45 lines
1.3 KiB
Python
45 lines
1.3 KiB
Python
import json
|
|
import time
|
|
from tqdm import tqdm
|
|
from hezar.models import Model
|
|
from flair.data import Sentence
|
|
from flair.models import SequenceTagger
|
|
|
|
# load tagger
|
|
tagger_flair = SequenceTagger.load("hamedkhaledi/persain-flair-upos")
|
|
tagger_hezar = Model.load("hezarai/distilbert-fa-pos-lscp-500k")
|
|
|
|
|
|
def get_POS_flair(input):
|
|
sentence = Sentence(input)
|
|
tagger_flair.predict(sentence)
|
|
#return sentence.to_tagged_string()
|
|
tg = sentence.get_labels()
|
|
str = ''
|
|
for s in tg:
|
|
str += s.shortstring
|
|
str += ','
|
|
return str
|
|
|
|
def get_POS_hezar(input):
|
|
tg = tagger_hezar.predict([input])
|
|
return f"{tg}"
|
|
|
|
if __name__ == "__main__":
|
|
print('start')
|
|
start_time = time.time()
|
|
inputfile = open('./main_classes_dataset_03.json', "r", encoding='utf-8')
|
|
data = json.load(inputfile)
|
|
inputfile.close()
|
|
for c in tqdm(data):
|
|
for item in tqdm(data[c]):
|
|
content = item['content']
|
|
item['POS_flair'] = get_POS_flair(content)
|
|
item['POS_hezar'] = get_POS_hezar(content)
|
|
|
|
outputfile = open('./main_classes_dataset_POS_03.json', "w", encoding='utf-8')
|
|
outputfile.write(json.dumps(data, ensure_ascii=False, indent = 4))
|
|
outputfile.close()
|
|
end_time = time.time()
|
|
print(f"elapsed time: {end_time-start_time}")
|
|
print("end") |