diff --git a/embedder.py b/embedder.py index e2ea749..6db15a8 100644 --- a/embedder.py +++ b/embedder.py @@ -27,6 +27,7 @@ from pathlib import Path # NLP and ML libraries from sentence_transformers import SentenceTransformer +from transformers import AutoTokenizer from sklearn.decomposition import PCA from sklearn.manifold import TSNE from sklearn.metrics.pairwise import cosine_similarity @@ -108,7 +109,24 @@ class PersianVectorAnalyzer: except Exception as e: logger.error(f"Error loading model: {e}") raise - + def split_sentence(self, sentence:str): + sentences = [] + sentence_len = len(self.tokenize_sentence(sentence)) + if sentence_len < 512: + sentences.append(sentence) + else: + temp_sentences = str(sentence).split('.') + for sent in temp_sentences: + sent_len = len(self.tokenize_sentence(sent)) + if sent_len > 512: + temp_sentences_2 = str(sentence).split('،') + for snt in temp_sentences_2: + sentences.append(snt) + else: + sentences.append(sent) + + return sentences + def load_json_data(self, file_path: str) -> List[str]: """ Load Persian sentences from JSON file. @@ -136,11 +154,11 @@ class PersianVectorAnalyzer: # Extract sentences from different possible keys for key in ['persian_translate']: if key in item and item[key]: - splited_sentences = str(item[key]).split('.') + splited_sentences = self.split_sentence(item[key]) for sent in splited_sentences: sentences.append(sent) elif isinstance(item, str): - splited_sentences = str(item).split('.') + splited_sentences = self.split_sentence(item[key]) for sent in splited_sentences: sentences.append(sent) elif isinstance(data, dict): @@ -181,6 +199,18 @@ class PersianVectorAnalyzer: return text.strip() + def tokenize_sentence(self, sentence:str): + + try: + tokenizer = AutoTokenizer.from_pretrained(self.model_name) + print(self.model_name) + tokens = tokenizer.tokenize(sentence) + return tokens + except: + error = "An exception occurred in tokenizer : " + model_checkpoint + #file.write( error + '\n' ) + return [] + def extract_words(self, sentences: List[str]) -> List[str]: """ Extract all words from sentences. @@ -464,8 +494,8 @@ class PersianVectorAnalyzer: Run the complete processing pipeline. Args: - input_file: Path to input JSON file - output_dir: Output directory for results + input_file(str): Path to input JSON file + output_dir(str): Output directory for results """ # Create output directory Path(output_dir).mkdir(exist_ok=True) @@ -475,6 +505,10 @@ class PersianVectorAnalyzer: # Step 1: Load data sentences = self.load_json_data(input_file) + for s in sentences: + s_len = len(self.tokenize_sentence(s)) + if s_len > 512: + print(s) # Step 2: Extract words # all_words = self.extract_words(sentences) @@ -530,8 +564,8 @@ def main(): analyzer = PersianVectorAnalyzer() # Define input and output paths - input_file = "./data/final_wisdom.json" - output_dir = "output" + input_file = "./out/nahj_speeches.json" + output_dir = "output-speechs" # Run the complete pipeline analyzer.process_pipeline(input_file, output_dir)