add files for extract pos tags
This commit is contained in:
parent
8d7a6cd922
commit
3b3a333978
32
hazm_parsivar.py
Normal file
32
hazm_parsivar.py
Normal file
|
@ -0,0 +1,32 @@
|
|||
import parsivar
|
||||
import hazm
|
||||
import json
|
||||
|
||||
with open("main_classes_dataset_POS_03.json", "r") as f:
|
||||
qs = json.load(f)
|
||||
|
||||
parsivar_normalizer = parsivar.Normalizer()
|
||||
parsivar_tokenizer = parsivar.Tokenizer()
|
||||
parsivar_tagger = parsivar.POSTagger(
|
||||
tagging_model="wapiti"
|
||||
) # tagging_model = "wapiti" or "stanford". "wapiti" is faster than "stanford"
|
||||
|
||||
hazm_normalizer = hazm.Normalizer()
|
||||
hazm_tagger = hazm.POSTagger(model="pos_tagger.model", universal_tag=False)
|
||||
|
||||
for cls, cls_list in qs.items():
|
||||
for i, q in enumerate(cls_list):
|
||||
text = q["content"]
|
||||
|
||||
qs[cls][i]["parsivar"] = str(
|
||||
parsivar_tagger.parse(
|
||||
parsivar_tokenizer.tokenize_words(parsivar_normalizer.normalize(text))
|
||||
)
|
||||
)
|
||||
|
||||
qs[cls][i]["hazm"] = str(
|
||||
hazm_tagger.tag(tokens=hazm.word_tokenize(hazm_normalizer.normalize(text)))
|
||||
)
|
||||
|
||||
with open("hazm_parsivar_added.json", "w", encoding="utf-8") as f:
|
||||
json.dump(qs, f, indent=4, ensure_ascii=False)
|
45
khaledi_hezar.py
Normal file
45
khaledi_hezar.py
Normal file
|
@ -0,0 +1,45 @@
|
|||
import json
|
||||
import time
|
||||
from tqdm import tqdm
|
||||
from hezar.models import Model
|
||||
from flair.data import Sentence
|
||||
from flair.models import SequenceTagger
|
||||
|
||||
# load tagger
|
||||
tagger_flair = SequenceTagger.load("hamedkhaledi/persain-flair-upos")
|
||||
tagger_hezar = Model.load("hezarai/distilbert-fa-pos-lscp-500k")
|
||||
|
||||
|
||||
def get_POS_flair(input):
|
||||
sentence = Sentence(input)
|
||||
tagger_flair.predict(sentence)
|
||||
#return sentence.to_tagged_string()
|
||||
tg = sentence.get_labels()
|
||||
str = ''
|
||||
for s in tg:
|
||||
str += s.shortstring
|
||||
str += ','
|
||||
return str
|
||||
|
||||
def get_POS_hezar(input):
|
||||
tg = tagger_hezar.predict([input])
|
||||
return f"{tg}"
|
||||
|
||||
if __name__ == "__main__":
|
||||
print('start')
|
||||
start_time = time.time()
|
||||
inputfile = open('./main_classes_dataset_03.json', "r", encoding='utf-8')
|
||||
data = json.load(inputfile)
|
||||
inputfile.close()
|
||||
for c in tqdm(data):
|
||||
for item in tqdm(data[c]):
|
||||
content = item['content']
|
||||
item['POS_flair'] = get_POS_flair(content)
|
||||
item['POS_hezar'] = get_POS_hezar(content)
|
||||
|
||||
outputfile = open('./main_classes_dataset_POS_03.json', "w", encoding='utf-8')
|
||||
outputfile.write(json.dumps(data, ensure_ascii=False, indent = 4))
|
||||
outputfile.close()
|
||||
end_time = time.time()
|
||||
print(f"elapsed time: {end_time-start_time}")
|
||||
print("end")
|
Loading…
Reference in New Issue
Block a user