add sentences and some editions
This commit is contained in:
parent
0402193403
commit
46ca9a6b50
15
embedder.py
15
embedder.py
|
@ -119,7 +119,7 @@ class PersianVectorAnalyzer:
|
||||||
for sent in temp_sentences:
|
for sent in temp_sentences:
|
||||||
sent_len = len(self.tokenize_sentence(sent))
|
sent_len = len(self.tokenize_sentence(sent))
|
||||||
if sent_len > 512:
|
if sent_len > 512:
|
||||||
temp_sentences_2 = str(sentence).split('،')
|
temp_sentences_2 = str(sent).split('،')
|
||||||
for snt in temp_sentences_2:
|
for snt in temp_sentences_2:
|
||||||
sentences.append(snt)
|
sentences.append(snt)
|
||||||
else:
|
else:
|
||||||
|
@ -149,7 +149,8 @@ class PersianVectorAnalyzer:
|
||||||
|
|
||||||
sentences = []
|
sentences = []
|
||||||
if isinstance(data, list):
|
if isinstance(data, list):
|
||||||
for item in data:
|
for index, item in enumerate(data):
|
||||||
|
print(f'split sentence {index}')
|
||||||
if isinstance(item, dict):
|
if isinstance(item, dict):
|
||||||
# Extract sentences from different possible keys
|
# Extract sentences from different possible keys
|
||||||
for key in ['persian_translate']:
|
for key in ['persian_translate']:
|
||||||
|
@ -203,11 +204,11 @@ class PersianVectorAnalyzer:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
||||||
print(self.model_name)
|
# print(self.model_name)
|
||||||
tokens = tokenizer.tokenize(sentence)
|
tokens = tokenizer.tokenize(sentence)
|
||||||
return tokens
|
return tokens
|
||||||
except:
|
except:
|
||||||
error = "An exception occurred in tokenizer : " + model_checkpoint
|
error = "An exception occurred in tokenizer : " + self.model_name
|
||||||
#file.write( error + '\n' )
|
#file.write( error + '\n' )
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
@ -508,7 +509,7 @@ class PersianVectorAnalyzer:
|
||||||
for s in sentences:
|
for s in sentences:
|
||||||
s_len = len(self.tokenize_sentence(s))
|
s_len = len(self.tokenize_sentence(s))
|
||||||
if s_len > 512:
|
if s_len > 512:
|
||||||
print(s)
|
print(f'long: {s}')
|
||||||
# Step 2: Extract words
|
# Step 2: Extract words
|
||||||
# all_words = self.extract_words(sentences)
|
# all_words = self.extract_words(sentences)
|
||||||
|
|
||||||
|
@ -523,7 +524,7 @@ class PersianVectorAnalyzer:
|
||||||
sentences_vectors = self.compute_word_vectors(sentences)
|
sentences_vectors = self.compute_word_vectors(sentences)
|
||||||
|
|
||||||
# Step 6: Save word vectors
|
# Step 6: Save word vectors
|
||||||
self.save_json(sentences_vectors, f"{output_dir}/sentences_vector.json")
|
self.save_json(sentences_vectors, f"{output_dir}/speech-sentences-vector.json")
|
||||||
|
|
||||||
# Step 7: Find closest words to key words
|
# Step 7: Find closest words to key words
|
||||||
# selected_words = self.find_closest_words(word_vectors, self.key_words)
|
# selected_words = self.find_closest_words(word_vectors, self.key_words)
|
||||||
|
@ -564,7 +565,7 @@ def main():
|
||||||
analyzer = PersianVectorAnalyzer()
|
analyzer = PersianVectorAnalyzer()
|
||||||
|
|
||||||
# Define input and output paths
|
# Define input and output paths
|
||||||
input_file = "./out/nahj_speeches.json"
|
input_file = "./output/nahj_speeches.json"
|
||||||
output_dir = "output-speechs"
|
output_dir = "output-speechs"
|
||||||
|
|
||||||
# Run the complete pipeline
|
# Run the complete pipeline
|
||||||
|
|
|
@ -28,7 +28,7 @@ from sklearn.metrics.pairwise import cosine_similarity
|
||||||
# -------------------
|
# -------------------
|
||||||
EMBED_MODEL = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
|
EMBED_MODEL = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
|
||||||
RERANKER_MODEL = "BAAI/bge-reranker-v2-m3"
|
RERANKER_MODEL = "BAAI/bge-reranker-v2-m3"
|
||||||
DATA_PATH = "./output/sentences_vector.json"
|
DATA_PATH = "./output-speechs/speech-sentences-vector.json"
|
||||||
|
|
||||||
|
|
||||||
def load_dataset(path: str) -> Tuple[List[str], np.ndarray]:
|
def load_dataset(path: str) -> Tuple[List[str], np.ndarray]:
|
||||||
|
@ -154,7 +154,7 @@ class HybridRetrieverReranker:
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
query = "انسان در فتنه ها باید چگونه عملی کند؟"
|
query = "افراد کوتاه قد چه ویژگی هایی دارند؟"
|
||||||
sentences, emb_matrix = load_dataset(DATA_PATH)
|
sentences, emb_matrix = load_dataset(DATA_PATH)
|
||||||
|
|
||||||
pipe = HybridRetrieverReranker(sentences, emb_matrix, dense_alpha=0.6)
|
pipe = HybridRetrieverReranker(sentences, emb_matrix, dense_alpha=0.6)
|
||||||
|
@ -163,6 +163,8 @@ def main():
|
||||||
print("\nTop results:")
|
print("\nTop results:")
|
||||||
for i, r in enumerate(results, 1):
|
for i, r in enumerate(results, 1):
|
||||||
print(f"{i}. [score={r['rerank_score']:.4f}] {r['sentence']}")
|
print(f"{i}. [score={r['rerank_score']:.4f}] {r['sentence']}")
|
||||||
|
print("--"*100)
|
||||||
|
print("--"*100)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -170,18 +172,18 @@ if __name__ == "__main__":
|
||||||
start = datetime.datetime.now()
|
start = datetime.datetime.now()
|
||||||
main()
|
main()
|
||||||
time2 = datetime.datetime.now()
|
time2 = datetime.datetime.now()
|
||||||
print(time2 - start)
|
print(f'p1: {time2 - start}')
|
||||||
|
|
||||||
main()
|
main()
|
||||||
time3 = datetime.datetime.now()
|
time3 = datetime.datetime.now()
|
||||||
print(time3 - time2)
|
print(f'p2: {time3 - time2}')
|
||||||
|
|
||||||
main()
|
main()
|
||||||
time4 = datetime.datetime.now()
|
time4 = datetime.datetime.now()
|
||||||
print(time4 - time3)
|
print(f'p3: {time4 - time3}')
|
||||||
|
|
||||||
main()
|
main()
|
||||||
time5 = datetime.datetime.now()
|
time5 = datetime.datetime.now()
|
||||||
print(time5 - time4)
|
print(f'p4: {time5 - time4}')
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
362939
output-speechs/speech-sentences-vector.json
Normal file
362939
output-speechs/speech-sentences-vector.json
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user