4
This commit is contained in:
parent
18eb11dea7
commit
a991f030e4
228
3-NLP_services/.gitignore
vendored
228
3-NLP_services/.gitignore
vendored
|
@ -1,228 +0,0 @@
|
||||||
### JetBrains template
|
|
||||||
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
|
|
||||||
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# User-specific stuff
|
|
||||||
.idea/**/workspace.xml
|
|
||||||
.idea/**/tasks.xml
|
|
||||||
.idea/**/usage.statistics.xml
|
|
||||||
.idea/**/dictionaries
|
|
||||||
.idea/**/shelf
|
|
||||||
|
|
||||||
# Generated files
|
|
||||||
.idea/**/contentModel.xml
|
|
||||||
|
|
||||||
*.jsonl
|
|
||||||
*.7zip
|
|
||||||
*.7z
|
|
||||||
.idea/
|
|
||||||
fastapi_flair/idea/
|
|
||||||
|
|
||||||
# Sensitive or high-churn files
|
|
||||||
.idea/**/dataSources/
|
|
||||||
.idea/**/dataSources.ids
|
|
||||||
.idea/**/dataSources.local.xml
|
|
||||||
.idea/**/sqlDataSources.xml
|
|
||||||
.idea/**/dynamic.xml
|
|
||||||
.idea/**/uiDesigner.xml
|
|
||||||
.idea/**/dbnavigator.xml
|
|
||||||
|
|
||||||
# Gradle
|
|
||||||
.idea/**/gradle.xml
|
|
||||||
.idea/**/libraries
|
|
||||||
|
|
||||||
# Gradle and Maven with auto-import
|
|
||||||
# When using Gradle or Maven with auto-import, you should exclude module files,
|
|
||||||
# since they will be recreated, and may cause churn. Uncomment if using
|
|
||||||
# auto-import.
|
|
||||||
# .idea/artifacts
|
|
||||||
# .idea/compiler.xml
|
|
||||||
# .idea/jarRepositories.xml
|
|
||||||
# .idea/modules.xml
|
|
||||||
# .idea/*.iml
|
|
||||||
# .idea/modules
|
|
||||||
# *.iml
|
|
||||||
# *.ipr
|
|
||||||
|
|
||||||
# CMake
|
|
||||||
cmake-build-*/
|
|
||||||
|
|
||||||
# Mongo Explorer plugin
|
|
||||||
.idea/**/mongoSettings.xml
|
|
||||||
|
|
||||||
# File-based project format
|
|
||||||
*.iws
|
|
||||||
|
|
||||||
# IntelliJ
|
|
||||||
out/
|
|
||||||
|
|
||||||
# mpeltonen/sbt-idea plugin
|
|
||||||
.idea_modules/
|
|
||||||
|
|
||||||
# JIRA plugin
|
|
||||||
atlassian-ide-plugin.xml
|
|
||||||
|
|
||||||
# Cursive Clojure plugin
|
|
||||||
.idea/replstate.xml
|
|
||||||
|
|
||||||
# Crashlytics plugin (for Android Studio and IntelliJ)
|
|
||||||
com_crashlytics_export_strings.xml
|
|
||||||
crashlytics.properties
|
|
||||||
crashlytics-build.properties
|
|
||||||
fabric.properties
|
|
||||||
|
|
||||||
# Editor-based Rest Client
|
|
||||||
.idea/httpRequests
|
|
||||||
|
|
||||||
# Android studio 3.1+ serialized cache file
|
|
||||||
.idea/caches/build_file_checksums.ser
|
|
||||||
|
|
||||||
### Python template
|
|
||||||
# Byte-compiled / optimized / DLL files
|
|
||||||
__pycache__/
|
|
||||||
*.py[cod]
|
|
||||||
*$py.class
|
|
||||||
|
|
||||||
# C extensions
|
|
||||||
*.so
|
|
||||||
|
|
||||||
# Distribution / packaging
|
|
||||||
.Python
|
|
||||||
build/
|
|
||||||
develop-eggs/
|
|
||||||
dist/
|
|
||||||
downloads/
|
|
||||||
eggs/
|
|
||||||
.eggs/
|
|
||||||
lib/
|
|
||||||
lib64/
|
|
||||||
parts/
|
|
||||||
sdist/
|
|
||||||
var/
|
|
||||||
wheels/
|
|
||||||
share/python-wheels/
|
|
||||||
*.egg-info/
|
|
||||||
.installed.cfg
|
|
||||||
*.egg
|
|
||||||
MANIFEST
|
|
||||||
|
|
||||||
# PyInstaller
|
|
||||||
# Usually these files are written by a python script from a template
|
|
||||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
|
||||||
*.manifest
|
|
||||||
*.spec
|
|
||||||
|
|
||||||
# Installer logs
|
|
||||||
pip-log.txt
|
|
||||||
pip-delete-this-directory.txt
|
|
||||||
|
|
||||||
# Unit test / coverage reports
|
|
||||||
htmlcov/
|
|
||||||
.tox/
|
|
||||||
.nox/
|
|
||||||
.coverage
|
|
||||||
.coverage.*
|
|
||||||
.cache
|
|
||||||
nosetests.xml
|
|
||||||
coverage.xml
|
|
||||||
*.cover
|
|
||||||
*.py,cover
|
|
||||||
.hypothesis/
|
|
||||||
.pytest_cache/
|
|
||||||
cover/
|
|
||||||
|
|
||||||
# Translations
|
|
||||||
*.mo
|
|
||||||
*.pot
|
|
||||||
|
|
||||||
# Django stuff:
|
|
||||||
*.log
|
|
||||||
local_settings.py
|
|
||||||
db.sqlite3
|
|
||||||
db.sqlite3-journal
|
|
||||||
|
|
||||||
# Flask stuff:
|
|
||||||
instance/
|
|
||||||
.webassets-cache
|
|
||||||
|
|
||||||
# Scrapy stuff:
|
|
||||||
.scrapy
|
|
||||||
|
|
||||||
# Sphinx documentation
|
|
||||||
docs/_build/
|
|
||||||
|
|
||||||
# PyBuilder
|
|
||||||
.pybuilder/
|
|
||||||
target/
|
|
||||||
|
|
||||||
# Jupyter Notebook
|
|
||||||
.ipynb_checkpoints
|
|
||||||
|
|
||||||
# IPython
|
|
||||||
profile_default/
|
|
||||||
ipython_config.py
|
|
||||||
|
|
||||||
# pyenv
|
|
||||||
# For a library or package, you might want to ignore these files since the code is
|
|
||||||
# intended to run in multiple environments; otherwise, check them in:
|
|
||||||
# .python-version
|
|
||||||
|
|
||||||
# pipenv
|
|
||||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
|
||||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
|
||||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
|
||||||
# install all needed dependencies.
|
|
||||||
#Pipfile.lock
|
|
||||||
|
|
||||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
|
||||||
__pypackages__/
|
|
||||||
|
|
||||||
# Celery stuff
|
|
||||||
celerybeat-schedule
|
|
||||||
celerybeat.pid
|
|
||||||
|
|
||||||
# SageMath parsed files
|
|
||||||
*.sage.py
|
|
||||||
|
|
||||||
# Environments
|
|
||||||
.env
|
|
||||||
.venv
|
|
||||||
env/
|
|
||||||
venv/
|
|
||||||
ENV/
|
|
||||||
env.bak/
|
|
||||||
venv.bak/
|
|
||||||
|
|
||||||
# Spyder project settings
|
|
||||||
.spyderproject
|
|
||||||
.spyproject
|
|
||||||
|
|
||||||
# Rope project settings
|
|
||||||
.ropeproject
|
|
||||||
|
|
||||||
# mkdocs documentation
|
|
||||||
/site
|
|
||||||
|
|
||||||
# mypy
|
|
||||||
.mypy_cache/
|
|
||||||
.dmypy.json
|
|
||||||
dmypy.json
|
|
||||||
|
|
||||||
# Pyre type checker
|
|
||||||
.pyre/
|
|
||||||
|
|
||||||
# pytype static type analyzer
|
|
||||||
.pytype/
|
|
||||||
|
|
||||||
# Cython debug symbols
|
|
||||||
cython_debug/
|
|
||||||
|
|
||||||
data/ARMAN-MSR-persian-base-PN-summary
|
|
||||||
data/mT5_multilingual_XLSum
|
|
||||||
data/ner-model.pt
|
|
||||||
data/pos-model.pt
|
|
||||||
data/pos-model.7z
|
|
||||||
|
|
|
@ -1,34 +0,0 @@
|
||||||
FROM python:3.9-slim
|
|
||||||
|
|
||||||
# Maintainer info
|
|
||||||
LABEL maintainer="khaledihkh@gmail.com"
|
|
||||||
|
|
||||||
# Make working directories
|
|
||||||
RUN mkdir -p /service
|
|
||||||
WORKDIR /service
|
|
||||||
|
|
||||||
# Upgrade pip with no cache
|
|
||||||
RUN pip install --no-cache-dir -U pip
|
|
||||||
|
|
||||||
# Copy application requirements file to the created working directory
|
|
||||||
COPY requirements.txt .
|
|
||||||
|
|
||||||
# Install application dependencies from the requirements file
|
|
||||||
RUN pip install -r requirements.txt
|
|
||||||
|
|
||||||
# Copy every file in the source folder to the created working directory
|
|
||||||
COPY . .
|
|
||||||
RUN pip install ./src/piraye
|
|
||||||
RUN pip install transformers[sentencepiece]
|
|
||||||
|
|
||||||
|
|
||||||
ENV NER_MODEL_NAME ""
|
|
||||||
ENV NORMALIZE "False"
|
|
||||||
ENV SHORT_MODEL_NAME ""
|
|
||||||
ENV LONG_MODEL_NAME ""
|
|
||||||
ENV POS_MODEL_NAME ""
|
|
||||||
ENV BATCH_SIZE "4"
|
|
||||||
ENV DEVICE "gpu"
|
|
||||||
|
|
||||||
EXPOSE 80
|
|
||||||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "80"]
|
|
|
@ -1,18 +0,0 @@
|
||||||
# Services
|
|
||||||
## Requirements
|
|
||||||
````shell
|
|
||||||
pip install -r requirements.txt
|
|
||||||
````
|
|
||||||
## Download Models
|
|
||||||
download models and place in data folder
|
|
||||||
https://www.mediafire.com/folder/tz3t9c9rpf6fo/models
|
|
||||||
|
|
||||||
## Getting started
|
|
||||||
|
|
||||||
````shell
|
|
||||||
cd fastapi_flair
|
|
||||||
uvicorn main:app --reload
|
|
||||||
````
|
|
||||||
|
|
||||||
## Documentation
|
|
||||||
[Open Documentations](./docs/docs.md)
|
|
|
@ -1,10 +0,0 @@
|
||||||
[
|
|
||||||
{
|
|
||||||
"client-id":"hassan",
|
|
||||||
"client-password":"1234"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"client-id":"ali",
|
|
||||||
"client-password":"12345"
|
|
||||||
}
|
|
||||||
]
|
|
|
@ -1,8 +0,0 @@
|
||||||
version: '3.0'
|
|
||||||
|
|
||||||
services:
|
|
||||||
web:
|
|
||||||
image: "services"
|
|
||||||
command: uvicorn main:app --host 0.0.0.0
|
|
||||||
ports:
|
|
||||||
- 8008:8000
|
|
|
@ -1,10 +0,0 @@
|
||||||
### Libraries:
|
|
||||||
|
|
||||||
1. Uvicorn: Uvicorn is an ASGI (The Asynchronous Server Gateway Interface is a calling convention for web servers to forward requests to asynchronous-capable Python programming language frameworks, and applications. It is built as a successor to the Web Server Gateway Interface.) web server implementation for Python. Until recently Python has lacked a minimal low-level server/application interface for async frameworks. The ASGI specification fills this gap, and means we're now able to start building a common set of tooling usable across all async frameworks.
|
|
||||||
2. FastAPI: FastAPI framework, high performance, easy to learn, fast to code, ready for production
|
|
||||||
FastAPI is a modern, fast (high-performance), web framework for building APIs with Python 3.7+ based on standard Python type hints.
|
|
||||||
3. NLTK: NLTK is a leading platform for building Python programs to work with human language data. It provides easy-to-use interfaces to over 50 corpora and lexical resources such as WordNet, along with a suite of text processing libraries for classification, tokenization, stemming, tagging, parsing, and semantic reasoning, wrappers for industrial-strength NLP libraries, and an active discussion forum.
|
|
||||||
4. Piraye: NLP Utils, A utility for normalizing persian, arabic and english texts
|
|
||||||
5. ...
|
|
||||||
|
|
||||||
|
|
|
@ -1,199 +0,0 @@
|
||||||
from __future__ import annotations
|
|
||||||
import uvicorn
|
|
||||||
import torch
|
|
||||||
from fastapi import FastAPI, HTTPException, Request
|
|
||||||
from starlette.status import HTTP_201_CREATED
|
|
||||||
|
|
||||||
from src.database_handler import Database
|
|
||||||
from src.requests_data import Requests
|
|
||||||
from model.request_models import InputNerType, InputSummaryType, InputNREType
|
|
||||||
from model.response_models import BaseResponse, ResponseNerModel, ResponseSummaryModel, ResponseOIEModel, \
|
|
||||||
ResponseNREModel
|
|
||||||
from src.response_ner import NerResponse, ResponsesNer
|
|
||||||
from fastapi import BackgroundTasks
|
|
||||||
|
|
||||||
from src.response_nre import NREResponse, ResponsesNRE
|
|
||||||
from src.response_oie import OIEResponse, ResponsesOIE
|
|
||||||
from src.response_summary import ShortResponse, LongResponse, ResponsesSummary, NormalizerTexts
|
|
||||||
|
|
||||||
print("torch version : ", torch.__version__)
|
|
||||||
app = FastAPI()
|
|
||||||
# ner_model_name = os.getenv("NER_MODEL_NAME")
|
|
||||||
# normalize = os.getenv("NORMALIZE")
|
|
||||||
# short_model_name = os.getenv("SHORT_MODEL_NAME")
|
|
||||||
# long_model_name = os.getenv("LONG_MODEL_NAME")
|
|
||||||
# pos_model_name = os.getenv("POS_MODEL_NAME")
|
|
||||||
# batch_size = int(os.getenv("BATCH_SIZE"))
|
|
||||||
# device = os.getenv("DEVICE")
|
|
||||||
#
|
|
||||||
# print(ner_model_name, normalize, short_model_name, long_model_name, pos_model_name, batch_size, device)
|
|
||||||
|
|
||||||
# initial docker parameters
|
|
||||||
pos_model_name = "./data/pos-model.pt"
|
|
||||||
ner_model_name = "./data/ner-model.pt"
|
|
||||||
short_model_name = "csebuetnlp/mT5_multilingual_XLSum"
|
|
||||||
long_model_name = "alireza7/ARMAN-MSR-persian-base-PN-summary"
|
|
||||||
# oie_model_name = "default"
|
|
||||||
oie_model_name = ""
|
|
||||||
ner_model_name = "default"
|
|
||||||
pos_model_name = "default"
|
|
||||||
nre_model_name = ""
|
|
||||||
|
|
||||||
bert_oie = './data/mbert-base-parsinlu-entailment'
|
|
||||||
short_model_name = ""
|
|
||||||
long_model_name = ""
|
|
||||||
normalize = "False"
|
|
||||||
batch_size = 4
|
|
||||||
device = "cpu"
|
|
||||||
|
|
||||||
if ner_model_name == "":
|
|
||||||
ner_response = None
|
|
||||||
else:
|
|
||||||
if ner_model_name.lower() == "default":
|
|
||||||
ner_model_name = "./data/ner-model.pt"
|
|
||||||
ner_response = NerResponse(ner_model_name, batch_size=batch_size, device=device)
|
|
||||||
|
|
||||||
if oie_model_name == "":
|
|
||||||
oie_response = None
|
|
||||||
else:
|
|
||||||
if oie_model_name.lower() == "default":
|
|
||||||
oie_model_name = "./data/mbertEntail-2GBEn.bin"
|
|
||||||
oie_response = OIEResponse(oie_model_name, BERT=bert_oie, batch_size=batch_size, device=device)
|
|
||||||
|
|
||||||
if nre_model_name == "":
|
|
||||||
nre_response = None
|
|
||||||
else:
|
|
||||||
if nre_model_name.lower() == "default":
|
|
||||||
nre_model_name = "./data/model.ckpt.pth.tar"
|
|
||||||
nre_response = NREResponse(nre_model_name)
|
|
||||||
|
|
||||||
if normalize.lower() == "false" or pos_model_name == "":
|
|
||||||
short_response = ShortResponse(normalizer_input=None, device=device)
|
|
||||||
long_response = LongResponse(normalizer_input=None, device=device)
|
|
||||||
else:
|
|
||||||
if pos_model_name.lower() == "default":
|
|
||||||
pos_model_name = "./data/pos-model.pt"
|
|
||||||
normalizer_pos = NormalizerTexts(pos_model_name)
|
|
||||||
short_response = ShortResponse(normalizer_input=normalizer_pos, device=device)
|
|
||||||
long_response = LongResponse(normalizer_input=normalizer_pos, device=device)
|
|
||||||
if short_model_name != "":
|
|
||||||
short_response.load_model(short_model_name)
|
|
||||||
if long_model_name != "":
|
|
||||||
long_response.load_model(long_model_name)
|
|
||||||
|
|
||||||
requests: Requests = Requests()
|
|
||||||
database: Database = Database()
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/")
|
|
||||||
async def root():
|
|
||||||
return {"message": "Hello World"}
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/ner/{item_id}", response_model=ResponseNerModel)
|
|
||||||
async def read_item(item_id, request: Request) -> ResponseNerModel:
|
|
||||||
if database.check_client(client_id=request.headers.get('Client-Id'),
|
|
||||||
client_password=request.headers.get('Client-Password')) and ner_response:
|
|
||||||
if not ResponsesNer.check_id(int(item_id)):
|
|
||||||
raise HTTPException(status_code=404, detail="Item not found")
|
|
||||||
return ResponsesNer.get_response(int(item_id))
|
|
||||||
else:
|
|
||||||
raise HTTPException(status_code=403, detail="you are not allowed")
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/oie/{item_id}", response_model=ResponseOIEModel)
|
|
||||||
async def read_item(item_id, request: Request) -> ResponseOIEModel:
|
|
||||||
if database.check_client(client_id=request.headers.get('Client-Id'),
|
|
||||||
client_password=request.headers.get('Client-Password')) and oie_response:
|
|
||||||
if not ResponsesOIE.check_id(int(item_id)):
|
|
||||||
raise HTTPException(status_code=404, detail="Item not found")
|
|
||||||
return ResponsesOIE.get_response(int(item_id))
|
|
||||||
else:
|
|
||||||
raise HTTPException(status_code=403, detail="you are not allowed")
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/nre/{item_id}", response_model=ResponseNREModel)
|
|
||||||
async def read_item(item_id, request: Request) -> ResponseNREModel:
|
|
||||||
if database.check_client(client_id=request.headers.get('Client-Id'),
|
|
||||||
client_password=request.headers.get('Client-Password')) and nre_response:
|
|
||||||
if not ResponsesNRE.check_id(int(item_id)):
|
|
||||||
raise HTTPException(status_code=404, detail="Item not found")
|
|
||||||
return ResponsesNRE.get_response(int(item_id))
|
|
||||||
else:
|
|
||||||
raise HTTPException(status_code=403, detail="you are not allowed")
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/summary/{item_id}", response_model=ResponseSummaryModel)
|
|
||||||
async def read_item(item_id, request: Request) -> ResponseSummaryModel:
|
|
||||||
if database.check_client(client_id=request.headers.get('Client-Id'),
|
|
||||||
client_password=request.headers.get('Client-Password')):
|
|
||||||
if not ResponsesSummary.check_id(int(item_id)):
|
|
||||||
raise HTTPException(status_code=404, detail="Item not found")
|
|
||||||
return ResponsesSummary.get_response(int(item_id))
|
|
||||||
else:
|
|
||||||
raise HTTPException(status_code=403, detail="you are not allowed")
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/ner", response_model=BaseResponse, status_code=HTTP_201_CREATED)
|
|
||||||
async def ner(input_json: InputNerType, background_tasks: BackgroundTasks, request: Request) -> BaseResponse:
|
|
||||||
if database.check_client(client_id=request.headers.get('Client-Id'),
|
|
||||||
client_password=request.headers.get('Client-Password')) and ner_model_name != "":
|
|
||||||
request_id = requests.add_request_ner(input_json)
|
|
||||||
background_tasks.add_task(ner_response.predict_json, input_json, request_id)
|
|
||||||
return BaseResponse(status="input received successfully", id=request_id)
|
|
||||||
else:
|
|
||||||
raise HTTPException(status_code=403, detail="you are not allowed")
|
|
||||||
# return ner_response.predict_json(input_json)
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/oie", response_model=BaseResponse, status_code=HTTP_201_CREATED)
|
|
||||||
async def oie(input_json: InputNerType, background_tasks: BackgroundTasks, request: Request) -> BaseResponse:
|
|
||||||
if database.check_client(client_id=request.headers.get('Client-Id'),
|
|
||||||
client_password=request.headers.get('Client-Password')) and oie_model_name != "":
|
|
||||||
request_id = requests.add_request_oie(input_json)
|
|
||||||
background_tasks.add_task(oie_response.predict_json, input_json, request_id)
|
|
||||||
return BaseResponse(status="input received successfully", id=request_id)
|
|
||||||
else:
|
|
||||||
raise HTTPException(status_code=403, detail="you are not allowed")
|
|
||||||
# return ner_response.predict_json(input_json)
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/nre", response_model=BaseResponse, status_code=HTTP_201_CREATED)
|
|
||||||
async def oie(input_json: InputNREType, background_tasks: BackgroundTasks, request: Request) -> BaseResponse:
|
|
||||||
if database.check_client(client_id=request.headers.get('Client-Id'),
|
|
||||||
client_password=request.headers.get('Client-Password')) and nre_model_name != "":
|
|
||||||
request_id = requests.add_request_nre(input_json)
|
|
||||||
background_tasks.add_task(nre_response.predict_json, input_json, request_id)
|
|
||||||
return BaseResponse(status="input received successfully", id=request_id)
|
|
||||||
else:
|
|
||||||
raise HTTPException(status_code=403, detail="you are not allowed")
|
|
||||||
# return ner_response.predict_json(input_json)
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/summary/short", response_model=BaseResponse, status_code=HTTP_201_CREATED)
|
|
||||||
async def short_summary(input_json: InputSummaryType, background_tasks: BackgroundTasks,
|
|
||||||
request: Request) -> BaseResponse:
|
|
||||||
if database.check_client(client_id=request.headers.get('Client-Id'),
|
|
||||||
client_password=request.headers.get('Client-Password')) and short_model_name != "":
|
|
||||||
request_id = requests.add_request_summary(input_json)
|
|
||||||
background_tasks.add_task(short_response.get_result, input_json, request_id)
|
|
||||||
return BaseResponse(status="input received successfully", id=request_id)
|
|
||||||
else:
|
|
||||||
raise HTTPException(status_code=403, detail="you are not allowed")
|
|
||||||
# return ner_response.predict_json(input_json)
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/summary/long", response_model=BaseResponse, status_code=HTTP_201_CREATED)
|
|
||||||
async def long_summary(input_json: InputSummaryType, background_tasks: BackgroundTasks,
|
|
||||||
request: Request) -> BaseResponse:
|
|
||||||
if database.check_client(client_id=request.headers.get('Client-Id'),
|
|
||||||
client_password=request.headers.get('Client-Password')) and long_model_name != "":
|
|
||||||
request_id = requests.add_request_summary(input_json)
|
|
||||||
background_tasks.add_task(long_response.get_result, input_json, request_id)
|
|
||||||
return BaseResponse(status="input received successfully", id=request_id)
|
|
||||||
else:
|
|
||||||
raise HTTPException(status_code=403, detail="you are not allowed")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
|
|
@ -1,28 +0,0 @@
|
||||||
from typing import List, Dict, Tuple
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
|
|
||||||
class InputElement(BaseModel):
|
|
||||||
lang: str
|
|
||||||
text: str
|
|
||||||
|
|
||||||
|
|
||||||
class NREHType(BaseModel):
|
|
||||||
pos: str
|
|
||||||
|
|
||||||
|
|
||||||
class NRETType(BaseModel):
|
|
||||||
pos: str
|
|
||||||
|
|
||||||
|
|
||||||
class InputElementNRE(BaseModel):
|
|
||||||
text: str
|
|
||||||
h: NREHType
|
|
||||||
t: NRETType
|
|
||||||
|
|
||||||
|
|
||||||
InputNerType = List[InputElement]
|
|
||||||
InputOIEType = List[InputElement]
|
|
||||||
InputNREType = List[InputElementNRE]
|
|
||||||
|
|
||||||
InputSummaryType = List[str]
|
|
|
@ -1,43 +0,0 @@
|
||||||
from typing import List, Optional, Tuple
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
|
|
||||||
class EntityNerResponseModel(BaseModel):
|
|
||||||
entity_group: str
|
|
||||||
word: str
|
|
||||||
start: int
|
|
||||||
end: int
|
|
||||||
score: float
|
|
||||||
|
|
||||||
|
|
||||||
class EntityOIEResponseModel(BaseModel):
|
|
||||||
score: str
|
|
||||||
relation: str
|
|
||||||
arg1: str
|
|
||||||
arg2: str
|
|
||||||
|
|
||||||
|
|
||||||
class BaseResponse(BaseModel):
|
|
||||||
id: int
|
|
||||||
status: str
|
|
||||||
|
|
||||||
|
|
||||||
class ResponseNerModel(BaseModel):
|
|
||||||
progression: str
|
|
||||||
result: Optional[List[List[EntityNerResponseModel]]]
|
|
||||||
|
|
||||||
|
|
||||||
class ResponseSummaryModel(BaseModel):
|
|
||||||
progression: str
|
|
||||||
result: Optional[List[str]]
|
|
||||||
|
|
||||||
|
|
||||||
class ResponseOIEModel(BaseModel):
|
|
||||||
progression: str
|
|
||||||
result: Optional[List[EntityOIEResponseModel]]
|
|
||||||
|
|
||||||
|
|
||||||
class ResponseNREModel(BaseModel):
|
|
||||||
progression: str
|
|
||||||
result: Optional[List[Tuple[str, float]]]
|
|
|
@ -1,672 +0,0 @@
|
||||||
{
|
|
||||||
"openapi": "3.0.2",
|
|
||||||
"info": {
|
|
||||||
"title": "FastAPI",
|
|
||||||
"version": "0.1.0"
|
|
||||||
},
|
|
||||||
"paths": {
|
|
||||||
"/": {
|
|
||||||
"get": {
|
|
||||||
"summary": "Root",
|
|
||||||
"operationId": "root__get",
|
|
||||||
"responses": {
|
|
||||||
"200": {
|
|
||||||
"description": "Successful Response",
|
|
||||||
"content": {
|
|
||||||
"application/json": {
|
|
||||||
"schema": {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"/ner/{item_id}": {
|
|
||||||
"get": {
|
|
||||||
"summary": "Read Item",
|
|
||||||
"operationId": "read_item_ner__item_id__get",
|
|
||||||
"parameters": [
|
|
||||||
{
|
|
||||||
"required": true,
|
|
||||||
"schema": {
|
|
||||||
"title": "Item Id"
|
|
||||||
},
|
|
||||||
"name": "item_id",
|
|
||||||
"in": "path"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"responses": {
|
|
||||||
"200": {
|
|
||||||
"description": "Successful Response",
|
|
||||||
"content": {
|
|
||||||
"application/json": {
|
|
||||||
"schema": {
|
|
||||||
"$ref": "#/components/schemas/ResponseNerModel"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"422": {
|
|
||||||
"description": "Validation Error",
|
|
||||||
"content": {
|
|
||||||
"application/json": {
|
|
||||||
"schema": {
|
|
||||||
"$ref": "#/components/schemas/HTTPValidationError"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"/oie/{item_id}": {
|
|
||||||
"get": {
|
|
||||||
"summary": "Read Item",
|
|
||||||
"operationId": "read_item_oie__item_id__get",
|
|
||||||
"parameters": [
|
|
||||||
{
|
|
||||||
"required": true,
|
|
||||||
"schema": {
|
|
||||||
"title": "Item Id"
|
|
||||||
},
|
|
||||||
"name": "item_id",
|
|
||||||
"in": "path"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"responses": {
|
|
||||||
"200": {
|
|
||||||
"description": "Successful Response",
|
|
||||||
"content": {
|
|
||||||
"application/json": {
|
|
||||||
"schema": {
|
|
||||||
"$ref": "#/components/schemas/ResponseOIEModel"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"422": {
|
|
||||||
"description": "Validation Error",
|
|
||||||
"content": {
|
|
||||||
"application/json": {
|
|
||||||
"schema": {
|
|
||||||
"$ref": "#/components/schemas/HTTPValidationError"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"/nre/{item_id}": {
|
|
||||||
"get": {
|
|
||||||
"summary": "Read Item",
|
|
||||||
"operationId": "read_item_nre__item_id__get",
|
|
||||||
"parameters": [
|
|
||||||
{
|
|
||||||
"required": true,
|
|
||||||
"schema": {
|
|
||||||
"title": "Item Id"
|
|
||||||
},
|
|
||||||
"name": "item_id",
|
|
||||||
"in": "path"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"responses": {
|
|
||||||
"200": {
|
|
||||||
"description": "Successful Response",
|
|
||||||
"content": {
|
|
||||||
"application/json": {
|
|
||||||
"schema": {
|
|
||||||
"$ref": "#/components/schemas/ResponseNREModel"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"422": {
|
|
||||||
"description": "Validation Error",
|
|
||||||
"content": {
|
|
||||||
"application/json": {
|
|
||||||
"schema": {
|
|
||||||
"$ref": "#/components/schemas/HTTPValidationError"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"/summary/{item_id}": {
|
|
||||||
"get": {
|
|
||||||
"summary": "Read Item",
|
|
||||||
"operationId": "read_item_summary__item_id__get",
|
|
||||||
"parameters": [
|
|
||||||
{
|
|
||||||
"required": true,
|
|
||||||
"schema": {
|
|
||||||
"title": "Item Id"
|
|
||||||
},
|
|
||||||
"name": "item_id",
|
|
||||||
"in": "path"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"responses": {
|
|
||||||
"200": {
|
|
||||||
"description": "Successful Response",
|
|
||||||
"content": {
|
|
||||||
"application/json": {
|
|
||||||
"schema": {
|
|
||||||
"$ref": "#/components/schemas/ResponseSummaryModel"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"422": {
|
|
||||||
"description": "Validation Error",
|
|
||||||
"content": {
|
|
||||||
"application/json": {
|
|
||||||
"schema": {
|
|
||||||
"$ref": "#/components/schemas/HTTPValidationError"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"/ner": {
|
|
||||||
"post": {
|
|
||||||
"summary": "Ner",
|
|
||||||
"operationId": "ner_ner_post",
|
|
||||||
"requestBody": {
|
|
||||||
"content": {
|
|
||||||
"application/json": {
|
|
||||||
"schema": {
|
|
||||||
"title": "Input Json",
|
|
||||||
"type": "array",
|
|
||||||
"items": {
|
|
||||||
"$ref": "#/components/schemas/InputElement"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": true
|
|
||||||
},
|
|
||||||
"responses": {
|
|
||||||
"201": {
|
|
||||||
"description": "Successful Response",
|
|
||||||
"content": {
|
|
||||||
"application/json": {
|
|
||||||
"schema": {
|
|
||||||
"$ref": "#/components/schemas/BaseResponse"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"422": {
|
|
||||||
"description": "Validation Error",
|
|
||||||
"content": {
|
|
||||||
"application/json": {
|
|
||||||
"schema": {
|
|
||||||
"$ref": "#/components/schemas/HTTPValidationError"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"/oie": {
|
|
||||||
"post": {
|
|
||||||
"summary": "Oie",
|
|
||||||
"operationId": "oie_oie_post",
|
|
||||||
"requestBody": {
|
|
||||||
"content": {
|
|
||||||
"application/json": {
|
|
||||||
"schema": {
|
|
||||||
"title": "Input Json",
|
|
||||||
"type": "array",
|
|
||||||
"items": {
|
|
||||||
"$ref": "#/components/schemas/InputElement"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": true
|
|
||||||
},
|
|
||||||
"responses": {
|
|
||||||
"201": {
|
|
||||||
"description": "Successful Response",
|
|
||||||
"content": {
|
|
||||||
"application/json": {
|
|
||||||
"schema": {
|
|
||||||
"$ref": "#/components/schemas/BaseResponse"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"422": {
|
|
||||||
"description": "Validation Error",
|
|
||||||
"content": {
|
|
||||||
"application/json": {
|
|
||||||
"schema": {
|
|
||||||
"$ref": "#/components/schemas/HTTPValidationError"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"/nre": {
|
|
||||||
"post": {
|
|
||||||
"summary": "Oie",
|
|
||||||
"operationId": "oie_nre_post",
|
|
||||||
"requestBody": {
|
|
||||||
"content": {
|
|
||||||
"application/json": {
|
|
||||||
"schema": {
|
|
||||||
"title": "Input Json",
|
|
||||||
"type": "array",
|
|
||||||
"items": {
|
|
||||||
"$ref": "#/components/schemas/InputElementNRE"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": true
|
|
||||||
},
|
|
||||||
"responses": {
|
|
||||||
"201": {
|
|
||||||
"description": "Successful Response",
|
|
||||||
"content": {
|
|
||||||
"application/json": {
|
|
||||||
"schema": {
|
|
||||||
"$ref": "#/components/schemas/BaseResponse"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"422": {
|
|
||||||
"description": "Validation Error",
|
|
||||||
"content": {
|
|
||||||
"application/json": {
|
|
||||||
"schema": {
|
|
||||||
"$ref": "#/components/schemas/HTTPValidationError"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"/summary/short": {
|
|
||||||
"post": {
|
|
||||||
"summary": "Short Summary",
|
|
||||||
"operationId": "short_summary_summary_short_post",
|
|
||||||
"requestBody": {
|
|
||||||
"content": {
|
|
||||||
"application/json": {
|
|
||||||
"schema": {
|
|
||||||
"title": "Input Json",
|
|
||||||
"type": "array",
|
|
||||||
"items": {
|
|
||||||
"type": "string"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": true
|
|
||||||
},
|
|
||||||
"responses": {
|
|
||||||
"201": {
|
|
||||||
"description": "Successful Response",
|
|
||||||
"content": {
|
|
||||||
"application/json": {
|
|
||||||
"schema": {
|
|
||||||
"$ref": "#/components/schemas/BaseResponse"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"422": {
|
|
||||||
"description": "Validation Error",
|
|
||||||
"content": {
|
|
||||||
"application/json": {
|
|
||||||
"schema": {
|
|
||||||
"$ref": "#/components/schemas/HTTPValidationError"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"/summary/long": {
|
|
||||||
"post": {
|
|
||||||
"summary": "Long Summary",
|
|
||||||
"operationId": "long_summary_summary_long_post",
|
|
||||||
"requestBody": {
|
|
||||||
"content": {
|
|
||||||
"application/json": {
|
|
||||||
"schema": {
|
|
||||||
"title": "Input Json",
|
|
||||||
"type": "array",
|
|
||||||
"items": {
|
|
||||||
"type": "string"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": true
|
|
||||||
},
|
|
||||||
"responses": {
|
|
||||||
"201": {
|
|
||||||
"description": "Successful Response",
|
|
||||||
"content": {
|
|
||||||
"application/json": {
|
|
||||||
"schema": {
|
|
||||||
"$ref": "#/components/schemas/BaseResponse"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"422": {
|
|
||||||
"description": "Validation Error",
|
|
||||||
"content": {
|
|
||||||
"application/json": {
|
|
||||||
"schema": {
|
|
||||||
"$ref": "#/components/schemas/HTTPValidationError"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"components": {
|
|
||||||
"schemas": {
|
|
||||||
"BaseResponse": {
|
|
||||||
"title": "BaseResponse",
|
|
||||||
"required": [
|
|
||||||
"id",
|
|
||||||
"status"
|
|
||||||
],
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"id": {
|
|
||||||
"title": "Id",
|
|
||||||
"type": "integer"
|
|
||||||
},
|
|
||||||
"status": {
|
|
||||||
"title": "Status",
|
|
||||||
"type": "string"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"EntityNerResponseModel": {
|
|
||||||
"title": "EntityNerResponseModel",
|
|
||||||
"required": [
|
|
||||||
"entity_group",
|
|
||||||
"word",
|
|
||||||
"start",
|
|
||||||
"end",
|
|
||||||
"score"
|
|
||||||
],
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"entity_group": {
|
|
||||||
"title": "Entity Group",
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
"word": {
|
|
||||||
"title": "Word",
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
"start": {
|
|
||||||
"title": "Start",
|
|
||||||
"type": "integer"
|
|
||||||
},
|
|
||||||
"end": {
|
|
||||||
"title": "End",
|
|
||||||
"type": "integer"
|
|
||||||
},
|
|
||||||
"score": {
|
|
||||||
"title": "Score",
|
|
||||||
"type": "number"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"EntityOIEResponseModel": {
|
|
||||||
"title": "EntityOIEResponseModel",
|
|
||||||
"required": [
|
|
||||||
"score",
|
|
||||||
"relation",
|
|
||||||
"arg1",
|
|
||||||
"arg2"
|
|
||||||
],
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"score": {
|
|
||||||
"title": "Score",
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
"relation": {
|
|
||||||
"title": "Relation",
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
"arg1": {
|
|
||||||
"title": "Arg1",
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
"arg2": {
|
|
||||||
"title": "Arg2",
|
|
||||||
"type": "string"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"HTTPValidationError": {
|
|
||||||
"title": "HTTPValidationError",
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"detail": {
|
|
||||||
"title": "Detail",
|
|
||||||
"type": "array",
|
|
||||||
"items": {
|
|
||||||
"$ref": "#/components/schemas/ValidationError"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"InputElement": {
|
|
||||||
"title": "InputElement",
|
|
||||||
"required": [
|
|
||||||
"lang",
|
|
||||||
"text"
|
|
||||||
],
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"lang": {
|
|
||||||
"title": "Lang",
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
"text": {
|
|
||||||
"title": "Text",
|
|
||||||
"type": "string"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"InputElementNRE": {
|
|
||||||
"title": "InputElementNRE",
|
|
||||||
"required": [
|
|
||||||
"text",
|
|
||||||
"h",
|
|
||||||
"t"
|
|
||||||
],
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"text": {
|
|
||||||
"title": "Text",
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
"h": {
|
|
||||||
"$ref": "#/components/schemas/NREHType"
|
|
||||||
},
|
|
||||||
"t": {
|
|
||||||
"$ref": "#/components/schemas/NRETType"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"NREHType": {
|
|
||||||
"title": "NREHType",
|
|
||||||
"required": [
|
|
||||||
"pos"
|
|
||||||
],
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"pos": {
|
|
||||||
"title": "Pos",
|
|
||||||
"type": "string"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"NRETType": {
|
|
||||||
"title": "NRETType",
|
|
||||||
"required": [
|
|
||||||
"pos"
|
|
||||||
],
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"pos": {
|
|
||||||
"title": "Pos",
|
|
||||||
"type": "string"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"ResponseNREModel": {
|
|
||||||
"title": "ResponseNREModel",
|
|
||||||
"required": [
|
|
||||||
"progression"
|
|
||||||
],
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"progression": {
|
|
||||||
"title": "Progression",
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
"result": {
|
|
||||||
"title": "Result",
|
|
||||||
"type": "array",
|
|
||||||
"items": {
|
|
||||||
"type": "array",
|
|
||||||
"items": [
|
|
||||||
{
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "number"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"ResponseNerModel": {
|
|
||||||
"title": "ResponseNerModel",
|
|
||||||
"required": [
|
|
||||||
"progression"
|
|
||||||
],
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"progression": {
|
|
||||||
"title": "Progression",
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
"result": {
|
|
||||||
"title": "Result",
|
|
||||||
"type": "array",
|
|
||||||
"items": {
|
|
||||||
"type": "array",
|
|
||||||
"items": {
|
|
||||||
"$ref": "#/components/schemas/EntityNerResponseModel"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"ResponseOIEModel": {
|
|
||||||
"title": "ResponseOIEModel",
|
|
||||||
"required": [
|
|
||||||
"progression"
|
|
||||||
],
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"progression": {
|
|
||||||
"title": "Progression",
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
"result": {
|
|
||||||
"title": "Result",
|
|
||||||
"type": "array",
|
|
||||||
"items": {
|
|
||||||
"$ref": "#/components/schemas/EntityOIEResponseModel"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"ResponseSummaryModel": {
|
|
||||||
"title": "ResponseSummaryModel",
|
|
||||||
"required": [
|
|
||||||
"progression"
|
|
||||||
],
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"progression": {
|
|
||||||
"title": "Progression",
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
"result": {
|
|
||||||
"title": "Result",
|
|
||||||
"type": "array",
|
|
||||||
"items": {
|
|
||||||
"type": "string"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"ValidationError": {
|
|
||||||
"title": "ValidationError",
|
|
||||||
"required": [
|
|
||||||
"loc",
|
|
||||||
"msg",
|
|
||||||
"type"
|
|
||||||
],
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"loc": {
|
|
||||||
"title": "Location",
|
|
||||||
"type": "array",
|
|
||||||
"items": {
|
|
||||||
"anyOf": [
|
|
||||||
{
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "integer"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"msg": {
|
|
||||||
"title": "Message",
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
"type": {
|
|
||||||
"title": "Error Type",
|
|
||||||
"type": "string"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1 +0,0 @@
|
||||||
web : gunicorn -w 4 -k uvicorn.workers.UvicornWorker main:app
|
|
|
@ -1,14 +0,0 @@
|
||||||
requests==2.28.0
|
|
||||||
flair==0.10
|
|
||||||
langdetect==1.0.9
|
|
||||||
transformers
|
|
||||||
pydantic==1.8.2
|
|
||||||
uvicorn==0.17.6
|
|
||||||
torch>=1.12.1+cu116
|
|
||||||
fastapi==0.78.0
|
|
||||||
protobuf==3.20.0
|
|
||||||
sacremoses
|
|
||||||
nltk>=3.3
|
|
||||||
tqdm>=4.64.0
|
|
||||||
setuptools>=62.6.0
|
|
||||||
starlette>=0.19.1
|
|
|
@ -1,21 +0,0 @@
|
||||||
MIT License
|
|
||||||
|
|
||||||
Copyright (c) 2020 Youngbin Ro
|
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
||||||
of this software and associated documentation files (the "Software"), to deal
|
|
||||||
in the Software without restriction, including without limitation the rights
|
|
||||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
||||||
copies of the Software, and to permit persons to whom the Software is
|
|
||||||
furnished to do so, subject to the following conditions:
|
|
||||||
|
|
||||||
The above copyright notice and this permission notice shall be included in all
|
|
||||||
copies or substantial portions of the Software.
|
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
||||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
||||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
||||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
||||||
SOFTWARE.
|
|
|
@ -1,372 +0,0 @@
|
||||||
import torch
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from transformers import BertModel
|
|
||||||
|
|
||||||
from torch.nn.modules.container import ModuleList
|
|
||||||
import copy
|
|
||||||
|
|
||||||
# from dataset import load_data
|
|
||||||
from transformers import BertTokenizer
|
|
||||||
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
from torch.utils.data import Dataset
|
|
||||||
|
|
||||||
from tqdm import tqdm
|
|
||||||
import numpy as np
|
|
||||||
import src.Multi2OIE.utils.bio as bio
|
|
||||||
|
|
||||||
|
|
||||||
# from extract import extract
|
|
||||||
|
|
||||||
def _get_clones(module, n):
|
|
||||||
return ModuleList([copy.deepcopy(module) for _ in range(n)])
|
|
||||||
|
|
||||||
|
|
||||||
def _get_position_idxs(pred_mask, input_ids):
|
|
||||||
position_idxs = torch.zeros(pred_mask.shape, dtype=int, device=pred_mask.device)
|
|
||||||
for mask_idx, cur_mask in enumerate(pred_mask):
|
|
||||||
position_idxs[mask_idx, :] += 2
|
|
||||||
cur_nonzero = (cur_mask == 0).nonzero()
|
|
||||||
start = torch.min(cur_nonzero).item()
|
|
||||||
end = torch.max(cur_nonzero).item()
|
|
||||||
position_idxs[mask_idx, start:end + 1] = 1
|
|
||||||
pad_start = max(input_ids[mask_idx].nonzero()).item() + 1
|
|
||||||
position_idxs[mask_idx, pad_start:] = 0
|
|
||||||
return position_idxs
|
|
||||||
|
|
||||||
|
|
||||||
def _get_pred_feature(pred_hidden, pred_mask):
|
|
||||||
B, L, D = pred_hidden.shape
|
|
||||||
pred_features = torch.zeros((B, L, D), device=pred_mask.device)
|
|
||||||
for mask_idx, cur_mask in enumerate(pred_mask):
|
|
||||||
pred_position = (cur_mask == 0).nonzero().flatten()
|
|
||||||
pred_feature = torch.mean(pred_hidden[mask_idx, pred_position], dim=0)
|
|
||||||
pred_feature = torch.cat(L * [pred_feature.unsqueeze(0)])
|
|
||||||
pred_features[mask_idx, :, :] = pred_feature
|
|
||||||
return pred_features
|
|
||||||
|
|
||||||
|
|
||||||
def _get_activation_fn(activation):
|
|
||||||
if activation == "relu":
|
|
||||||
return F.relu
|
|
||||||
elif activation == "gelu":
|
|
||||||
return F.gelu
|
|
||||||
else:
|
|
||||||
raise RuntimeError("activation should be relu/gelu, not %s." % activation)
|
|
||||||
|
|
||||||
|
|
||||||
class ArgExtractorLayer(nn.Module):
|
|
||||||
def __init__(self,
|
|
||||||
d_model=768,
|
|
||||||
n_heads=8,
|
|
||||||
d_feedforward=2048,
|
|
||||||
dropout=0.1,
|
|
||||||
activation='relu'):
|
|
||||||
"""
|
|
||||||
A layer similar to Transformer decoder without decoder self-attention.
|
|
||||||
(only encoder-decoder multi-head attention followed by feed-forward layers)
|
|
||||||
|
|
||||||
:param d_model: model dimensionality (default=768 from BERT-base)
|
|
||||||
:param n_heads: number of heads in multi-head attention layer
|
|
||||||
:param d_feedforward: dimensionality of point-wise feed-forward layer
|
|
||||||
:param dropout: drop rate of all layers
|
|
||||||
:param activation: activation function after first feed-forward layer
|
|
||||||
"""
|
|
||||||
super(ArgExtractorLayer, self).__init__()
|
|
||||||
self.multihead_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
|
|
||||||
self.linear1 = nn.Linear(d_model, d_feedforward)
|
|
||||||
self.dropout1 = nn.Dropout(dropout)
|
|
||||||
self.linear2 = nn.Linear(d_feedforward, d_model)
|
|
||||||
|
|
||||||
self.norm1 = nn.LayerNorm(d_model)
|
|
||||||
self.norm2 = nn.LayerNorm(d_model)
|
|
||||||
self.dropout2 = nn.Dropout(dropout)
|
|
||||||
self.dropout3 = nn.Dropout(dropout)
|
|
||||||
self.activation = _get_activation_fn(activation)
|
|
||||||
|
|
||||||
def forward(self, target, source, key_mask=None):
|
|
||||||
"""
|
|
||||||
Single Transformer Decoder layer without self-attention
|
|
||||||
|
|
||||||
:param target: a tensor which takes a role as a query
|
|
||||||
:param source: a tensor which takes a role as a key & value
|
|
||||||
:param key_mask: key mask tensor with the shape of (batch_size, sequence_length)
|
|
||||||
"""
|
|
||||||
# Multi-head attention layer (+ add & norm)
|
|
||||||
attended = self.multihead_attn(
|
|
||||||
target, source, source,
|
|
||||||
key_padding_mask=key_mask)[0]
|
|
||||||
skipped = target + self.dropout1(attended)
|
|
||||||
normed = self.norm1(skipped)
|
|
||||||
|
|
||||||
# Point-wise feed-forward layer (+ add & norm)
|
|
||||||
projected = self.linear2(self.dropout2(self.activation(self.linear1(normed))))
|
|
||||||
skipped = normed + self.dropout1(projected)
|
|
||||||
normed = self.norm2(skipped)
|
|
||||||
return normed
|
|
||||||
|
|
||||||
|
|
||||||
class ArgModule(nn.Module):
|
|
||||||
def __init__(self, arg_layer, n_layers):
|
|
||||||
"""
|
|
||||||
Module for extracting arguments based on given encoder output and predicates.
|
|
||||||
It uses ArgExtractorLayer as a base block and repeat the block N('n_layers') times
|
|
||||||
|
|
||||||
:param arg_layer: an instance of the ArgExtractorLayer() class (required)
|
|
||||||
:param n_layers: the number of sub-layers in the ArgModule (required).
|
|
||||||
"""
|
|
||||||
super(ArgModule, self).__init__()
|
|
||||||
self.layers = _get_clones(arg_layer, n_layers)
|
|
||||||
self.n_layers = n_layers
|
|
||||||
|
|
||||||
def forward(self, encoded, predicate, pred_mask=None):
|
|
||||||
"""
|
|
||||||
:param encoded: output from sentence encoder with the shape of (L, B, D),
|
|
||||||
where L is the sequence length, B is the batch size, D is the embedding dimension
|
|
||||||
:param predicate: output from predicate module with the shape of (L, B, D)
|
|
||||||
:param pred_mask: mask that prevents attention to tokens which are not predicates
|
|
||||||
with the shape of (B, L)
|
|
||||||
:return: tensor like Transformer Decoder Layer Output
|
|
||||||
"""
|
|
||||||
output = encoded
|
|
||||||
for layer_idx in range(self.n_layers):
|
|
||||||
output = self.layers[layer_idx](
|
|
||||||
target=output, source=predicate, key_mask=pred_mask)
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
class Multi2OIE(nn.Module):
|
|
||||||
def __init__(self,
|
|
||||||
bert_config,
|
|
||||||
mh_dropout=0.1,
|
|
||||||
pred_clf_dropout=0.,
|
|
||||||
arg_clf_dropout=0.3,
|
|
||||||
n_arg_heads=8,
|
|
||||||
n_arg_layers=4,
|
|
||||||
pos_emb_dim=64,
|
|
||||||
pred_n_labels=3,
|
|
||||||
arg_n_labels=9):
|
|
||||||
super(Multi2OIE, self).__init__()
|
|
||||||
self.pred_n_labels = pred_n_labels
|
|
||||||
self.arg_n_labels = arg_n_labels
|
|
||||||
|
|
||||||
self.bert = BertModel.from_pretrained(
|
|
||||||
bert_config,
|
|
||||||
output_hidden_states=True)
|
|
||||||
d_model = self.bert.config.hidden_size
|
|
||||||
self.pred_dropout = nn.Dropout(pred_clf_dropout)
|
|
||||||
self.pred_classifier = nn.Linear(d_model, self.pred_n_labels)
|
|
||||||
|
|
||||||
self.position_emb = nn.Embedding(3, pos_emb_dim, padding_idx=0)
|
|
||||||
d_model += (d_model + pos_emb_dim)
|
|
||||||
arg_layer = ArgExtractorLayer(
|
|
||||||
d_model=d_model,
|
|
||||||
n_heads=n_arg_heads,
|
|
||||||
dropout=mh_dropout)
|
|
||||||
self.arg_module = ArgModule(arg_layer, n_arg_layers)
|
|
||||||
self.arg_dropout = nn.Dropout(arg_clf_dropout)
|
|
||||||
self.arg_classifier = nn.Linear(d_model, arg_n_labels)
|
|
||||||
|
|
||||||
def forward(self,
|
|
||||||
input_ids,
|
|
||||||
attention_mask,
|
|
||||||
predicate_mask=None,
|
|
||||||
predicate_hidden=None,
|
|
||||||
total_pred_labels=None,
|
|
||||||
arg_labels=None):
|
|
||||||
|
|
||||||
# predicate extraction
|
|
||||||
bert_hidden = self.bert(input_ids, attention_mask)[0]
|
|
||||||
pred_logit = self.pred_classifier(self.pred_dropout(bert_hidden))
|
|
||||||
|
|
||||||
# predicate loss
|
|
||||||
if total_pred_labels is not None:
|
|
||||||
loss_fct = nn.CrossEntropyLoss()
|
|
||||||
active_loss = attention_mask.view(-1) == 1
|
|
||||||
active_logits = pred_logit.view(-1, self.pred_n_labels)
|
|
||||||
active_labels = torch.where(
|
|
||||||
active_loss, total_pred_labels.view(-1),
|
|
||||||
torch.tensor(loss_fct.ignore_index).type_as(total_pred_labels))
|
|
||||||
pred_loss = loss_fct(active_logits, active_labels)
|
|
||||||
|
|
||||||
# inputs for argument extraction
|
|
||||||
pred_feature = _get_pred_feature(bert_hidden, predicate_mask)
|
|
||||||
position_vectors = self.position_emb(_get_position_idxs(predicate_mask, input_ids))
|
|
||||||
bert_hidden = torch.cat([bert_hidden, pred_feature, position_vectors], dim=2)
|
|
||||||
bert_hidden = bert_hidden.transpose(0, 1)
|
|
||||||
|
|
||||||
# argument extraction
|
|
||||||
arg_hidden = self.arg_module(bert_hidden, bert_hidden, predicate_mask)
|
|
||||||
arg_hidden = arg_hidden.transpose(0, 1)
|
|
||||||
arg_logit = self.arg_classifier(self.arg_dropout(arg_hidden))
|
|
||||||
|
|
||||||
# argument loss
|
|
||||||
if arg_labels is not None:
|
|
||||||
loss_fct = nn.CrossEntropyLoss()
|
|
||||||
active_loss = attention_mask.view(-1) == 1
|
|
||||||
active_logits = arg_logit.view(-1, self.arg_n_labels)
|
|
||||||
active_labels = torch.where(
|
|
||||||
active_loss, arg_labels.view(-1),
|
|
||||||
torch.tensor(loss_fct.ignore_index).type_as(arg_labels))
|
|
||||||
arg_loss = loss_fct(active_logits, active_labels)
|
|
||||||
|
|
||||||
# total loss
|
|
||||||
batch_loss = pred_loss + arg_loss
|
|
||||||
outputs = (batch_loss, pred_loss, arg_loss)
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
def extract_predicate(self,
|
|
||||||
input_ids,
|
|
||||||
attention_mask):
|
|
||||||
bert_hidden = self.bert(input_ids, attention_mask)[0]
|
|
||||||
pred_logit = self.pred_classifier(bert_hidden)
|
|
||||||
return pred_logit, bert_hidden
|
|
||||||
|
|
||||||
def extract_argument(self,
|
|
||||||
input_ids,
|
|
||||||
predicate_hidden,
|
|
||||||
predicate_mask):
|
|
||||||
pred_feature = _get_pred_feature(predicate_hidden, predicate_mask)
|
|
||||||
position_vectors = self.position_emb(_get_position_idxs(predicate_mask, input_ids))
|
|
||||||
arg_input = torch.cat([predicate_hidden, pred_feature, position_vectors], dim=2)
|
|
||||||
arg_input = arg_input.transpose(0, 1)
|
|
||||||
arg_hidden = self.arg_module(arg_input, arg_input, predicate_mask)
|
|
||||||
arg_hidden = arg_hidden.transpose(0, 1)
|
|
||||||
return self.arg_classifier(arg_hidden)
|
|
||||||
|
|
||||||
|
|
||||||
class OieEvalDataset(Dataset):
|
|
||||||
def __init__(self, sentences, max_len, tokenizer_config):
|
|
||||||
self.sentences = sentences
|
|
||||||
self.tokenizer = BertTokenizer.from_pretrained(tokenizer_config)
|
|
||||||
self.vocab = self.tokenizer.vocab
|
|
||||||
self.max_len = max_len
|
|
||||||
|
|
||||||
self.pad_idx = self.vocab['[PAD]']
|
|
||||||
self.cls_idx = self.vocab['[CLS]']
|
|
||||||
self.sep_idx = self.vocab['[SEP]']
|
|
||||||
self.mask_idx = self.vocab['[MASK]']
|
|
||||||
|
|
||||||
def add_pad(self, token_ids):
|
|
||||||
diff = self.max_len - len(token_ids)
|
|
||||||
if diff > 0:
|
|
||||||
token_ids += [self.pad_idx] * diff
|
|
||||||
else:
|
|
||||||
token_ids = token_ids[:self.max_len - 1] + [self.sep_idx]
|
|
||||||
return token_ids
|
|
||||||
|
|
||||||
def idx2mask(self, token_ids):
|
|
||||||
return [token_id != self.pad_idx for token_id in token_ids]
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.sentences)
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
token_ids = self.add_pad(self.tokenizer.encode(self.sentences[idx]))
|
|
||||||
att_mask = self.idx2mask(token_ids)
|
|
||||||
token_strs = self.tokenizer.convert_ids_to_tokens(token_ids)
|
|
||||||
sentence = self.sentences[idx]
|
|
||||||
|
|
||||||
assert len(token_ids) == self.max_len
|
|
||||||
assert len(att_mask) == self.max_len
|
|
||||||
assert len(token_strs) == self.max_len
|
|
||||||
batch = [
|
|
||||||
torch.tensor(token_ids),
|
|
||||||
torch.tensor(att_mask),
|
|
||||||
token_strs,
|
|
||||||
sentence
|
|
||||||
]
|
|
||||||
return batch
|
|
||||||
|
|
||||||
|
|
||||||
def extract(model, loader, device, tokenizer):
|
|
||||||
|
|
||||||
# model.eval()
|
|
||||||
result = []
|
|
||||||
for step, batch in tqdm(enumerate(loader), desc='eval_steps', total=len(loader)):
|
|
||||||
token_strs = [[word for word in sent] for sent in np.asarray(batch[-2]).T]
|
|
||||||
sentences = batch[-1]
|
|
||||||
print(sentences)
|
|
||||||
token_ids, att_mask = map(lambda x: x.to(device), batch[:-2])
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
"""
|
|
||||||
We will iterate B(batch_size) times
|
|
||||||
because there are more than one predicate in one batch.
|
|
||||||
In feeding to argument extractor, # of predicates takes a role as batch size.
|
|
||||||
|
|
||||||
pred_logit: (B, L, 3)
|
|
||||||
pred_hidden: (B, L, D)
|
|
||||||
pred_tags: (B, P, L) ~ list of tensors, where P is # of predicate in each batch
|
|
||||||
"""
|
|
||||||
pred_logit, pred_hidden = model.extract_predicate(
|
|
||||||
input_ids=token_ids, attention_mask=att_mask)
|
|
||||||
pred_tags = torch.argmax(pred_logit, 2)
|
|
||||||
pred_tags = bio.filter_pred_tags(pred_tags, token_strs)
|
|
||||||
pred_tags = bio.get_single_predicate_idxs(pred_tags)
|
|
||||||
pred_probs = torch.nn.Softmax(2)(pred_logit)
|
|
||||||
|
|
||||||
# iterate B times (one iteration means extraction for one sentence)
|
|
||||||
for cur_pred_tags, cur_pred_hidden, cur_att_mask, cur_token_id, cur_pred_probs, token_str, sentence \
|
|
||||||
in zip(pred_tags, pred_hidden, att_mask, token_ids, pred_probs, token_strs, sentences):
|
|
||||||
|
|
||||||
# generate temporary batch for this sentence and feed to argument module
|
|
||||||
cur_pred_masks = bio.get_pred_mask(cur_pred_tags).to(device)
|
|
||||||
n_predicates = cur_pred_masks.shape[0]
|
|
||||||
if n_predicates == 0:
|
|
||||||
continue # if there is no predicate, we cannot extract.
|
|
||||||
cur_pred_hidden = torch.cat(n_predicates * [cur_pred_hidden.unsqueeze(0)])
|
|
||||||
cur_token_id = torch.cat(n_predicates * [cur_token_id.unsqueeze(0)])
|
|
||||||
cur_arg_logit = model.extract_argument(
|
|
||||||
input_ids=cur_token_id,
|
|
||||||
predicate_hidden=cur_pred_hidden,
|
|
||||||
predicate_mask=cur_pred_masks)
|
|
||||||
|
|
||||||
# filter and get argument tags with highest probability
|
|
||||||
cur_arg_tags = torch.argmax(cur_arg_logit, 2)
|
|
||||||
cur_arg_probs = torch.nn.Softmax(2)(cur_arg_logit)
|
|
||||||
cur_arg_tags = bio.filter_arg_tags(cur_arg_tags, cur_pred_tags, token_str)
|
|
||||||
|
|
||||||
# get string tuples and write results
|
|
||||||
cur_extractions, cur_extraction_idxs = bio.get_tuple(sentence, cur_pred_tags, cur_arg_tags, tokenizer)
|
|
||||||
cur_confidences = bio.get_confidence_score(cur_pred_probs, cur_arg_probs, cur_extraction_idxs)
|
|
||||||
for extraction, confidence in zip(cur_extractions, cur_confidences):
|
|
||||||
# print('\n')
|
|
||||||
# print("\t".join([sentence] + [str(1.0)] + extraction[:3]))
|
|
||||||
res_dict = {
|
|
||||||
'score': str(confidence),
|
|
||||||
'relation': extraction[0],
|
|
||||||
'arg1': extraction[1],
|
|
||||||
'arg2': extraction[2]
|
|
||||||
}
|
|
||||||
result.append(res_dict)
|
|
||||||
# print("\nExtraction Done.\n")
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
class OIE:
|
|
||||||
def __init__(self, model_path, BERT, batch_size):
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
self.device = torch.device("cuda")
|
|
||||||
else:
|
|
||||||
self.device = torch.device("cpu")
|
|
||||||
self.model = Multi2OIE(bert_config=BERT).to(self.device)
|
|
||||||
model_weights = torch.load(model_path, map_location=torch.device(self.device))
|
|
||||||
self.model.load_state_dict(model_weights)
|
|
||||||
self.max_len = 64
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.tokenizer_config = BERT
|
|
||||||
self.tokenizer = BertTokenizer.from_pretrained(BERT)
|
|
||||||
|
|
||||||
def predict(self, sentences):
|
|
||||||
loader = DataLoader(
|
|
||||||
dataset=OieEvalDataset(
|
|
||||||
sentences,
|
|
||||||
self.max_len,
|
|
||||||
self.tokenizer_config),
|
|
||||||
batch_size=self.batch_size,
|
|
||||||
num_workers=4,
|
|
||||||
pin_memory=True)
|
|
||||||
return extract(model=self.model, loader=loader, tokenizer=self.tokenizer, device=self.device)
|
|
|
@ -1,182 +0,0 @@
|
||||||
# Multi^2OIE: <u>Multi</u>lingual Open Information Extraction Based on <u>Multi</u>-Head Attention with BERT
|
|
||||||
|
|
||||||
> Source code for learning Multi^2OIE for (multilingual) open information extraction.
|
|
||||||
|
|
||||||
## Paper
|
|
||||||
[**Multi^2OIE: <u>Multi</u>lingual Open Information Extraction Based on <u>Multi</u>-Head Attention with BERT**](https://arxiv.org/abs/2009.08128)<br>
|
|
||||||
[Youngbin Ro](https://github.com/youngbin-ro), [Yukyung Lee](https://github.com/yukyunglee), and [Pilsung Kang](https://github.com/pilsung-kang)*<br>
|
|
||||||
Accepted to Findings of ACL: EMNLP 2020. (*corresponding author)
|
|
||||||
|
|
||||||
<br>
|
|
||||||
|
|
||||||
## Overview
|
|
||||||
### What is Open Information Extraction (Open IE)?
|
|
||||||
[Niklaus et al. (2018)](https://www.aclweb.org/anthology/C18-1326/) describes Open IE as follows:
|
|
||||||
|
|
||||||
> Information extraction (IE) **<u>turns the unstructured information expressed in natural language text into a structured representation</u>** in the form of relational tuples consisting of a set of arguments and a phrase denoting a semantic relation between them: <arg1; rel; arg2>. (...) Unlike traditional IE methods, Open IE is **<u>not limited to a small set of target relations</u>** known in advance, but rather extracts all types of relations found in a text.
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
#### Note
|
|
||||||
- Systems adopting sequence generation scheme ([Cui et al., 2018](https://www.aclweb.org/anthology/P18-2065/); [Kolluru et al., 2020](https://www.aclweb.org/anthology/2020.acl-main.521/)) can extract (actually generate) relations outside of given texts.
|
|
||||||
- Multi^2OIE, however, is adopting sequence labeling scheme ([Stanovsky et al., 2018](https://www.aclweb.org/anthology/N18-1081/)) for computational efficiency and multilingual ability
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
### Our Approach
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
#### Step 1: Extract predicates (relations) from the input sentence using BERT
|
|
||||||
- Conduct token-level classification on the BERT output sequence
|
|
||||||
- Use BIO Tagging for representing arguments and predicates
|
|
||||||
|
|
||||||
#### Step 2: Extract arguments using multi-head attention blocks
|
|
||||||
- Concatenate BERT whole hidden sequence, average vector of hidden sequence at predicate position, and binary embedding vector indicating the token is included in predicate span.
|
|
||||||
- Apply multi-head attention operation over N times
|
|
||||||
- Query: whole hidden sequence
|
|
||||||
- Key-Value pairs: hidden states of predicate positions
|
|
||||||
- Conduct token-level classification on the multi-head attention output sequence
|
|
||||||
|
|
||||||
#### Multilingual Extraction
|
|
||||||
|
|
||||||
- Replace English BERT to Multilingual BERT
|
|
||||||
- Train the model only with English data
|
|
||||||
- Test the model in three difference languages (English, Spanish, and Portuguese) in zero-shot manner.
|
|
||||||
|
|
||||||
<br>
|
|
||||||
|
|
||||||
|
|
||||||
## Usage
|
|
||||||
|
|
||||||
### Prerequisites
|
|
||||||
|
|
||||||
- Python 3.7
|
|
||||||
|
|
||||||
- CUDA 10.0 or above
|
|
||||||
|
|
||||||
### Environmental Setup
|
|
||||||
|
|
||||||
#### Install
|
|
||||||
##### using 'conda' command,
|
|
||||||
~~~~
|
|
||||||
# this makes a new conda environment
|
|
||||||
conda env create -f environment.yml
|
|
||||||
conda activate multi2oie
|
|
||||||
~~~~
|
|
||||||
|
|
||||||
##### using 'pip' command,
|
|
||||||
~~~~
|
|
||||||
pip install -r requirements.txt
|
|
||||||
~~~~
|
|
||||||
|
|
||||||
#### NLTK setup
|
|
||||||
```
|
|
||||||
python -c "import nltk; nltk.download('stopwords')"
|
|
||||||
```
|
|
||||||
|
|
||||||
### Datasets
|
|
||||||
|
|
||||||
#### Dataset Released
|
|
||||||
- `openie4_train.pkl`: https://drive.google.com/file/d/1DrWj1CjLFIno-UBfLI3_uIratN4QY6Y3/view?usp=sharing
|
|
||||||
|
|
||||||
#### Do-it-yourself
|
|
||||||
Original data file (bootstrapped sample from OpenIE4; used in SpanOIE) can be downloaded from [here](https://drive.google.com/file/d/1AEfwbh3BQnsv2VM977cS4tEoldrayKB6/view).
|
|
||||||
Following download, put the downloaded data in './datasets' and use preprocess.py to convert the data into the format suitable for Multi^2OIE.
|
|
||||||
|
|
||||||
~~~~
|
|
||||||
cd utils
|
|
||||||
python preprocess.py \
|
|
||||||
--mode 'train' \
|
|
||||||
--data '../datasets/structured_data.json' \
|
|
||||||
--save_path '../datasets/openie4_train.pkl' \
|
|
||||||
--bert_config 'bert-base-cased' \
|
|
||||||
--max_len 64
|
|
||||||
~~~~
|
|
||||||
|
|
||||||
For multilingual training data, set **'bert_config'** as **'bert-base-multilingual-cased'**.
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
### Run the Code
|
|
||||||
|
|
||||||
#### Model Released
|
|
||||||
- English Model: https://drive.google.com/file/d/11BaLuGjMVVB16WHcyaHWLgL6dg0_9xHQ/view?usp=sharing
|
|
||||||
- Multilingual Model: https://drive.google.com/file/d/1lHQeetbacFOqvyPQ3ZzVUGPgn-zwTRA_/view?usp=sharing
|
|
||||||
|
|
||||||
We used TITAN RTX GPU for training, and the use of other GPU can make the final performance different.
|
|
||||||
|
|
||||||
##### for training,
|
|
||||||
|
|
||||||
~~~~
|
|
||||||
python main.py [--FLAGS]
|
|
||||||
~~~~
|
|
||||||
|
|
||||||
##### for testing,
|
|
||||||
|
|
||||||
~~~~
|
|
||||||
python test.py [--FLAGS]
|
|
||||||
~~~~
|
|
||||||
|
|
||||||
<br>
|
|
||||||
|
|
||||||
## Model Configurations
|
|
||||||
|
|
||||||
### # of Parameters
|
|
||||||
|
|
||||||
- Original BERT: 110M
|
|
||||||
- \+ Multi-Head Attention Blocks: 66M
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
### Hyper-parameters {& searching bounds}
|
|
||||||
|
|
||||||
- epochs: 1 {**1**, 2, 3}
|
|
||||||
- dropout rate for multi-head attention blocks: 0.2 {0.0, 0.1, **0.2**}
|
|
||||||
- dropout rate for argument classifier: 0.2 {0.0, 0.1, **0.2**, 0.3}
|
|
||||||
- batch size: 128 {64, **128**, 256, 512}
|
|
||||||
- learning rate: 3e-5 {2e-5, **3e-5**, 5e-5}
|
|
||||||
- number of multi-head attention heads: 8 {4, **8**}
|
|
||||||
- number of multi-head attention blocks: 4 {2, **4**, 8}
|
|
||||||
- position embedding dimension: 64 {**64**, 128, 256}
|
|
||||||
- gradient clipping norm: 1.0 (not tuned)
|
|
||||||
- learning rate warm-up steps: 10% of total steps (not tuned)
|
|
||||||
- for other unspecified parameters, the default values can be used.
|
|
||||||
<br>
|
|
||||||
|
|
||||||
## Expected Results
|
|
||||||
|
|
||||||
### Development set
|
|
||||||
|
|
||||||
#### OIE2016
|
|
||||||
|
|
||||||
- F1: 71.7
|
|
||||||
- AUC: 55.4
|
|
||||||
|
|
||||||
#### CaRB
|
|
||||||
|
|
||||||
- F1: 54.3
|
|
||||||
- AUC: 34.8
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
### Testing set
|
|
||||||
|
|
||||||
#### Re-OIE2016
|
|
||||||
|
|
||||||
- F1: 83.9
|
|
||||||
- AUC: 74.6
|
|
||||||
|
|
||||||
#### CaRB
|
|
||||||
|
|
||||||
- F1: 52.3
|
|
||||||
- AUC: 32.6
|
|
||||||
|
|
||||||
<br>
|
|
||||||
|
|
||||||
## References
|
|
||||||
|
|
||||||
- https://github.com/gabrielStanovsky/oie-benchmark
|
|
||||||
- https://github.com/dair-iitd/CaRB
|
|
||||||
- https://github.com/zhanjunlang/Span_OIE
|
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -1,21 +0,0 @@
|
||||||
import nltk
|
|
||||||
from operator import itemgetter
|
|
||||||
|
|
||||||
class Argument:
|
|
||||||
def __init__(self, arg):
|
|
||||||
self.words = [x for x in arg[0].strip().split(' ') if x]
|
|
||||||
self.posTags = map(itemgetter(1), nltk.pos_tag(self.words))
|
|
||||||
self.indices = arg[1]
|
|
||||||
self.feats = {}
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return "({})".format('\t'.join(map(str,
|
|
||||||
[escape_special_chars(' '.join(self.words)),
|
|
||||||
str(self.indices)])))
|
|
||||||
|
|
||||||
COREF = 'coref'
|
|
||||||
|
|
||||||
## Helper functions
|
|
||||||
def escape_special_chars(s):
|
|
||||||
return s.replace('\t', '\\t')
|
|
||||||
|
|
|
@ -1,380 +0,0 @@
|
||||||
'''
|
|
||||||
Usage:
|
|
||||||
benchmark --gold=GOLD_OIE --out=OUTPUT_FILE (--openiefive=OPENIE5 | --stanford=STANFORD_OIE | --ollie=OLLIE_OIE |--reverb=REVERB_OIE | --clausie=CLAUSIE_OIE | --openiefour=OPENIEFOUR_OIE | --props=PROPS_OIE | --tabbed=TABBED_OIE | --benchmarkGold=BENCHMARK_GOLD | --allennlp=ALLENNLP_OIE ) [--exactMatch | --predMatch | --lexicalMatch | --binaryMatch | --simpleMatch | --strictMatch] [--error-file=ERROR_FILE] [--binary]
|
|
||||||
|
|
||||||
Options:
|
|
||||||
--gold=GOLD_OIE The gold reference Open IE file (by default, it should be under ./oie_corpus/all.oie).
|
|
||||||
--benchmarkgold=GOLD_OIE The benchmark's gold reference.
|
|
||||||
--out-OUTPUT_FILE The output file, into which the precision recall curve will be written.
|
|
||||||
--clausie=CLAUSIE_OIE Read ClausIE format from file CLAUSIE_OIE.
|
|
||||||
--ollie=OLLIE_OIE Read OLLIE format from file OLLIE_OIE.
|
|
||||||
--openiefour=OPENIEFOUR_OIE Read Open IE 4 format from file OPENIEFOUR_OIE.
|
|
||||||
--openiefive=OPENIE5 Read Open IE 5 format from file OPENIE5.
|
|
||||||
--props=PROPS_OIE Read PropS format from file PROPS_OIE
|
|
||||||
--reverb=REVERB_OIE Read ReVerb format from file REVERB_OIE
|
|
||||||
--stanford=STANFORD_OIE Read Stanford format from file STANFORD_OIE
|
|
||||||
--tabbed=TABBED_OIE Read simple tab format file, where each line consists of:
|
|
||||||
sent, prob, pred,arg1, arg2, ...
|
|
||||||
--exactmatch Use exact match when judging whether an extraction is correct.
|
|
||||||
'''
|
|
||||||
from __future__ import division
|
|
||||||
import docopt
|
|
||||||
import string
|
|
||||||
import numpy as np
|
|
||||||
from sklearn.metrics import precision_recall_curve
|
|
||||||
from sklearn.metrics import auc, roc_auc_score
|
|
||||||
import re
|
|
||||||
import logging
|
|
||||||
import pdb
|
|
||||||
import ipdb
|
|
||||||
from _collections import defaultdict
|
|
||||||
from carb.goldReader import GoldReader
|
|
||||||
from carb.gold_relabel import Relabel_GoldReader
|
|
||||||
logging.basicConfig(level = logging.INFO)
|
|
||||||
|
|
||||||
from operator import itemgetter
|
|
||||||
import pprint
|
|
||||||
from copy import copy
|
|
||||||
pp = pprint.PrettyPrinter(indent=4)
|
|
||||||
|
|
||||||
class Benchmark:
|
|
||||||
''' Compare the gold OIE dataset against a predicted equivalent '''
|
|
||||||
def __init__(self, gold_fn):
|
|
||||||
''' Load gold Open IE, this will serve to compare against using the compare function '''
|
|
||||||
|
|
||||||
if 'Re-OIE2016' in gold_fn:
|
|
||||||
gr = Relabel_GoldReader()
|
|
||||||
else:
|
|
||||||
gr = GoldReader()
|
|
||||||
gr.read(gold_fn)
|
|
||||||
self.gold = gr.oie
|
|
||||||
|
|
||||||
def compare(self, predicted, matchingFunc, output_fn, error_file = None, binary=False):
|
|
||||||
''' Compare gold against predicted using a specified matching function.
|
|
||||||
Outputs PR curve to output_fn '''
|
|
||||||
|
|
||||||
y_true = []
|
|
||||||
y_scores = []
|
|
||||||
errors = []
|
|
||||||
correct = 0
|
|
||||||
incorrect = 0
|
|
||||||
|
|
||||||
correctTotal = 0
|
|
||||||
unmatchedCount = 0
|
|
||||||
predicted = Benchmark.normalizeDict(predicted)
|
|
||||||
gold = Benchmark.normalizeDict(self.gold)
|
|
||||||
if binary:
|
|
||||||
predicted = Benchmark.binarize(predicted)
|
|
||||||
gold = Benchmark.binarize(gold)
|
|
||||||
#gold = self.gold
|
|
||||||
|
|
||||||
# taking all distinct values of confidences as thresholds
|
|
||||||
confidence_thresholds = set()
|
|
||||||
for sent in predicted:
|
|
||||||
for predicted_ex in predicted[sent]:
|
|
||||||
confidence_thresholds.add(predicted_ex.confidence)
|
|
||||||
|
|
||||||
confidence_thresholds = sorted(list(confidence_thresholds))
|
|
||||||
num_conf = len(confidence_thresholds)
|
|
||||||
|
|
||||||
results = {}
|
|
||||||
p = np.zeros(num_conf)
|
|
||||||
pl = np.zeros(num_conf)
|
|
||||||
r = np.zeros(num_conf)
|
|
||||||
rl = np.zeros(num_conf)
|
|
||||||
|
|
||||||
for sent, goldExtractions in gold.items():
|
|
||||||
|
|
||||||
if sent in predicted:
|
|
||||||
predictedExtractions = predicted[sent]
|
|
||||||
else:
|
|
||||||
predictedExtractions = []
|
|
||||||
|
|
||||||
scores = [[None for _ in predictedExtractions] for __ in goldExtractions]
|
|
||||||
|
|
||||||
# print("***Gold Extractions***")
|
|
||||||
# print("\n".join([goldExtractions[i].pred + ' ' + " ".join(goldExtractions[i].args) for i in range(len(goldExtractions))]))
|
|
||||||
# print("***Predicted Extractions***")
|
|
||||||
# print("\n".join([predictedExtractions[i].pred+ " ".join(predictedExtractions[i].args) for i in range(len(predictedExtractions))]))
|
|
||||||
|
|
||||||
for i, goldEx in enumerate(goldExtractions):
|
|
||||||
for j, predictedEx in enumerate(predictedExtractions):
|
|
||||||
score = matchingFunc(goldEx, predictedEx,ignoreStopwords = True,ignoreCase = True)
|
|
||||||
scores[i][j] = score
|
|
||||||
|
|
||||||
|
|
||||||
# OPTIMISED GLOBAL MATCH
|
|
||||||
sent_confidences = [extraction.confidence for extraction in predictedExtractions]
|
|
||||||
sent_confidences.sort()
|
|
||||||
prev_c = 0
|
|
||||||
for conf in sent_confidences:
|
|
||||||
c = confidence_thresholds.index(conf)
|
|
||||||
ext_indices = []
|
|
||||||
for ext_indx, extraction in enumerate(predictedExtractions):
|
|
||||||
if extraction.confidence >= conf:
|
|
||||||
ext_indices.append(ext_indx)
|
|
||||||
|
|
||||||
recall_numerator = 0
|
|
||||||
for i, row in enumerate(scores):
|
|
||||||
max_recall_row = max([row[ext_indx][1] for ext_indx in ext_indices ], default=0)
|
|
||||||
recall_numerator += max_recall_row
|
|
||||||
|
|
||||||
precision_numerator = 0
|
|
||||||
|
|
||||||
selected_rows = []
|
|
||||||
selected_cols = []
|
|
||||||
num_precision_matches = min(len(scores), len(ext_indices))
|
|
||||||
for t in range(num_precision_matches):
|
|
||||||
matched_row = -1
|
|
||||||
matched_col = -1
|
|
||||||
matched_precision = -1 # initialised to <0 so that it updates whenever precision is 0 as well
|
|
||||||
for i in range(len(scores)):
|
|
||||||
if i in selected_rows:
|
|
||||||
continue
|
|
||||||
for ext_indx in ext_indices:
|
|
||||||
if ext_indx in selected_cols:
|
|
||||||
continue
|
|
||||||
if scores[i][ext_indx][0] > matched_precision:
|
|
||||||
matched_precision = scores[i][ext_indx][0]
|
|
||||||
matched_row = i
|
|
||||||
matched_col = ext_indx
|
|
||||||
|
|
||||||
selected_rows.append(matched_row)
|
|
||||||
selected_cols.append(matched_col)
|
|
||||||
precision_numerator += scores[matched_row][matched_col][0]
|
|
||||||
|
|
||||||
p[prev_c:c+1] += precision_numerator
|
|
||||||
pl[prev_c:c+1] += len(ext_indices)
|
|
||||||
r[prev_c:c+1] += recall_numerator
|
|
||||||
rl[prev_c:c+1] += len(scores)
|
|
||||||
|
|
||||||
prev_c = c+1
|
|
||||||
|
|
||||||
# for indices beyond the maximum sentence confidence, len(scores) has to be added to the denominator of recall
|
|
||||||
rl[prev_c:] += len(scores)
|
|
||||||
|
|
||||||
prec_scores = [a/b if b>0 else 1 for a,b in zip(p,pl) ]
|
|
||||||
rec_scores = [a/b if b>0 else 0 for a,b in zip(r,rl)]
|
|
||||||
|
|
||||||
f1s = [Benchmark.f1(p,r) for p,r in zip(prec_scores, rec_scores)]
|
|
||||||
try:
|
|
||||||
optimal_idx = np.nanargmax(f1s)
|
|
||||||
optimal = (prec_scores[optimal_idx], rec_scores[optimal_idx], f1s[optimal_idx])
|
|
||||||
except ValueError:
|
|
||||||
# When there is no prediction
|
|
||||||
optimal = (0,0,0)
|
|
||||||
|
|
||||||
# In order to calculate auc, we need to add the point corresponding to precision=1 , recall=0 to the PR-curve
|
|
||||||
temp_rec_scores = rec_scores.copy()
|
|
||||||
temp_prec_scores = prec_scores.copy()
|
|
||||||
temp_rec_scores.append(0)
|
|
||||||
temp_prec_scores.append(1)
|
|
||||||
# print("AUC: {}\t Optimal (precision, recall, F1): {}".format( np.round(auc(temp_rec_scores, temp_prec_scores),3), np.round(optimal,3) ))
|
|
||||||
|
|
||||||
with open(output_fn, 'w') as fout:
|
|
||||||
fout.write('{0}\t{1}\t{2}\n'.format("Precision", "Recall", "Confidence"))
|
|
||||||
for cur_p, cur_r, cur_conf in sorted(zip(prec_scores, rec_scores, confidence_thresholds), key = lambda cur: cur[1]):
|
|
||||||
fout.write('{0}\t{1}\t{2}\n'.format(cur_p, cur_r, cur_conf))
|
|
||||||
|
|
||||||
if len(f1s)>0:
|
|
||||||
rec_prec_dict = {rec: prec for rec, prec in zip(temp_rec_scores, temp_prec_scores)}
|
|
||||||
rec_prec_dict = sorted(rec_prec_dict.items(), key=lambda x: x[0])
|
|
||||||
temp_rec_scores = [rec for rec, _ in rec_prec_dict]
|
|
||||||
temp_prec_scores = [prec for _, prec in rec_prec_dict]
|
|
||||||
return np.round(auc(temp_rec_scores, temp_prec_scores),5), np.round(optimal,5)
|
|
||||||
else:
|
|
||||||
# When there is no prediction
|
|
||||||
return 0, (0,0,0)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def binarize(extrs):
|
|
||||||
res = defaultdict(lambda: [])
|
|
||||||
for sent,extr in extrs.items():
|
|
||||||
for ex in extr:
|
|
||||||
#Add (a1, r, a2)
|
|
||||||
temp = copy(ex)
|
|
||||||
temp.args = ex.args[:2]
|
|
||||||
res[sent].append(temp)
|
|
||||||
|
|
||||||
if len(ex.args) <= 2:
|
|
||||||
continue
|
|
||||||
|
|
||||||
#Add (a1, r a2 , a3 ...)
|
|
||||||
for arg in ex.args[2:]:
|
|
||||||
temp.args = [ex.args[0]]
|
|
||||||
temp.pred = ex.pred + ' ' + ex.args[1]
|
|
||||||
words = arg.split()
|
|
||||||
|
|
||||||
#Add preposition of arg to rel
|
|
||||||
if words[0].lower() in Benchmark.PREPS:
|
|
||||||
temp.pred += ' ' + words[0]
|
|
||||||
words = words[1:]
|
|
||||||
temp.args.append(' '.join(words))
|
|
||||||
res[sent].append(temp)
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def f1(prec, rec):
|
|
||||||
try:
|
|
||||||
return 2*prec*rec / (prec+rec)
|
|
||||||
except ZeroDivisionError:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def aggregate_scores_greedily(scores):
|
|
||||||
# Greedy match: pick the prediction/gold match with the best f1 and exclude
|
|
||||||
# them both, until nothing left matches. Each input square is a [prec, rec]
|
|
||||||
# pair. Returns precision and recall as score-and-denominator pairs.
|
|
||||||
matches = []
|
|
||||||
while True:
|
|
||||||
max_s = 0
|
|
||||||
gold, pred = None, None
|
|
||||||
for i, gold_ss in enumerate(scores):
|
|
||||||
if i in [m[0] for m in matches]:
|
|
||||||
# Those are already taken rows
|
|
||||||
continue
|
|
||||||
for j, pred_s in enumerate(scores[i]):
|
|
||||||
if j in [m[1] for m in matches]:
|
|
||||||
# Those are used columns
|
|
||||||
continue
|
|
||||||
if pred_s and Benchmark.f1(*pred_s) > max_s:
|
|
||||||
max_s = Benchmark.f1(*pred_s)
|
|
||||||
gold = i
|
|
||||||
pred = j
|
|
||||||
if max_s == 0:
|
|
||||||
break
|
|
||||||
matches.append([gold, pred])
|
|
||||||
# Now that matches are determined, compute final scores.
|
|
||||||
prec_scores = [scores[i][j][0] for i,j in matches]
|
|
||||||
rec_scores = [scores[i][j][1] for i,j in matches]
|
|
||||||
total_prec = sum(prec_scores)
|
|
||||||
total_rec = sum(rec_scores)
|
|
||||||
scoring_metrics = {"precision" : [total_prec, len(scores[0])],
|
|
||||||
"recall" : [total_rec, len(scores)],
|
|
||||||
"precision_of_matches" : prec_scores,
|
|
||||||
"recall_of_matches" : rec_scores
|
|
||||||
}
|
|
||||||
return scoring_metrics
|
|
||||||
|
|
||||||
# Helper functions:
|
|
||||||
@staticmethod
|
|
||||||
def normalizeDict(d):
|
|
||||||
return dict([(Benchmark.normalizeKey(k), v) for k, v in d.items()])
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def normalizeKey(k):
|
|
||||||
# return Benchmark.removePunct(unicode(Benchmark.PTB_unescape(k.replace(' ','')), errors = 'ignore'))
|
|
||||||
return Benchmark.removePunct(str(Benchmark.PTB_unescape(k.replace(' ',''))))
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def PTB_escape(s):
|
|
||||||
for u, e in Benchmark.PTB_ESCAPES:
|
|
||||||
s = s.replace(u, e)
|
|
||||||
return s
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def PTB_unescape(s):
|
|
||||||
for u, e in Benchmark.PTB_ESCAPES:
|
|
||||||
s = s.replace(e, u)
|
|
||||||
return s
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def removePunct(s):
|
|
||||||
return Benchmark.regex.sub('', s)
|
|
||||||
|
|
||||||
# CONSTANTS
|
|
||||||
regex = re.compile('[%s]' % re.escape(string.punctuation))
|
|
||||||
|
|
||||||
# Penn treebank bracket escapes
|
|
||||||
# Taken from: https://github.com/nlplab/brat/blob/master/server/src/gtbtokenize.py
|
|
||||||
PTB_ESCAPES = [('(', '-LRB-'),
|
|
||||||
(')', '-RRB-'),
|
|
||||||
('[', '-LSB-'),
|
|
||||||
(']', '-RSB-'),
|
|
||||||
('{', '-LCB-'),
|
|
||||||
('}', '-RCB-'),]
|
|
||||||
|
|
||||||
PREPS = ['above','across','against','along','among','around','at','before','behind','below','beneath','beside','between','by','for','from','in','into','near','of','off','on','to','toward','under','upon','with','within']
|
|
||||||
|
|
||||||
def f_beta(precision, recall, beta = 1):
|
|
||||||
"""
|
|
||||||
Get F_beta score from precision and recall.
|
|
||||||
"""
|
|
||||||
beta = float(beta) # Make sure that results are in float
|
|
||||||
return (1 + pow(beta, 2)) * (precision * recall) / ((pow(beta, 2) * precision) + recall)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
args = docopt.docopt(__doc__)
|
|
||||||
logging.debug(args)
|
|
||||||
|
|
||||||
if args['--stanford']:
|
|
||||||
predicted = StanfordReader()
|
|
||||||
predicted.read(args['--stanford'])
|
|
||||||
|
|
||||||
if args['--props']:
|
|
||||||
predicted = PropSReader()
|
|
||||||
predicted.read(args['--props'])
|
|
||||||
|
|
||||||
if args['--ollie']:
|
|
||||||
predicted = OllieReader()
|
|
||||||
predicted.read(args['--ollie'])
|
|
||||||
|
|
||||||
if args['--reverb']:
|
|
||||||
predicted = ReVerbReader()
|
|
||||||
predicted.read(args['--reverb'])
|
|
||||||
|
|
||||||
if args['--clausie']:
|
|
||||||
predicted = ClausieReader()
|
|
||||||
predicted.read(args['--clausie'])
|
|
||||||
|
|
||||||
if args['--openiefour']:
|
|
||||||
predicted = OpenieFourReader()
|
|
||||||
predicted.read(args['--openiefour'])
|
|
||||||
|
|
||||||
if args['--openiefive']:
|
|
||||||
predicted = OpenieFiveReader()
|
|
||||||
predicted.read(args['--openiefive'])
|
|
||||||
|
|
||||||
if args['--benchmarkGold']:
|
|
||||||
predicted = BenchmarkGoldReader()
|
|
||||||
predicted.read(args['--benchmarkGold'])
|
|
||||||
|
|
||||||
if args['--tabbed']:
|
|
||||||
predicted = TabReader()
|
|
||||||
predicted.read(args['--tabbed'])
|
|
||||||
|
|
||||||
if args['--binaryMatch']:
|
|
||||||
matchingFunc = Matcher.binary_tuple_match
|
|
||||||
|
|
||||||
elif args['--simpleMatch']:
|
|
||||||
matchingFunc = Matcher.simple_tuple_match
|
|
||||||
|
|
||||||
elif args['--exactMatch']:
|
|
||||||
matchingFunc = Matcher.argMatch
|
|
||||||
|
|
||||||
elif args['--predMatch']:
|
|
||||||
matchingFunc = Matcher.predMatch
|
|
||||||
|
|
||||||
elif args['--lexicalMatch']:
|
|
||||||
matchingFunc = Matcher.lexicalMatch
|
|
||||||
|
|
||||||
elif args['--strictMatch']:
|
|
||||||
matchingFunc = Matcher.tuple_match
|
|
||||||
|
|
||||||
else:
|
|
||||||
matchingFunc = Matcher.binary_linient_tuple_match
|
|
||||||
|
|
||||||
b = Benchmark(args['--gold'])
|
|
||||||
out_filename = args['--out']
|
|
||||||
|
|
||||||
logging.info("Writing PR curve of {} to {}".format(predicted.name, out_filename))
|
|
||||||
|
|
||||||
auc, optimal_f1_point = b.compare(predicted = predicted.oie,
|
|
||||||
matchingFunc = matchingFunc,
|
|
||||||
output_fn = out_filename,
|
|
||||||
error_file = args["--error-file"],
|
|
||||||
binary = args["--binary"])
|
|
||||||
|
|
||||||
print("AUC: {}\t Optimal (precision, recall, F1): {}".format( auc, optimal_f1_point ))
|
|
|
@ -1,444 +0,0 @@
|
||||||
from sklearn.preprocessing.data import binarize
|
|
||||||
from carb.argument import Argument
|
|
||||||
from operator import itemgetter
|
|
||||||
from collections import defaultdict
|
|
||||||
import nltk
|
|
||||||
import itertools
|
|
||||||
import logging
|
|
||||||
import numpy as np
|
|
||||||
import pdb
|
|
||||||
|
|
||||||
class Extraction:
|
|
||||||
"""
|
|
||||||
Stores sentence, single predicate and corresponding arguments.
|
|
||||||
"""
|
|
||||||
def __init__(self, pred, head_pred_index, sent, confidence, question_dist = '', index = -1):
|
|
||||||
self.pred = pred
|
|
||||||
self.head_pred_index = head_pred_index
|
|
||||||
self.sent = sent
|
|
||||||
self.args = []
|
|
||||||
self.confidence = confidence
|
|
||||||
self.matched = []
|
|
||||||
self.questions = {}
|
|
||||||
self.indsForQuestions = defaultdict(lambda: set())
|
|
||||||
self.is_mwp = False
|
|
||||||
self.question_dist = question_dist
|
|
||||||
self.index = index
|
|
||||||
|
|
||||||
def distArgFromPred(self, arg):
|
|
||||||
assert(len(self.pred) == 2)
|
|
||||||
dists = []
|
|
||||||
for x in self.pred[1]:
|
|
||||||
for y in arg.indices:
|
|
||||||
dists.append(abs(x - y))
|
|
||||||
|
|
||||||
return min(dists)
|
|
||||||
|
|
||||||
def argsByDistFromPred(self, question):
|
|
||||||
return sorted(self.questions[question], key = lambda arg: self.distArgFromPred(arg))
|
|
||||||
|
|
||||||
def addArg(self, arg, question = None):
|
|
||||||
self.args.append(arg)
|
|
||||||
if question:
|
|
||||||
self.questions[question] = self.questions.get(question,[]) + [Argument(arg)]
|
|
||||||
|
|
||||||
def noPronounArgs(self):
|
|
||||||
"""
|
|
||||||
Returns True iff all of this extraction's arguments are not pronouns.
|
|
||||||
"""
|
|
||||||
for (a, _) in self.args:
|
|
||||||
tokenized_arg = nltk.word_tokenize(a)
|
|
||||||
if len(tokenized_arg) == 1:
|
|
||||||
_, pos_tag = nltk.pos_tag(tokenized_arg)[0]
|
|
||||||
if ('PRP' in pos_tag):
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def isContiguous(self):
|
|
||||||
return all([indices for (_, indices) in self.args])
|
|
||||||
|
|
||||||
def toBinary(self):
|
|
||||||
''' Try to represent this extraction's arguments as binary
|
|
||||||
If fails, this function will return an empty list. '''
|
|
||||||
|
|
||||||
ret = [self.elementToStr(self.pred)]
|
|
||||||
|
|
||||||
if len(self.args) == 2:
|
|
||||||
# we're in luck
|
|
||||||
return ret + [self.elementToStr(arg) for arg in self.args]
|
|
||||||
|
|
||||||
return []
|
|
||||||
|
|
||||||
if not self.isContiguous():
|
|
||||||
# give up on non contiguous arguments (as we need indexes)
|
|
||||||
return []
|
|
||||||
|
|
||||||
# otherwise, try to merge based on indices
|
|
||||||
# TODO: you can explore other methods for doing this
|
|
||||||
binarized = self.binarizeByIndex()
|
|
||||||
|
|
||||||
if binarized:
|
|
||||||
return ret + binarized
|
|
||||||
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
def elementToStr(self, elem, print_indices = True):
|
|
||||||
''' formats an extraction element (pred or arg) as a raw string
|
|
||||||
removes indices and trailing spaces '''
|
|
||||||
if print_indices:
|
|
||||||
return str(elem)
|
|
||||||
if isinstance(elem, str):
|
|
||||||
return elem
|
|
||||||
if isinstance(elem, tuple):
|
|
||||||
ret = elem[0].rstrip().lstrip()
|
|
||||||
else:
|
|
||||||
ret = ' '.join(elem.words)
|
|
||||||
assert ret, "empty element? {0}".format(elem)
|
|
||||||
return ret
|
|
||||||
|
|
||||||
def binarizeByIndex(self):
|
|
||||||
extraction = [self.pred] + self.args
|
|
||||||
markPred = [(w, ind, i == 0) for i, (w, ind) in enumerate(extraction)]
|
|
||||||
sortedExtraction = sorted(markPred, key = lambda ws, indices, f : indices[0])
|
|
||||||
s = ' '.join(['{1} {0} {1}'.format(self.elementToStr(elem), SEP) if elem[2] else self.elementToStr(elem) for elem in sortedExtraction])
|
|
||||||
binArgs = [a for a in s.split(SEP) if a.rstrip().lstrip()]
|
|
||||||
|
|
||||||
if len(binArgs) == 2:
|
|
||||||
return binArgs
|
|
||||||
|
|
||||||
# failure
|
|
||||||
return []
|
|
||||||
|
|
||||||
def bow(self):
|
|
||||||
return ' '.join([self.elementToStr(elem) for elem in [self.pred] + self.args])
|
|
||||||
|
|
||||||
def getSortedArgs(self):
|
|
||||||
"""
|
|
||||||
Sort the list of arguments.
|
|
||||||
If a question distribution is provided - use it,
|
|
||||||
otherwise, default to the order of appearance in the sentence.
|
|
||||||
"""
|
|
||||||
if self.question_dist:
|
|
||||||
# There's a question distribtuion - use it
|
|
||||||
return self.sort_args_by_distribution()
|
|
||||||
ls = []
|
|
||||||
for q, args in self.questions.iteritems():
|
|
||||||
if (len(args) != 1):
|
|
||||||
logging.debug("Not one argument: {}".format(args))
|
|
||||||
continue
|
|
||||||
arg = args[0]
|
|
||||||
indices = list(self.indsForQuestions[q].union(arg.indices))
|
|
||||||
if not indices:
|
|
||||||
logging.debug("Empty indexes for arg {} -- backing to zero".format(arg))
|
|
||||||
indices = [0]
|
|
||||||
ls.append(((arg, q), indices))
|
|
||||||
return [a for a, _ in sorted(ls,
|
|
||||||
key = lambda _, indices: min(indices))]
|
|
||||||
|
|
||||||
def question_prob_for_loc(self, question, loc):
|
|
||||||
"""
|
|
||||||
Returns the probability of the given question leading to argument
|
|
||||||
appearing in the given location in the output slot.
|
|
||||||
"""
|
|
||||||
gen_question = generalize_question(question)
|
|
||||||
q_dist = self.question_dist[gen_question]
|
|
||||||
logging.debug("distribution of {}: {}".format(gen_question,
|
|
||||||
q_dist))
|
|
||||||
|
|
||||||
return float(q_dist.get(loc, 0)) / \
|
|
||||||
sum(q_dist.values())
|
|
||||||
|
|
||||||
def sort_args_by_distribution(self):
|
|
||||||
"""
|
|
||||||
Use this instance's question distribution (this func assumes it exists)
|
|
||||||
in determining the positioning of the arguments.
|
|
||||||
Greedy algorithm:
|
|
||||||
0. Decide on which argument will serve as the ``subject'' (first slot) of this extraction
|
|
||||||
0.1 Based on the most probable one for this spot
|
|
||||||
(special care is given to select the highly-influential subject position)
|
|
||||||
1. For all other arguments, sort arguments by the prevalance of their questions
|
|
||||||
2. For each argument:
|
|
||||||
2.1 Assign to it the most probable slot still available
|
|
||||||
2.2 If non such exist (fallback) - default to put it in the last location
|
|
||||||
"""
|
|
||||||
INF_LOC = 100 # Used as an impractical last argument
|
|
||||||
|
|
||||||
# Store arguments by slot
|
|
||||||
ret = {INF_LOC: []}
|
|
||||||
logging.debug("sorting: {}".format(self.questions))
|
|
||||||
|
|
||||||
# Find the most suitable arguemnt for the subject location
|
|
||||||
logging.debug("probs for subject: {}".format([(q, self.question_prob_for_loc(q, 0))
|
|
||||||
for (q, _) in self.questions.iteritems()]))
|
|
||||||
|
|
||||||
subj_question, subj_args = max(self.questions.iteritems(),
|
|
||||||
key = lambda q, _: self.question_prob_for_loc(q, 0))
|
|
||||||
|
|
||||||
ret[0] = [(subj_args[0], subj_question)]
|
|
||||||
|
|
||||||
# Find the rest
|
|
||||||
for (question, args) in sorted([(q, a)
|
|
||||||
for (q, a) in self.questions.iteritems() if (q not in [subj_question])],
|
|
||||||
key = lambda q, _: \
|
|
||||||
sum(self.question_dist[generalize_question(q)].values()),
|
|
||||||
reverse = True):
|
|
||||||
gen_question = generalize_question(question)
|
|
||||||
arg = args[0]
|
|
||||||
assigned_flag = False
|
|
||||||
for (loc, count) in sorted(self.question_dist[gen_question].iteritems(),
|
|
||||||
key = lambda _ , c: c,
|
|
||||||
reverse = True):
|
|
||||||
if loc not in ret:
|
|
||||||
# Found an empty slot for this item
|
|
||||||
# Place it there and break out
|
|
||||||
ret[loc] = [(arg, question)]
|
|
||||||
assigned_flag = True
|
|
||||||
break
|
|
||||||
|
|
||||||
if not assigned_flag:
|
|
||||||
# Add this argument to the non-assigned (hopefully doesn't happen much)
|
|
||||||
logging.debug("Couldn't find an open assignment for {}".format((arg, gen_question)))
|
|
||||||
ret[INF_LOC].append((arg, question))
|
|
||||||
|
|
||||||
logging.debug("Linearizing arg list: {}".format(ret))
|
|
||||||
|
|
||||||
# Finished iterating - consolidate and return a list of arguments
|
|
||||||
return [arg
|
|
||||||
for (_, arg_ls) in sorted(ret.iteritems(),
|
|
||||||
key = lambda k, v: int(k))
|
|
||||||
for arg in arg_ls]
|
|
||||||
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
pred_str = self.elementToStr(self.pred)
|
|
||||||
return '{}\t{}\t{}'.format(self.get_base_verb(pred_str),
|
|
||||||
self.compute_global_pred(pred_str,
|
|
||||||
self.questions.keys()),
|
|
||||||
'\t'.join([escape_special_chars(self.augment_arg_with_question(self.elementToStr(arg),
|
|
||||||
question))
|
|
||||||
for arg, question in self.getSortedArgs()]))
|
|
||||||
|
|
||||||
def get_base_verb(self, surface_pred):
|
|
||||||
"""
|
|
||||||
Given the surface pred, return the original annotated verb
|
|
||||||
"""
|
|
||||||
# Assumes that at this point the verb is always the last word
|
|
||||||
# in the surface predicate
|
|
||||||
return surface_pred.split(' ')[-1]
|
|
||||||
|
|
||||||
|
|
||||||
def compute_global_pred(self, surface_pred, questions):
|
|
||||||
"""
|
|
||||||
Given the surface pred and all instansiations of questions,
|
|
||||||
make global coherence decisions regarding the final form of the predicate
|
|
||||||
This should hopefully take care of multi word predicates and correct inflections
|
|
||||||
"""
|
|
||||||
from operator import itemgetter
|
|
||||||
split_surface = surface_pred.split(' ')
|
|
||||||
|
|
||||||
if len(split_surface) > 1:
|
|
||||||
# This predicate has a modal preceding the base verb
|
|
||||||
verb = split_surface[-1]
|
|
||||||
ret = split_surface[:-1] # get all of the elements in the modal
|
|
||||||
else:
|
|
||||||
verb = split_surface[0]
|
|
||||||
ret = []
|
|
||||||
|
|
||||||
split_questions = map(lambda question: question.split(' '),
|
|
||||||
questions)
|
|
||||||
|
|
||||||
preds = map(normalize_element,
|
|
||||||
map(itemgetter(QUESTION_TRG_INDEX),
|
|
||||||
split_questions))
|
|
||||||
if len(set(preds)) > 1:
|
|
||||||
# This predicate is appears in multiple ways, let's stick to the base form
|
|
||||||
ret.append(verb)
|
|
||||||
|
|
||||||
if len(set(preds)) == 1:
|
|
||||||
# Change the predciate to the inflected form
|
|
||||||
# if there's exactly one way in which the predicate is conveyed
|
|
||||||
ret.append(preds[0])
|
|
||||||
|
|
||||||
pps = map(normalize_element,
|
|
||||||
map(itemgetter(QUESTION_PP_INDEX),
|
|
||||||
split_questions))
|
|
||||||
|
|
||||||
obj2s = map(normalize_element,
|
|
||||||
map(itemgetter(QUESTION_OBJ2_INDEX),
|
|
||||||
split_questions))
|
|
||||||
|
|
||||||
if (len(set(pps)) == 1):
|
|
||||||
# If all questions for the predicate include the same pp attachemnt -
|
|
||||||
# assume it's a multiword predicate
|
|
||||||
self.is_mwp = True # Signal to arguments that they shouldn't take the preposition
|
|
||||||
ret.append(pps[0])
|
|
||||||
|
|
||||||
# Concat all elements in the predicate and return
|
|
||||||
return " ".join(ret).strip()
|
|
||||||
|
|
||||||
|
|
||||||
def augment_arg_with_question(self, arg, question):
|
|
||||||
"""
|
|
||||||
Decide what elements from the question to incorporate in the given
|
|
||||||
corresponding argument
|
|
||||||
"""
|
|
||||||
# Parse question
|
|
||||||
wh, aux, sbj, trg, obj1, pp, obj2 = map(normalize_element,
|
|
||||||
question.split(' ')[:-1]) # Last split is the question mark
|
|
||||||
|
|
||||||
# Place preposition in argument
|
|
||||||
# This is safer when dealing with n-ary arguments, as it's directly attaches to the
|
|
||||||
# appropriate argument
|
|
||||||
if (not self.is_mwp) and pp and (not obj2):
|
|
||||||
if not(arg.startswith("{} ".format(pp))):
|
|
||||||
# Avoid repeating the preporition in cases where both question and answer contain it
|
|
||||||
return " ".join([pp,
|
|
||||||
arg])
|
|
||||||
|
|
||||||
# Normal cases
|
|
||||||
return arg
|
|
||||||
|
|
||||||
def clusterScore(self, cluster):
|
|
||||||
"""
|
|
||||||
Calculate cluster density score as the mean distance of the maximum distance of each slot.
|
|
||||||
Lower score represents a denser cluster.
|
|
||||||
"""
|
|
||||||
logging.debug("*-*-*- Cluster: {}".format(cluster))
|
|
||||||
|
|
||||||
# Find global centroid
|
|
||||||
arr = np.array([x for ls in cluster for x in ls])
|
|
||||||
centroid = np.sum(arr)/arr.shape[0]
|
|
||||||
logging.debug("Centroid: {}".format(centroid))
|
|
||||||
|
|
||||||
# Calculate mean over all maxmimum points
|
|
||||||
return np.average([max([abs(x - centroid) for x in ls]) for ls in cluster])
|
|
||||||
|
|
||||||
def resolveAmbiguity(self):
|
|
||||||
"""
|
|
||||||
Heursitic to map the elments (argument and predicates) of this extraction
|
|
||||||
back to the indices of the sentence.
|
|
||||||
"""
|
|
||||||
## TODO: This removes arguments for which there was no consecutive span found
|
|
||||||
## Part of these are non-consecutive arguments,
|
|
||||||
## but other could be a bug in recognizing some punctuation marks
|
|
||||||
|
|
||||||
elements = [self.pred] \
|
|
||||||
+ [(s, indices)
|
|
||||||
for (s, indices)
|
|
||||||
in self.args
|
|
||||||
if indices]
|
|
||||||
logging.debug("Resolving ambiguity in: {}".format(elements))
|
|
||||||
|
|
||||||
# Collect all possible combinations of arguments and predicate indices
|
|
||||||
# (hopefully it's not too much)
|
|
||||||
all_combinations = list(itertools.product(*map(itemgetter(1), elements)))
|
|
||||||
logging.debug("Number of combinations: {}".format(len(all_combinations)))
|
|
||||||
|
|
||||||
# Choose the ones with best clustering and unfold them
|
|
||||||
resolved_elements = zip(map(itemgetter(0), elements),
|
|
||||||
min(all_combinations,
|
|
||||||
key = lambda cluster: self.clusterScore(cluster)))
|
|
||||||
logging.debug("Resolved elements = {}".format(resolved_elements))
|
|
||||||
|
|
||||||
self.pred = resolved_elements[0]
|
|
||||||
self.args = resolved_elements[1:]
|
|
||||||
|
|
||||||
def conll(self, external_feats = {}):
|
|
||||||
"""
|
|
||||||
Return a CoNLL string representation of this extraction
|
|
||||||
"""
|
|
||||||
return '\n'.join(["\t".join(map(str,
|
|
||||||
[i, w] + \
|
|
||||||
list(self.pred) + \
|
|
||||||
[self.head_pred_index] + \
|
|
||||||
external_feats + \
|
|
||||||
[self.get_label(i)]))
|
|
||||||
for (i, w)
|
|
||||||
in enumerate(self.sent.split(" "))]) + '\n'
|
|
||||||
|
|
||||||
def get_label(self, index):
|
|
||||||
"""
|
|
||||||
Given an index of a word in the sentence -- returns the appropriate BIO conll label
|
|
||||||
Assumes that ambiguation was already resolved.
|
|
||||||
"""
|
|
||||||
# Get the element(s) in which this index appears
|
|
||||||
ent = [(elem_ind, elem)
|
|
||||||
for (elem_ind, elem)
|
|
||||||
in enumerate(map(itemgetter(1),
|
|
||||||
[self.pred] + self.args))
|
|
||||||
if index in elem]
|
|
||||||
|
|
||||||
if not ent:
|
|
||||||
# index doesnt appear in any element
|
|
||||||
return "O"
|
|
||||||
|
|
||||||
if len(ent) > 1:
|
|
||||||
# The same word appears in two different answers
|
|
||||||
# In this case we choose the first one as label
|
|
||||||
logging.warn("Index {} appears in one than more element: {}".\
|
|
||||||
format(index,
|
|
||||||
"\t".join(map(str,
|
|
||||||
[ent,
|
|
||||||
self.sent,
|
|
||||||
self.pred,
|
|
||||||
self.args]))))
|
|
||||||
|
|
||||||
## Some indices appear in more than one argument (ones where the above message appears)
|
|
||||||
## From empricial observation, these seem to mostly consist of different levels of granularity:
|
|
||||||
## what had _ been taken _ _ _ ? loan commitments topping $ 3 billion
|
|
||||||
## how much had _ been taken _ _ _ ? topping $ 3 billion
|
|
||||||
## In these cases we heuristically choose the shorter answer span, hopefully creating minimal spans
|
|
||||||
## E.g., in this example two arguemnts are created: (loan commitments, topping $ 3 billion)
|
|
||||||
|
|
||||||
elem_ind, elem = min(ent, key = lambda _, ls: len(ls))
|
|
||||||
|
|
||||||
# Distinguish between predicate and arguments
|
|
||||||
prefix = "P" if elem_ind == 0 else "A{}".format(elem_ind - 1)
|
|
||||||
|
|
||||||
# Distinguish between Beginning and Inside labels
|
|
||||||
suffix = "B" if index == elem[0] else "I"
|
|
||||||
|
|
||||||
return "{}-{}".format(prefix, suffix)
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return '{0}\t{1}'.format(self.elementToStr(self.pred,
|
|
||||||
print_indices = True),
|
|
||||||
'\t'.join([self.elementToStr(arg)
|
|
||||||
for arg
|
|
||||||
in self.args]))
|
|
||||||
|
|
||||||
# Flatten a list of lists
|
|
||||||
flatten = lambda l: [item for sublist in l for item in sublist]
|
|
||||||
|
|
||||||
|
|
||||||
def normalize_element(elem):
|
|
||||||
"""
|
|
||||||
Return a surface form of the given question element.
|
|
||||||
the output should be properly able to precede a predicate (or blank otherwise)
|
|
||||||
"""
|
|
||||||
return elem.replace("_", " ") \
|
|
||||||
if (elem != "_")\
|
|
||||||
else ""
|
|
||||||
|
|
||||||
## Helper functions
|
|
||||||
def escape_special_chars(s):
|
|
||||||
return s.replace('\t', '\\t')
|
|
||||||
|
|
||||||
|
|
||||||
def generalize_question(question):
|
|
||||||
"""
|
|
||||||
Given a question in the context of the sentence and the predicate index within
|
|
||||||
the question - return a generalized version which extracts only order-imposing features
|
|
||||||
"""
|
|
||||||
import nltk # Using nltk since couldn't get spaCy to agree on the tokenization
|
|
||||||
wh, aux, sbj, trg, obj1, pp, obj2 = question.split(' ')[:-1] # Last split is the question mark
|
|
||||||
return ' '.join([wh, sbj, obj1])
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## CONSTANTS
|
|
||||||
SEP = ';;;'
|
|
||||||
QUESTION_TRG_INDEX = 3 # index of the predicate within the question
|
|
||||||
QUESTION_PP_INDEX = 5
|
|
||||||
QUESTION_OBJ2_INDEX = 6
|
|
|
@ -1,53 +0,0 @@
|
||||||
from carb.oieReader import OieReader
|
|
||||||
from carb.extraction import Extraction
|
|
||||||
from _collections import defaultdict
|
|
||||||
import ipdb
|
|
||||||
|
|
||||||
class GoldReader(OieReader):
|
|
||||||
|
|
||||||
# Path relative to repo root folder
|
|
||||||
default_filename = './oie_corpus/all.oie'
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.name = 'Gold'
|
|
||||||
|
|
||||||
def read(self, fn):
|
|
||||||
d = defaultdict(lambda: [])
|
|
||||||
multilingual = False
|
|
||||||
for lang in ['spanish']:
|
|
||||||
if lang in fn:
|
|
||||||
multilingual = True
|
|
||||||
encoding = lang
|
|
||||||
break
|
|
||||||
if multilingual and encoding == 'spanish':
|
|
||||||
fin = open(fn, 'r', encoding='latin-1')
|
|
||||||
else:
|
|
||||||
fin = open(fn)
|
|
||||||
#with open(fn) as fin:
|
|
||||||
for line_ind, line in enumerate(fin):
|
|
||||||
data = line.strip().split('\t')
|
|
||||||
text, rel = data[:2]
|
|
||||||
args = data[2:]
|
|
||||||
confidence = 1
|
|
||||||
|
|
||||||
curExtraction = Extraction(pred = rel.strip(),
|
|
||||||
head_pred_index = None,
|
|
||||||
sent = text.strip(),
|
|
||||||
confidence = float(confidence),
|
|
||||||
index = line_ind)
|
|
||||||
for arg in args:
|
|
||||||
if "C: " in arg:
|
|
||||||
continue
|
|
||||||
curExtraction.addArg(arg.strip())
|
|
||||||
|
|
||||||
d[text.strip()].append(curExtraction)
|
|
||||||
self.oie = d
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__' :
|
|
||||||
g = GoldReader()
|
|
||||||
g.read('../oie_corpus/all.oie', includeNominal = False)
|
|
||||||
d = g.oie
|
|
||||||
e = d.items()[0]
|
|
||||||
print(e[1][0].bow())
|
|
||||||
print(g.count())
|
|
|
@ -1,46 +0,0 @@
|
||||||
from carb.oieReader import OieReader
|
|
||||||
from carb.extraction import Extraction
|
|
||||||
from _collections import defaultdict
|
|
||||||
import json
|
|
||||||
|
|
||||||
class Relabel_GoldReader(OieReader):
|
|
||||||
|
|
||||||
# Path relative to repo root folder
|
|
||||||
default_filename = './oie_corpus/all.oie'
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.name = 'Relabel_Gold'
|
|
||||||
|
|
||||||
def read(self, fn):
|
|
||||||
d = defaultdict(lambda: [])
|
|
||||||
with open(fn) as fin:
|
|
||||||
data = json.load(fin)
|
|
||||||
for sentence in data:
|
|
||||||
tuples = data[sentence]
|
|
||||||
for t in tuples:
|
|
||||||
if t["pred"].strip() == "<be>":
|
|
||||||
rel = "[is]"
|
|
||||||
else:
|
|
||||||
rel = t["pred"].replace("<be> ","")
|
|
||||||
confidence = 1
|
|
||||||
|
|
||||||
curExtraction = Extraction(pred = rel,
|
|
||||||
head_pred_index = None,
|
|
||||||
sent = sentence,
|
|
||||||
confidence = float(confidence),
|
|
||||||
index = None)
|
|
||||||
if t["arg0"] != "":
|
|
||||||
curExtraction.addArg(t["arg0"])
|
|
||||||
if t["arg1"] != "":
|
|
||||||
curExtraction.addArg(t["arg1"])
|
|
||||||
if t["arg2"] != "":
|
|
||||||
curExtraction.addArg(t["arg2"])
|
|
||||||
if t["arg3"] != "":
|
|
||||||
curExtraction.addArg(t["arg3"])
|
|
||||||
if t["temp"] != "":
|
|
||||||
curExtraction.addArg(t["temp"])
|
|
||||||
if t["loc"] != "":
|
|
||||||
curExtraction.addArg(t["loc"])
|
|
||||||
|
|
||||||
d[sentence].append(curExtraction)
|
|
||||||
self.oie = d
|
|
|
@ -1,340 +0,0 @@
|
||||||
from __future__ import division
|
|
||||||
import string
|
|
||||||
from nltk.translate.bleu_score import sentence_bleu
|
|
||||||
from nltk.corpus import stopwords
|
|
||||||
from copy import copy
|
|
||||||
import ipdb
|
|
||||||
|
|
||||||
class Matcher:
|
|
||||||
@staticmethod
|
|
||||||
def bowMatch(ref, ex, ignoreStopwords, ignoreCase):
|
|
||||||
"""
|
|
||||||
A binary function testing for exact lexical match (ignoring ordering) between reference
|
|
||||||
and predicted extraction
|
|
||||||
"""
|
|
||||||
s1 = ref.bow()
|
|
||||||
s2 = ex.bow()
|
|
||||||
if ignoreCase:
|
|
||||||
s1 = s1.lower()
|
|
||||||
s2 = s2.lower()
|
|
||||||
|
|
||||||
s1Words = s1.split(' ')
|
|
||||||
s2Words = s2.split(' ')
|
|
||||||
|
|
||||||
if ignoreStopwords:
|
|
||||||
s1Words = Matcher.removeStopwords(s1Words)
|
|
||||||
s2Words = Matcher.removeStopwords(s2Words)
|
|
||||||
|
|
||||||
return sorted(s1Words) == sorted(s2Words)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def predMatch(ref, ex, ignoreStopwords, ignoreCase):
|
|
||||||
"""
|
|
||||||
Return whehter gold and predicted extractions agree on the predicate
|
|
||||||
"""
|
|
||||||
s1 = ref.elementToStr(ref.pred)
|
|
||||||
s2 = ex.elementToStr(ex.pred)
|
|
||||||
if ignoreCase:
|
|
||||||
s1 = s1.lower()
|
|
||||||
s2 = s2.lower()
|
|
||||||
|
|
||||||
s1Words = s1.split(' ')
|
|
||||||
s2Words = s2.split(' ')
|
|
||||||
|
|
||||||
if ignoreStopwords:
|
|
||||||
s1Words = Matcher.removeStopwords(s1Words)
|
|
||||||
s2Words = Matcher.removeStopwords(s2Words)
|
|
||||||
|
|
||||||
return s1Words == s2Words
|
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def argMatch(ref, ex, ignoreStopwords, ignoreCase):
|
|
||||||
"""
|
|
||||||
Return whehter gold and predicted extractions agree on the arguments
|
|
||||||
"""
|
|
||||||
sRef = ' '.join([ref.elementToStr(elem) for elem in ref.args])
|
|
||||||
sEx = ' '.join([ex.elementToStr(elem) for elem in ex.args])
|
|
||||||
|
|
||||||
count = 0
|
|
||||||
|
|
||||||
for w1 in sRef:
|
|
||||||
for w2 in sEx:
|
|
||||||
if w1 == w2:
|
|
||||||
count += 1
|
|
||||||
|
|
||||||
# We check how well does the extraction lexically cover the reference
|
|
||||||
# Note: this is somewhat lenient as it doesn't penalize the extraction for
|
|
||||||
# being too long
|
|
||||||
coverage = float(count) / len(sRef)
|
|
||||||
|
|
||||||
|
|
||||||
return coverage > Matcher.LEXICAL_THRESHOLD
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def bleuMatch(ref, ex, ignoreStopwords, ignoreCase):
|
|
||||||
sRef = ref.bow()
|
|
||||||
sEx = ex.bow()
|
|
||||||
bleu = sentence_bleu(references = [sRef.split(' ')], hypothesis = sEx.split(' '))
|
|
||||||
return bleu > Matcher.BLEU_THRESHOLD
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def lexicalMatch(ref, ex, ignoreStopwords, ignoreCase):
|
|
||||||
sRef = ref.bow().split(' ')
|
|
||||||
sEx = ex.bow().split(' ')
|
|
||||||
count = 0
|
|
||||||
#for w1 in sRef:
|
|
||||||
# if w1 in sEx:
|
|
||||||
# count += 1
|
|
||||||
# sEx.remove(w1)
|
|
||||||
for w1 in sRef:
|
|
||||||
for w2 in sEx:
|
|
||||||
if w1 == w2:
|
|
||||||
count += 1
|
|
||||||
|
|
||||||
# We check how well does the extraction lexically cover the reference
|
|
||||||
# Note: this is somewhat lenient as it doesn't penalize the extraction for
|
|
||||||
# being too long
|
|
||||||
coverage = float(count) / len(sRef)
|
|
||||||
|
|
||||||
return coverage > Matcher.LEXICAL_THRESHOLD
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def tuple_match(ref, ex, ignoreStopwords, ignoreCase):
|
|
||||||
precision = [0, 0] # 0 out of 0 predicted words match
|
|
||||||
recall = [0, 0] # 0 out of 0 reference words match
|
|
||||||
# If, for each part, any word is the same as a reference word, then it's a match.
|
|
||||||
|
|
||||||
predicted_words = ex.pred.split()
|
|
||||||
gold_words = ref.pred.split()
|
|
||||||
precision[1] += len(predicted_words)
|
|
||||||
recall[1] += len(gold_words)
|
|
||||||
|
|
||||||
# matching_words = sum(1 for w in predicted_words if w in gold_words)
|
|
||||||
matching_words = 0
|
|
||||||
for w in gold_words:
|
|
||||||
if w in predicted_words:
|
|
||||||
matching_words += 1
|
|
||||||
predicted_words.remove(w)
|
|
||||||
|
|
||||||
if matching_words == 0:
|
|
||||||
return False # t <-> gt is not a match
|
|
||||||
precision[0] += matching_words
|
|
||||||
recall[0] += matching_words
|
|
||||||
|
|
||||||
for i in range(len(ref.args)):
|
|
||||||
gold_words = ref.args[i].split()
|
|
||||||
recall[1] += len(gold_words)
|
|
||||||
if len(ex.args) <= i:
|
|
||||||
if i<2:
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
continue
|
|
||||||
predicted_words = ex.args[i].split()
|
|
||||||
precision[1] += len(predicted_words)
|
|
||||||
matching_words = 0
|
|
||||||
for w in gold_words:
|
|
||||||
if w in predicted_words:
|
|
||||||
matching_words += 1
|
|
||||||
predicted_words.remove(w)
|
|
||||||
|
|
||||||
if matching_words == 0 and i<2:
|
|
||||||
return False # t <-> gt is not a match
|
|
||||||
precision[0] += matching_words
|
|
||||||
# Currently this slightly penalises systems when the reference
|
|
||||||
# reformulates the sentence words, because the reformulation doesn't
|
|
||||||
# match the predicted word. It's a one-wrong-word penalty to precision,
|
|
||||||
# to all systems that correctly extracted the reformulated word.
|
|
||||||
recall[0] += matching_words
|
|
||||||
|
|
||||||
prec = 1.0 * precision[0] / precision[1]
|
|
||||||
rec = 1.0 * recall[0] / recall[1]
|
|
||||||
return [prec, rec]
|
|
||||||
|
|
||||||
# STRICTER LINIENT MATCH
|
|
||||||
def linient_tuple_match(ref, ex, ignoreStopwords, ignoreCase):
|
|
||||||
precision = [0, 0] # 0 out of 0 predicted words match
|
|
||||||
recall = [0, 0] # 0 out of 0 reference words match
|
|
||||||
# If, for each part, any word is the same as a reference word, then it's a match.
|
|
||||||
|
|
||||||
predicted_words = ex.pred.split()
|
|
||||||
gold_words = ref.pred.split()
|
|
||||||
precision[1] += len(predicted_words)
|
|
||||||
recall[1] += len(gold_words)
|
|
||||||
|
|
||||||
# matching_words = sum(1 for w in predicted_words if w in gold_words)
|
|
||||||
matching_words = 0
|
|
||||||
for w in gold_words:
|
|
||||||
if w in predicted_words:
|
|
||||||
matching_words += 1
|
|
||||||
predicted_words.remove(w)
|
|
||||||
|
|
||||||
# matching 'be' with its different forms
|
|
||||||
forms_of_be = ["be","is","am","are","was","were","been","being"]
|
|
||||||
if "be" in predicted_words:
|
|
||||||
for form in forms_of_be:
|
|
||||||
if form in gold_words:
|
|
||||||
matching_words += 1
|
|
||||||
predicted_words.remove("be")
|
|
||||||
break
|
|
||||||
|
|
||||||
if matching_words == 0:
|
|
||||||
return [0,0] # t <-> gt is not a match
|
|
||||||
|
|
||||||
precision[0] += matching_words
|
|
||||||
recall[0] += matching_words
|
|
||||||
|
|
||||||
for i in range(len(ref.args)):
|
|
||||||
gold_words = ref.args[i].split()
|
|
||||||
recall[1] += len(gold_words)
|
|
||||||
if len(ex.args) <= i:
|
|
||||||
if i<2:
|
|
||||||
return [0,0] # changed
|
|
||||||
else:
|
|
||||||
continue
|
|
||||||
predicted_words = ex.args[i].split()
|
|
||||||
precision[1] += len(predicted_words)
|
|
||||||
matching_words = 0
|
|
||||||
for w in gold_words:
|
|
||||||
if w in predicted_words:
|
|
||||||
matching_words += 1
|
|
||||||
predicted_words.remove(w)
|
|
||||||
|
|
||||||
precision[0] += matching_words
|
|
||||||
# Currently this slightly penalises systems when the reference
|
|
||||||
# reformulates the sentence words, because the reformulation doesn't
|
|
||||||
# match the predicted word. It's a one-wrong-word penalty to precision,
|
|
||||||
# to all systems that correctly extracted the reformulated word.
|
|
||||||
recall[0] += matching_words
|
|
||||||
|
|
||||||
if(precision[1] == 0):
|
|
||||||
prec = 0
|
|
||||||
else:
|
|
||||||
prec = 1.0 * precision[0] / precision[1]
|
|
||||||
if(recall[1] == 0):
|
|
||||||
rec = 0
|
|
||||||
else:
|
|
||||||
rec = 1.0 * recall[0] / recall[1]
|
|
||||||
return [prec, rec]
|
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def simple_tuple_match(ref, ex, ignoreStopwords, ignoreCase):
|
|
||||||
ref.args = [ref.args[0], ' '.join(ref.args[1:])]
|
|
||||||
ex.args = [ex.args[0], ' '.join(ex.args[1:])]
|
|
||||||
|
|
||||||
precision = [0, 0] # 0 out of 0 predicted words match
|
|
||||||
recall = [0, 0] # 0 out of 0 reference words match
|
|
||||||
# If, for each part, any word is the same as a reference word, then it's a match.
|
|
||||||
|
|
||||||
predicted_words = ex.pred.split()
|
|
||||||
gold_words = ref.pred.split()
|
|
||||||
precision[1] += len(predicted_words)
|
|
||||||
recall[1] += len(gold_words)
|
|
||||||
|
|
||||||
matching_words = 0
|
|
||||||
for w in gold_words:
|
|
||||||
if w in predicted_words:
|
|
||||||
matching_words += 1
|
|
||||||
predicted_words.remove(w)
|
|
||||||
|
|
||||||
precision[0] += matching_words
|
|
||||||
recall[0] += matching_words
|
|
||||||
|
|
||||||
for i in range(len(ref.args)):
|
|
||||||
gold_words = ref.args[i].split()
|
|
||||||
recall[1] += len(gold_words)
|
|
||||||
if len(ex.args) <= i:
|
|
||||||
break
|
|
||||||
predicted_words = ex.args[i].split()
|
|
||||||
precision[1] += len(predicted_words)
|
|
||||||
matching_words = 0
|
|
||||||
for w in gold_words:
|
|
||||||
if w in predicted_words:
|
|
||||||
matching_words += 1
|
|
||||||
predicted_words.remove(w)
|
|
||||||
precision[0] += matching_words
|
|
||||||
|
|
||||||
# Currently this slightly penalises systems when the reference
|
|
||||||
# reformulates the sentence words, because the reformulation doesn't
|
|
||||||
# match the predicted word. It's a one-wrong-word penalty to precision,
|
|
||||||
# to all systems that correctly extracted the reformulated word.
|
|
||||||
recall[0] += matching_words
|
|
||||||
|
|
||||||
prec = 1.0 * precision[0] / precision[1]
|
|
||||||
rec = 1.0 * recall[0] / recall[1]
|
|
||||||
return [prec, rec]
|
|
||||||
|
|
||||||
# @staticmethod
|
|
||||||
# def binary_linient_tuple_match(ref, ex, ignoreStopwords, ignoreCase):
|
|
||||||
# if len(ref.args)>=2:
|
|
||||||
# # r = ref.copy()
|
|
||||||
# r = copy(ref)
|
|
||||||
# r.args = [ref.args[0], ' '.join(ref.args[1:])]
|
|
||||||
# else:
|
|
||||||
# r = ref
|
|
||||||
# if len(ex.args)>=2:
|
|
||||||
# # e = ex.copy()
|
|
||||||
# e = copy(ex)
|
|
||||||
# e.args = [ex.args[0], ' '.join(ex.args[1:])]
|
|
||||||
# else:
|
|
||||||
# e = ex
|
|
||||||
# return Matcher.linient_tuple_match(r, e, ignoreStopwords, ignoreCase)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def binary_linient_tuple_match(ref, ex, ignoreStopwords, ignoreCase):
|
|
||||||
if len(ref.args)>=2:
|
|
||||||
r = copy(ref)
|
|
||||||
r.args = [ref.args[0], ' '.join(ref.args[1:])]
|
|
||||||
else:
|
|
||||||
r = ref
|
|
||||||
if len(ex.args)>=2:
|
|
||||||
e = copy(ex)
|
|
||||||
e.args = [ex.args[0], ' '.join(ex.args[1:])]
|
|
||||||
else:
|
|
||||||
e = ex
|
|
||||||
stright_match = Matcher.linient_tuple_match(r, e, ignoreStopwords, ignoreCase)
|
|
||||||
|
|
||||||
said_type_reln = ["said", "told", "added", "adds", "says", "adds"]
|
|
||||||
said_type_sentence = False
|
|
||||||
for said_verb in said_type_reln:
|
|
||||||
if said_verb in ref.pred:
|
|
||||||
said_type_sentence = True
|
|
||||||
break
|
|
||||||
if not said_type_sentence:
|
|
||||||
return stright_match
|
|
||||||
else:
|
|
||||||
if len(ex.args)>=2:
|
|
||||||
e = copy(ex)
|
|
||||||
e.args = [' '.join(ex.args[1:]), ex.args[0]]
|
|
||||||
else:
|
|
||||||
e = ex
|
|
||||||
reverse_match = Matcher.linient_tuple_match(r, e, ignoreStopwords, ignoreCase)
|
|
||||||
|
|
||||||
return max(stright_match, reverse_match)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def binary_tuple_match(ref, ex, ignoreStopwords, ignoreCase):
|
|
||||||
if len(ref.args)>=2:
|
|
||||||
# r = ref.copy()
|
|
||||||
r = copy(ref)
|
|
||||||
r.args = [ref.args[0], ' '.join(ref.args[1:])]
|
|
||||||
else:
|
|
||||||
r = ref
|
|
||||||
if len(ex.args)>=2:
|
|
||||||
# e = ex.copy()
|
|
||||||
e = copy(ex)
|
|
||||||
e.args = [ex.args[0], ' '.join(ex.args[1:])]
|
|
||||||
else:
|
|
||||||
e = ex
|
|
||||||
return Matcher.tuple_match(r, e, ignoreStopwords, ignoreCase)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def removeStopwords(ls):
|
|
||||||
return [w for w in ls if w.lower() not in Matcher.stopwords]
|
|
||||||
|
|
||||||
# CONSTANTS
|
|
||||||
BLEU_THRESHOLD = 0.4
|
|
||||||
LEXICAL_THRESHOLD = 0.5 # Note: changing this value didn't change the ordering of the tested systems
|
|
||||||
stopwords = stopwords.words('english') + list(string.punctuation)
|
|
||||||
|
|
|
@ -1,45 +0,0 @@
|
||||||
class OieReader:
|
|
||||||
|
|
||||||
def read(self, fn, includeNominal):
|
|
||||||
''' should set oie as a class member
|
|
||||||
as a dictionary of extractions by sentence'''
|
|
||||||
raise Exception("Don't run me")
|
|
||||||
|
|
||||||
def count(self):
|
|
||||||
''' number of extractions '''
|
|
||||||
return sum([len(extractions) for _, extractions in self.oie.items()])
|
|
||||||
|
|
||||||
def split_to_corpus(self, corpus_fn, out_fn):
|
|
||||||
"""
|
|
||||||
Given a corpus file name, containing a list of sentences
|
|
||||||
print only the extractions pertaining to it to out_fn in a tab separated format:
|
|
||||||
sent, prob, pred, arg1, arg2, ...
|
|
||||||
"""
|
|
||||||
raw_sents = [line.strip() for line in open(corpus_fn)]
|
|
||||||
with open(out_fn, 'w') as fout:
|
|
||||||
for line in self.get_tabbed().split('\n'):
|
|
||||||
data = line.split('\t')
|
|
||||||
sent = data[0]
|
|
||||||
if sent in raw_sents:
|
|
||||||
fout.write(line + '\n')
|
|
||||||
|
|
||||||
def output_tabbed(self, out_fn):
|
|
||||||
"""
|
|
||||||
Write a tabbed represenation of this corpus.
|
|
||||||
"""
|
|
||||||
with open(out_fn, 'w') as fout:
|
|
||||||
fout.write(self.get_tabbed())
|
|
||||||
|
|
||||||
def get_tabbed(self):
|
|
||||||
"""
|
|
||||||
Get a tabbed format representation of this corpus (assumes that input was
|
|
||||||
already read).
|
|
||||||
"""
|
|
||||||
return "\n".join(['\t'.join(map(str,
|
|
||||||
[ex.sent,
|
|
||||||
ex.confidence,
|
|
||||||
ex.pred,
|
|
||||||
'\t'.join(ex.args)]))
|
|
||||||
for (sent, exs) in self.oie.iteritems()
|
|
||||||
for ex in exs])
|
|
||||||
|
|
|
@ -1,21 +0,0 @@
|
||||||
import nltk
|
|
||||||
from operator import itemgetter
|
|
||||||
|
|
||||||
class Argument:
|
|
||||||
def __init__(self, arg):
|
|
||||||
self.words = [x for x in arg[0].strip().split(' ') if x]
|
|
||||||
self.posTags = map(itemgetter(1), nltk.pos_tag(self.words))
|
|
||||||
self.indices = arg[1]
|
|
||||||
self.feats = {}
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return "({})".format('\t'.join(map(str,
|
|
||||||
[escape_special_chars(' '.join(self.words)),
|
|
||||||
str(self.indices)])))
|
|
||||||
|
|
||||||
COREF = 'coref'
|
|
||||||
|
|
||||||
## Helper functions
|
|
||||||
def escape_special_chars(s):
|
|
||||||
return s.replace('\t', '\\t')
|
|
||||||
|
|
|
@ -1,55 +0,0 @@
|
||||||
""" Usage:
|
|
||||||
benchmarkGoldReader --in=INPUT_FILE
|
|
||||||
|
|
||||||
Read a tab-formatted file.
|
|
||||||
Each line consists of:
|
|
||||||
sent, prob, pred, arg1, arg2, ...
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
from oie_readers.oieReader import OieReader
|
|
||||||
from oie_readers.extraction import Extraction
|
|
||||||
from docopt import docopt
|
|
||||||
import logging
|
|
||||||
|
|
||||||
logging.basicConfig(level = logging.DEBUG)
|
|
||||||
|
|
||||||
class BenchmarkGoldReader(OieReader):
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.name = 'BenchmarkGoldReader'
|
|
||||||
|
|
||||||
def read(self, fn):
|
|
||||||
"""
|
|
||||||
Read a tabbed format line
|
|
||||||
Each line consists of:
|
|
||||||
sent, prob, pred, arg1, arg2, ...
|
|
||||||
"""
|
|
||||||
d = {}
|
|
||||||
ex_index = 0
|
|
||||||
with open(fn) as fin:
|
|
||||||
for line in fin:
|
|
||||||
if not line.strip():
|
|
||||||
continue
|
|
||||||
data = line.strip().split('\t')
|
|
||||||
text, rel = data[:2]
|
|
||||||
curExtraction = Extraction(pred = rel.strip(),
|
|
||||||
head_pred_index = None,
|
|
||||||
sent = text.strip(),
|
|
||||||
confidence = 1.0,
|
|
||||||
question_dist = "./question_distributions/dist_wh_sbj_obj1.json",
|
|
||||||
index = ex_index)
|
|
||||||
ex_index += 1
|
|
||||||
|
|
||||||
for arg in data[2:]:
|
|
||||||
curExtraction.addArg(arg.strip())
|
|
||||||
|
|
||||||
d[text] = d.get(text, []) + [curExtraction]
|
|
||||||
self.oie = d
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
args = docopt(__doc__)
|
|
||||||
input_fn = args["--in"]
|
|
||||||
tr = BenchmarkGoldReader()
|
|
||||||
tr.read(input_fn)
|
|
|
@ -1,90 +0,0 @@
|
||||||
""" Usage:
|
|
||||||
<file-name> --in=INPUT_FILE --out=OUTPUT_FILE [--debug]
|
|
||||||
|
|
||||||
Convert to tabbed format
|
|
||||||
"""
|
|
||||||
# External imports
|
|
||||||
import logging
|
|
||||||
from pprint import pprint
|
|
||||||
from pprint import pformat
|
|
||||||
from docopt import docopt
|
|
||||||
|
|
||||||
# Local imports
|
|
||||||
from oie_readers.oieReader import OieReader
|
|
||||||
from oie_readers.extraction import Extraction
|
|
||||||
import ipdb
|
|
||||||
#=-----
|
|
||||||
|
|
||||||
class ClausieReader(OieReader):
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.name = 'ClausIE'
|
|
||||||
|
|
||||||
def read(self, fn):
|
|
||||||
d = {}
|
|
||||||
with open(fn, encoding="utf-8") as fin:
|
|
||||||
for line in fin:
|
|
||||||
data = line.strip().split('\t')
|
|
||||||
if len(data) == 1:
|
|
||||||
text = data[0]
|
|
||||||
elif len(data) == 5:
|
|
||||||
arg1, rel, arg2 = [s[1:-1] for s in data[1:4]]
|
|
||||||
confidence = data[4]
|
|
||||||
|
|
||||||
curExtraction = Extraction(pred = rel,
|
|
||||||
head_pred_index = -1,
|
|
||||||
sent = text,
|
|
||||||
confidence = float(confidence))
|
|
||||||
|
|
||||||
curExtraction.addArg(arg1)
|
|
||||||
curExtraction.addArg(arg2)
|
|
||||||
d[text] = d.get(text, []) + [curExtraction]
|
|
||||||
self.oie = d
|
|
||||||
# self.normalizeConfidence()
|
|
||||||
|
|
||||||
# # remove exxtractions below the confidence threshold
|
|
||||||
# if type(self.threshold) != type(None):
|
|
||||||
# new_d = {}
|
|
||||||
# for sent in self.oie:
|
|
||||||
# for extraction in self.oie[sent]:
|
|
||||||
# if extraction.confidence < self.threshold:
|
|
||||||
# continue
|
|
||||||
# else:
|
|
||||||
# new_d[sent] = new_d.get(sent, []) + [extraction]
|
|
||||||
# self.oie = new_d
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def normalizeConfidence(self):
|
|
||||||
''' Normalize confidence to resemble probabilities '''
|
|
||||||
EPSILON = 1e-3
|
|
||||||
|
|
||||||
confidences = [extraction.confidence for sent in self.oie for extraction in self.oie[sent]]
|
|
||||||
maxConfidence = max(confidences)
|
|
||||||
minConfidence = min(confidences)
|
|
||||||
|
|
||||||
denom = maxConfidence - minConfidence + (2*EPSILON)
|
|
||||||
|
|
||||||
for sent, extractions in self.oie.items():
|
|
||||||
for extraction in extractions:
|
|
||||||
extraction.confidence = ( (extraction.confidence - minConfidence) + EPSILON) / denom
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# Parse command line arguments
|
|
||||||
args = docopt(__doc__)
|
|
||||||
inp_fn = args["--in"]
|
|
||||||
out_fn = args["--out"]
|
|
||||||
debug = args["--debug"]
|
|
||||||
if debug:
|
|
||||||
logging.basicConfig(level = logging.DEBUG)
|
|
||||||
else:
|
|
||||||
logging.basicConfig(level = logging.INFO)
|
|
||||||
|
|
||||||
|
|
||||||
oie = ClausieReader()
|
|
||||||
oie.read(inp_fn)
|
|
||||||
oie.output_tabbed(out_fn)
|
|
||||||
|
|
||||||
logging.info("DONE")
|
|
|
@ -1,444 +0,0 @@
|
||||||
from sklearn.preprocessing.data import binarize
|
|
||||||
from oie_readers.argument import Argument
|
|
||||||
from operator import itemgetter
|
|
||||||
from collections import defaultdict
|
|
||||||
import nltk
|
|
||||||
import itertools
|
|
||||||
import logging
|
|
||||||
import numpy as np
|
|
||||||
import pdb
|
|
||||||
|
|
||||||
class Extraction:
|
|
||||||
"""
|
|
||||||
Stores sentence, single predicate and corresponding arguments.
|
|
||||||
"""
|
|
||||||
def __init__(self, pred, head_pred_index, sent, confidence, question_dist = '', index = -1):
|
|
||||||
self.pred = pred
|
|
||||||
self.head_pred_index = head_pred_index
|
|
||||||
self.sent = sent
|
|
||||||
self.args = []
|
|
||||||
self.confidence = confidence
|
|
||||||
self.matched = []
|
|
||||||
self.questions = {}
|
|
||||||
self.indsForQuestions = defaultdict(lambda: set())
|
|
||||||
self.is_mwp = False
|
|
||||||
self.question_dist = question_dist
|
|
||||||
self.index = index
|
|
||||||
|
|
||||||
def distArgFromPred(self, arg):
|
|
||||||
assert(len(self.pred) == 2)
|
|
||||||
dists = []
|
|
||||||
for x in self.pred[1]:
|
|
||||||
for y in arg.indices:
|
|
||||||
dists.append(abs(x - y))
|
|
||||||
|
|
||||||
return min(dists)
|
|
||||||
|
|
||||||
def argsByDistFromPred(self, question):
|
|
||||||
return sorted(self.questions[question], key = lambda arg: self.distArgFromPred(arg))
|
|
||||||
|
|
||||||
def addArg(self, arg, question = None):
|
|
||||||
self.args.append(arg)
|
|
||||||
if question:
|
|
||||||
self.questions[question] = self.questions.get(question,[]) + [Argument(arg)]
|
|
||||||
|
|
||||||
def noPronounArgs(self):
|
|
||||||
"""
|
|
||||||
Returns True iff all of this extraction's arguments are not pronouns.
|
|
||||||
"""
|
|
||||||
for (a, _) in self.args:
|
|
||||||
tokenized_arg = nltk.word_tokenize(a)
|
|
||||||
if len(tokenized_arg) == 1:
|
|
||||||
_, pos_tag = nltk.pos_tag(tokenized_arg)[0]
|
|
||||||
if ('PRP' in pos_tag):
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def isContiguous(self):
|
|
||||||
return all([indices for (_, indices) in self.args])
|
|
||||||
|
|
||||||
def toBinary(self):
|
|
||||||
''' Try to represent this extraction's arguments as binary
|
|
||||||
If fails, this function will return an empty list. '''
|
|
||||||
|
|
||||||
ret = [self.elementToStr(self.pred)]
|
|
||||||
|
|
||||||
if len(self.args) == 2:
|
|
||||||
# we're in luck
|
|
||||||
return ret + [self.elementToStr(arg) for arg in self.args]
|
|
||||||
|
|
||||||
return []
|
|
||||||
|
|
||||||
if not self.isContiguous():
|
|
||||||
# give up on non contiguous arguments (as we need indexes)
|
|
||||||
return []
|
|
||||||
|
|
||||||
# otherwise, try to merge based on indices
|
|
||||||
# TODO: you can explore other methods for doing this
|
|
||||||
binarized = self.binarizeByIndex()
|
|
||||||
|
|
||||||
if binarized:
|
|
||||||
return ret + binarized
|
|
||||||
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
def elementToStr(self, elem, print_indices = True):
|
|
||||||
''' formats an extraction element (pred or arg) as a raw string
|
|
||||||
removes indices and trailing spaces '''
|
|
||||||
if print_indices:
|
|
||||||
return str(elem)
|
|
||||||
if isinstance(elem, str):
|
|
||||||
return elem
|
|
||||||
if isinstance(elem, tuple):
|
|
||||||
ret = elem[0].rstrip().lstrip()
|
|
||||||
else:
|
|
||||||
ret = ' '.join(elem.words)
|
|
||||||
assert ret, "empty element? {0}".format(elem)
|
|
||||||
return ret
|
|
||||||
|
|
||||||
def binarizeByIndex(self):
|
|
||||||
extraction = [self.pred] + self.args
|
|
||||||
markPred = [(w, ind, i == 0) for i, (w, ind) in enumerate(extraction)]
|
|
||||||
sortedExtraction = sorted(markPred, key = lambda ws, indices, f : indices[0])
|
|
||||||
s = ' '.join(['{1} {0} {1}'.format(self.elementToStr(elem), SEP) if elem[2] else self.elementToStr(elem) for elem in sortedExtraction])
|
|
||||||
binArgs = [a for a in s.split(SEP) if a.rstrip().lstrip()]
|
|
||||||
|
|
||||||
if len(binArgs) == 2:
|
|
||||||
return binArgs
|
|
||||||
|
|
||||||
# failure
|
|
||||||
return []
|
|
||||||
|
|
||||||
def bow(self):
|
|
||||||
return ' '.join([self.elementToStr(elem) for elem in [self.pred] + self.args])
|
|
||||||
|
|
||||||
def getSortedArgs(self):
|
|
||||||
"""
|
|
||||||
Sort the list of arguments.
|
|
||||||
If a question distribution is provided - use it,
|
|
||||||
otherwise, default to the order of appearance in the sentence.
|
|
||||||
"""
|
|
||||||
if self.question_dist:
|
|
||||||
# There's a question distribtuion - use it
|
|
||||||
return self.sort_args_by_distribution()
|
|
||||||
ls = []
|
|
||||||
for q, args in self.questions.iteritems():
|
|
||||||
if (len(args) != 1):
|
|
||||||
logging.debug("Not one argument: {}".format(args))
|
|
||||||
continue
|
|
||||||
arg = args[0]
|
|
||||||
indices = list(self.indsForQuestions[q].union(arg.indices))
|
|
||||||
if not indices:
|
|
||||||
logging.debug("Empty indexes for arg {} -- backing to zero".format(arg))
|
|
||||||
indices = [0]
|
|
||||||
ls.append(((arg, q), indices))
|
|
||||||
return [a for a, _ in sorted(ls,
|
|
||||||
key = lambda _, indices: min(indices))]
|
|
||||||
|
|
||||||
def question_prob_for_loc(self, question, loc):
|
|
||||||
"""
|
|
||||||
Returns the probability of the given question leading to argument
|
|
||||||
appearing in the given location in the output slot.
|
|
||||||
"""
|
|
||||||
gen_question = generalize_question(question)
|
|
||||||
q_dist = self.question_dist[gen_question]
|
|
||||||
logging.debug("distribution of {}: {}".format(gen_question,
|
|
||||||
q_dist))
|
|
||||||
|
|
||||||
return float(q_dist.get(loc, 0)) / \
|
|
||||||
sum(q_dist.values())
|
|
||||||
|
|
||||||
def sort_args_by_distribution(self):
|
|
||||||
"""
|
|
||||||
Use this instance's question distribution (this func assumes it exists)
|
|
||||||
in determining the positioning of the arguments.
|
|
||||||
Greedy algorithm:
|
|
||||||
0. Decide on which argument will serve as the ``subject'' (first slot) of this extraction
|
|
||||||
0.1 Based on the most probable one for this spot
|
|
||||||
(special care is given to select the highly-influential subject position)
|
|
||||||
1. For all other arguments, sort arguments by the prevalance of their questions
|
|
||||||
2. For each argument:
|
|
||||||
2.1 Assign to it the most probable slot still available
|
|
||||||
2.2 If non such exist (fallback) - default to put it in the last location
|
|
||||||
"""
|
|
||||||
INF_LOC = 100 # Used as an impractical last argument
|
|
||||||
|
|
||||||
# Store arguments by slot
|
|
||||||
ret = {INF_LOC: []}
|
|
||||||
logging.debug("sorting: {}".format(self.questions))
|
|
||||||
|
|
||||||
# Find the most suitable arguemnt for the subject location
|
|
||||||
logging.debug("probs for subject: {}".format([(q, self.question_prob_for_loc(q, 0))
|
|
||||||
for (q, _) in self.questions.iteritems()]))
|
|
||||||
|
|
||||||
subj_question, subj_args = max(self.questions.iteritems(),
|
|
||||||
key = lambda q, _: self.question_prob_for_loc(q, 0))
|
|
||||||
|
|
||||||
ret[0] = [(subj_args[0], subj_question)]
|
|
||||||
|
|
||||||
# Find the rest
|
|
||||||
for (question, args) in sorted([(q, a)
|
|
||||||
for (q, a) in self.questions.iteritems() if (q not in [subj_question])],
|
|
||||||
key = lambda q, _: \
|
|
||||||
sum(self.question_dist[generalize_question(q)].values()),
|
|
||||||
reverse = True):
|
|
||||||
gen_question = generalize_question(question)
|
|
||||||
arg = args[0]
|
|
||||||
assigned_flag = False
|
|
||||||
for (loc, count) in sorted(self.question_dist[gen_question].iteritems(),
|
|
||||||
key = lambda _ , c: c,
|
|
||||||
reverse = True):
|
|
||||||
if loc not in ret:
|
|
||||||
# Found an empty slot for this item
|
|
||||||
# Place it there and break out
|
|
||||||
ret[loc] = [(arg, question)]
|
|
||||||
assigned_flag = True
|
|
||||||
break
|
|
||||||
|
|
||||||
if not assigned_flag:
|
|
||||||
# Add this argument to the non-assigned (hopefully doesn't happen much)
|
|
||||||
logging.debug("Couldn't find an open assignment for {}".format((arg, gen_question)))
|
|
||||||
ret[INF_LOC].append((arg, question))
|
|
||||||
|
|
||||||
logging.debug("Linearizing arg list: {}".format(ret))
|
|
||||||
|
|
||||||
# Finished iterating - consolidate and return a list of arguments
|
|
||||||
return [arg
|
|
||||||
for (_, arg_ls) in sorted(ret.iteritems(),
|
|
||||||
key = lambda k, v: int(k))
|
|
||||||
for arg in arg_ls]
|
|
||||||
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
pred_str = self.elementToStr(self.pred)
|
|
||||||
return '{}\t{}\t{}'.format(self.get_base_verb(pred_str),
|
|
||||||
self.compute_global_pred(pred_str,
|
|
||||||
self.questions.keys()),
|
|
||||||
'\t'.join([escape_special_chars(self.augment_arg_with_question(self.elementToStr(arg),
|
|
||||||
question))
|
|
||||||
for arg, question in self.getSortedArgs()]))
|
|
||||||
|
|
||||||
def get_base_verb(self, surface_pred):
|
|
||||||
"""
|
|
||||||
Given the surface pred, return the original annotated verb
|
|
||||||
"""
|
|
||||||
# Assumes that at this point the verb is always the last word
|
|
||||||
# in the surface predicate
|
|
||||||
return surface_pred.split(' ')[-1]
|
|
||||||
|
|
||||||
|
|
||||||
def compute_global_pred(self, surface_pred, questions):
|
|
||||||
"""
|
|
||||||
Given the surface pred and all instansiations of questions,
|
|
||||||
make global coherence decisions regarding the final form of the predicate
|
|
||||||
This should hopefully take care of multi word predicates and correct inflections
|
|
||||||
"""
|
|
||||||
from operator import itemgetter
|
|
||||||
split_surface = surface_pred.split(' ')
|
|
||||||
|
|
||||||
if len(split_surface) > 1:
|
|
||||||
# This predicate has a modal preceding the base verb
|
|
||||||
verb = split_surface[-1]
|
|
||||||
ret = split_surface[:-1] # get all of the elements in the modal
|
|
||||||
else:
|
|
||||||
verb = split_surface[0]
|
|
||||||
ret = []
|
|
||||||
|
|
||||||
split_questions = map(lambda question: question.split(' '),
|
|
||||||
questions)
|
|
||||||
|
|
||||||
preds = map(normalize_element,
|
|
||||||
map(itemgetter(QUESTION_TRG_INDEX),
|
|
||||||
split_questions))
|
|
||||||
if len(set(preds)) > 1:
|
|
||||||
# This predicate is appears in multiple ways, let's stick to the base form
|
|
||||||
ret.append(verb)
|
|
||||||
|
|
||||||
if len(set(preds)) == 1:
|
|
||||||
# Change the predciate to the inflected form
|
|
||||||
# if there's exactly one way in which the predicate is conveyed
|
|
||||||
ret.append(preds[0])
|
|
||||||
|
|
||||||
pps = map(normalize_element,
|
|
||||||
map(itemgetter(QUESTION_PP_INDEX),
|
|
||||||
split_questions))
|
|
||||||
|
|
||||||
obj2s = map(normalize_element,
|
|
||||||
map(itemgetter(QUESTION_OBJ2_INDEX),
|
|
||||||
split_questions))
|
|
||||||
|
|
||||||
if (len(set(pps)) == 1):
|
|
||||||
# If all questions for the predicate include the same pp attachemnt -
|
|
||||||
# assume it's a multiword predicate
|
|
||||||
self.is_mwp = True # Signal to arguments that they shouldn't take the preposition
|
|
||||||
ret.append(pps[0])
|
|
||||||
|
|
||||||
# Concat all elements in the predicate and return
|
|
||||||
return " ".join(ret).strip()
|
|
||||||
|
|
||||||
|
|
||||||
def augment_arg_with_question(self, arg, question):
|
|
||||||
"""
|
|
||||||
Decide what elements from the question to incorporate in the given
|
|
||||||
corresponding argument
|
|
||||||
"""
|
|
||||||
# Parse question
|
|
||||||
wh, aux, sbj, trg, obj1, pp, obj2 = map(normalize_element,
|
|
||||||
question.split(' ')[:-1]) # Last split is the question mark
|
|
||||||
|
|
||||||
# Place preposition in argument
|
|
||||||
# This is safer when dealing with n-ary arguments, as it's directly attaches to the
|
|
||||||
# appropriate argument
|
|
||||||
if (not self.is_mwp) and pp and (not obj2):
|
|
||||||
if not(arg.startswith("{} ".format(pp))):
|
|
||||||
# Avoid repeating the preporition in cases where both question and answer contain it
|
|
||||||
return " ".join([pp,
|
|
||||||
arg])
|
|
||||||
|
|
||||||
# Normal cases
|
|
||||||
return arg
|
|
||||||
|
|
||||||
def clusterScore(self, cluster):
|
|
||||||
"""
|
|
||||||
Calculate cluster density score as the mean distance of the maximum distance of each slot.
|
|
||||||
Lower score represents a denser cluster.
|
|
||||||
"""
|
|
||||||
logging.debug("*-*-*- Cluster: {}".format(cluster))
|
|
||||||
|
|
||||||
# Find global centroid
|
|
||||||
arr = np.array([x for ls in cluster for x in ls])
|
|
||||||
centroid = np.sum(arr)/arr.shape[0]
|
|
||||||
logging.debug("Centroid: {}".format(centroid))
|
|
||||||
|
|
||||||
# Calculate mean over all maxmimum points
|
|
||||||
return np.average([max([abs(x - centroid) for x in ls]) for ls in cluster])
|
|
||||||
|
|
||||||
def resolveAmbiguity(self):
|
|
||||||
"""
|
|
||||||
Heursitic to map the elments (argument and predicates) of this extraction
|
|
||||||
back to the indices of the sentence.
|
|
||||||
"""
|
|
||||||
## TODO: This removes arguments for which there was no consecutive span found
|
|
||||||
## Part of these are non-consecutive arguments,
|
|
||||||
## but other could be a bug in recognizing some punctuation marks
|
|
||||||
|
|
||||||
elements = [self.pred] \
|
|
||||||
+ [(s, indices)
|
|
||||||
for (s, indices)
|
|
||||||
in self.args
|
|
||||||
if indices]
|
|
||||||
logging.debug("Resolving ambiguity in: {}".format(elements))
|
|
||||||
|
|
||||||
# Collect all possible combinations of arguments and predicate indices
|
|
||||||
# (hopefully it's not too much)
|
|
||||||
all_combinations = list(itertools.product(*map(itemgetter(1), elements)))
|
|
||||||
logging.debug("Number of combinations: {}".format(len(all_combinations)))
|
|
||||||
|
|
||||||
# Choose the ones with best clustering and unfold them
|
|
||||||
resolved_elements = zip(map(itemgetter(0), elements),
|
|
||||||
min(all_combinations,
|
|
||||||
key = lambda cluster: self.clusterScore(cluster)))
|
|
||||||
logging.debug("Resolved elements = {}".format(resolved_elements))
|
|
||||||
|
|
||||||
self.pred = resolved_elements[0]
|
|
||||||
self.args = resolved_elements[1:]
|
|
||||||
|
|
||||||
def conll(self, external_feats = {}):
|
|
||||||
"""
|
|
||||||
Return a CoNLL string representation of this extraction
|
|
||||||
"""
|
|
||||||
return '\n'.join(["\t".join(map(str,
|
|
||||||
[i, w] + \
|
|
||||||
list(self.pred) + \
|
|
||||||
[self.head_pred_index] + \
|
|
||||||
external_feats + \
|
|
||||||
[self.get_label(i)]))
|
|
||||||
for (i, w)
|
|
||||||
in enumerate(self.sent.split(" "))]) + '\n'
|
|
||||||
|
|
||||||
def get_label(self, index):
|
|
||||||
"""
|
|
||||||
Given an index of a word in the sentence -- returns the appropriate BIO conll label
|
|
||||||
Assumes that ambiguation was already resolved.
|
|
||||||
"""
|
|
||||||
# Get the element(s) in which this index appears
|
|
||||||
ent = [(elem_ind, elem)
|
|
||||||
for (elem_ind, elem)
|
|
||||||
in enumerate(map(itemgetter(1),
|
|
||||||
[self.pred] + self.args))
|
|
||||||
if index in elem]
|
|
||||||
|
|
||||||
if not ent:
|
|
||||||
# index doesnt appear in any element
|
|
||||||
return "O"
|
|
||||||
|
|
||||||
if len(ent) > 1:
|
|
||||||
# The same word appears in two different answers
|
|
||||||
# In this case we choose the first one as label
|
|
||||||
logging.warn("Index {} appears in one than more element: {}".\
|
|
||||||
format(index,
|
|
||||||
"\t".join(map(str,
|
|
||||||
[ent,
|
|
||||||
self.sent,
|
|
||||||
self.pred,
|
|
||||||
self.args]))))
|
|
||||||
|
|
||||||
## Some indices appear in more than one argument (ones where the above message appears)
|
|
||||||
## From empricial observation, these seem to mostly consist of different levels of granularity:
|
|
||||||
## what had _ been taken _ _ _ ? loan commitments topping $ 3 billion
|
|
||||||
## how much had _ been taken _ _ _ ? topping $ 3 billion
|
|
||||||
## In these cases we heuristically choose the shorter answer span, hopefully creating minimal spans
|
|
||||||
## E.g., in this example two arguemnts are created: (loan commitments, topping $ 3 billion)
|
|
||||||
|
|
||||||
elem_ind, elem = min(ent, key = lambda _, ls: len(ls))
|
|
||||||
|
|
||||||
# Distinguish between predicate and arguments
|
|
||||||
prefix = "P" if elem_ind == 0 else "A{}".format(elem_ind - 1)
|
|
||||||
|
|
||||||
# Distinguish between Beginning and Inside labels
|
|
||||||
suffix = "B" if index == elem[0] else "I"
|
|
||||||
|
|
||||||
return "{}-{}".format(prefix, suffix)
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return '{0}\t{1}'.format(self.elementToStr(self.pred,
|
|
||||||
print_indices = True),
|
|
||||||
'\t'.join([self.elementToStr(arg)
|
|
||||||
for arg
|
|
||||||
in self.args]))
|
|
||||||
|
|
||||||
# Flatten a list of lists
|
|
||||||
flatten = lambda l: [item for sublist in l for item in sublist]
|
|
||||||
|
|
||||||
|
|
||||||
def normalize_element(elem):
|
|
||||||
"""
|
|
||||||
Return a surface form of the given question element.
|
|
||||||
the output should be properly able to precede a predicate (or blank otherwise)
|
|
||||||
"""
|
|
||||||
return elem.replace("_", " ") \
|
|
||||||
if (elem != "_")\
|
|
||||||
else ""
|
|
||||||
|
|
||||||
## Helper functions
|
|
||||||
def escape_special_chars(s):
|
|
||||||
return s.replace('\t', '\\t')
|
|
||||||
|
|
||||||
|
|
||||||
def generalize_question(question):
|
|
||||||
"""
|
|
||||||
Given a question in the context of the sentence and the predicate index within
|
|
||||||
the question - return a generalized version which extracts only order-imposing features
|
|
||||||
"""
|
|
||||||
import nltk # Using nltk since couldn't get spaCy to agree on the tokenization
|
|
||||||
wh, aux, sbj, trg, obj1, pp, obj2 = question.split(' ')[:-1] # Last split is the question mark
|
|
||||||
return ' '.join([wh, sbj, obj1])
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## CONSTANTS
|
|
||||||
SEP = ';;;'
|
|
||||||
QUESTION_TRG_INDEX = 3 # index of the predicate within the question
|
|
||||||
QUESTION_PP_INDEX = 5
|
|
||||||
QUESTION_OBJ2_INDEX = 6
|
|
|
@ -1,44 +0,0 @@
|
||||||
from oie_readers.oieReader import OieReader
|
|
||||||
from oie_readers.extraction import Extraction
|
|
||||||
from _collections import defaultdict
|
|
||||||
import ipdb
|
|
||||||
|
|
||||||
class GoldReader(OieReader):
|
|
||||||
|
|
||||||
# Path relative to repo root folder
|
|
||||||
default_filename = './oie_corpus/all.oie'
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.name = 'Gold'
|
|
||||||
|
|
||||||
def read(self, fn):
|
|
||||||
d = defaultdict(lambda: [])
|
|
||||||
with open(fn) as fin:
|
|
||||||
for line_ind, line in enumerate(fin):
|
|
||||||
# print line
|
|
||||||
data = line.strip().split('\t')
|
|
||||||
text, rel = data[:2]
|
|
||||||
args = data[2:]
|
|
||||||
confidence = 1
|
|
||||||
|
|
||||||
curExtraction = Extraction(pred = rel.strip(),
|
|
||||||
head_pred_index = None,
|
|
||||||
sent = text.strip(),
|
|
||||||
confidence = float(confidence),
|
|
||||||
index = line_ind)
|
|
||||||
for arg in args:
|
|
||||||
if "C: " in arg:
|
|
||||||
continue
|
|
||||||
curExtraction.addArg(arg.strip())
|
|
||||||
|
|
||||||
d[text.strip()].append(curExtraction)
|
|
||||||
self.oie = d
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__' :
|
|
||||||
g = GoldReader()
|
|
||||||
g.read('../oie_corpus/all.oie', includeNominal = False)
|
|
||||||
d = g.oie
|
|
||||||
e = d.items()[0]
|
|
||||||
print(e[1][0].bow())
|
|
||||||
print(g.count())
|
|
|
@ -1,45 +0,0 @@
|
||||||
class OieReader:
|
|
||||||
|
|
||||||
def read(self, fn, includeNominal):
|
|
||||||
''' should set oie as a class member
|
|
||||||
as a dictionary of extractions by sentence'''
|
|
||||||
raise Exception("Don't run me")
|
|
||||||
|
|
||||||
def count(self):
|
|
||||||
''' number of extractions '''
|
|
||||||
return sum([len(extractions) for _, extractions in self.oie.items()])
|
|
||||||
|
|
||||||
def split_to_corpus(self, corpus_fn, out_fn):
|
|
||||||
"""
|
|
||||||
Given a corpus file name, containing a list of sentences
|
|
||||||
print only the extractions pertaining to it to out_fn in a tab separated format:
|
|
||||||
sent, prob, pred, arg1, arg2, ...
|
|
||||||
"""
|
|
||||||
raw_sents = [line.strip() for line in open(corpus_fn)]
|
|
||||||
with open(out_fn, 'w') as fout:
|
|
||||||
for line in self.get_tabbed().split('\n'):
|
|
||||||
data = line.split('\t')
|
|
||||||
sent = data[0]
|
|
||||||
if sent in raw_sents:
|
|
||||||
fout.write(line + '\n')
|
|
||||||
|
|
||||||
def output_tabbed(self, out_fn):
|
|
||||||
"""
|
|
||||||
Write a tabbed represenation of this corpus.
|
|
||||||
"""
|
|
||||||
with open(out_fn, 'w') as fout:
|
|
||||||
fout.write(self.get_tabbed())
|
|
||||||
|
|
||||||
def get_tabbed(self):
|
|
||||||
"""
|
|
||||||
Get a tabbed format representation of this corpus (assumes that input was
|
|
||||||
already read).
|
|
||||||
"""
|
|
||||||
return "\n".join(['\t'.join(map(str,
|
|
||||||
[ex.sent,
|
|
||||||
ex.confidence,
|
|
||||||
ex.pred,
|
|
||||||
'\t'.join(ex.args)]))
|
|
||||||
for (sent, exs) in self.oie.iteritems()
|
|
||||||
for ex in exs])
|
|
||||||
|
|
|
@ -1,22 +0,0 @@
|
||||||
from oie_readers.oieReader import OieReader
|
|
||||||
from oie_readers.extraction import Extraction
|
|
||||||
|
|
||||||
class OllieReader(OieReader):
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.name = 'OLLIE'
|
|
||||||
|
|
||||||
def read(self, fn):
|
|
||||||
d = {}
|
|
||||||
with open(fn) as fin:
|
|
||||||
fin.readline() #remove header
|
|
||||||
for line in fin:
|
|
||||||
data = line.strip().split('\t')
|
|
||||||
confidence, arg1, rel, arg2, enabler, attribution, text = data[:7]
|
|
||||||
curExtraction = Extraction(pred = rel, head_pred_index = -1, sent = text, confidence = float(confidence))
|
|
||||||
curExtraction.addArg(arg1)
|
|
||||||
curExtraction.addArg(arg2)
|
|
||||||
d[text] = d.get(text, []) + [curExtraction]
|
|
||||||
self.oie = d
|
|
||||||
|
|
||||||
|
|
|
@ -1,38 +0,0 @@
|
||||||
from oie_readers.oieReader import OieReader
|
|
||||||
from oie_readers.extraction import Extraction
|
|
||||||
|
|
||||||
class OpenieFiveReader(OieReader):
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.name = 'OpenIE-5'
|
|
||||||
|
|
||||||
def read(self, fn):
|
|
||||||
d = {}
|
|
||||||
with open(fn) as fin:
|
|
||||||
for line in fin:
|
|
||||||
data = line.strip().split('\t')
|
|
||||||
confidence = data[0]
|
|
||||||
|
|
||||||
if not all(data[2:5]):
|
|
||||||
continue
|
|
||||||
arg1, rel = [s[s.index('(') + 1:s.index(',List(')] for s in data[2:4]]
|
|
||||||
#args = data[4].strip().split(');')
|
|
||||||
#print arg2s
|
|
||||||
args = [s[s.index('(') + 1:s.index(',List(')] for s in data[4].strip().split(');')]
|
|
||||||
# if arg1 == "the younger La Flesche":
|
|
||||||
# print len(args)
|
|
||||||
text = data[5]
|
|
||||||
if data[1]:
|
|
||||||
#print arg1, rel
|
|
||||||
s = data[1]
|
|
||||||
if not (arg1 + ' ' + rel).startswith(s[s.index('(') + 1:s.index(',List(')]):
|
|
||||||
#print "##########Not adding context"
|
|
||||||
arg1 = s[s.index('(') + 1:s.index(',List(')] + ' ' + arg1
|
|
||||||
#print arg1 + rel, ",,,,, ", s[s.index('(') + 1:s.index(',List(')]
|
|
||||||
#curExtraction = Extraction(pred = rel, sent = text, confidence = float(confidence))
|
|
||||||
curExtraction = Extraction(pred = rel, head_pred_index = -1, sent = text, confidence = float(confidence))
|
|
||||||
curExtraction.addArg(arg1)
|
|
||||||
for arg in args:
|
|
||||||
curExtraction.addArg(arg)
|
|
||||||
d[text] = d.get(text, []) + [curExtraction]
|
|
||||||
self.oie = d
|
|
|
@ -1,59 +0,0 @@
|
||||||
""" Usage:
|
|
||||||
<file-name> --in=INPUT_FILE --out=OUTPUT_FILE [--debug]
|
|
||||||
|
|
||||||
Convert to tabbed format
|
|
||||||
"""
|
|
||||||
# External imports
|
|
||||||
import logging
|
|
||||||
from pprint import pprint
|
|
||||||
from pprint import pformat
|
|
||||||
from docopt import docopt
|
|
||||||
|
|
||||||
# Local imports
|
|
||||||
from oie_readers.oieReader import OieReader
|
|
||||||
from oie_readers.extraction import Extraction
|
|
||||||
import ipdb
|
|
||||||
|
|
||||||
#=-----
|
|
||||||
|
|
||||||
class OpenieFourReader(OieReader):
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.name = 'OpenIE-4'
|
|
||||||
|
|
||||||
def read(self, fn):
|
|
||||||
d = {}
|
|
||||||
with open(fn) as fin:
|
|
||||||
for line in fin:
|
|
||||||
data = line.strip().split('\t')
|
|
||||||
confidence = data[0]
|
|
||||||
if not all(data[2:5]):
|
|
||||||
logging.debug("Skipped line: {}".format(line))
|
|
||||||
continue
|
|
||||||
arg1, rel, arg2 = [s[s.index('(') + 1:s.index(',List(')] for s in data[2:5]]
|
|
||||||
text = data[5]
|
|
||||||
curExtraction = Extraction(pred = rel, head_pred_index = -1, sent = text, confidence = float(confidence))
|
|
||||||
curExtraction.addArg(arg1)
|
|
||||||
curExtraction.addArg(arg2)
|
|
||||||
d[text] = d.get(text, []) + [curExtraction]
|
|
||||||
self.oie = d
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# Parse command line arguments
|
|
||||||
args = docopt(__doc__)
|
|
||||||
inp_fn = args["--in"]
|
|
||||||
out_fn = args["--out"]
|
|
||||||
debug = args["--debug"]
|
|
||||||
if debug:
|
|
||||||
logging.basicConfig(level = logging.DEBUG)
|
|
||||||
else:
|
|
||||||
logging.basicConfig(level = logging.INFO)
|
|
||||||
|
|
||||||
|
|
||||||
oie = OpenieFourReader()
|
|
||||||
oie.read(inp_fn)
|
|
||||||
oie.output_tabbed(out_fn)
|
|
||||||
|
|
||||||
logging.info("DONE")
|
|
|
@ -1,44 +0,0 @@
|
||||||
from oie_readers.oieReader import OieReader
|
|
||||||
from oie_readers.extraction import Extraction
|
|
||||||
|
|
||||||
|
|
||||||
class PropSReader(OieReader):
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.name = 'PropS'
|
|
||||||
|
|
||||||
def read(self, fn):
|
|
||||||
d = {}
|
|
||||||
with open(fn) as fin:
|
|
||||||
for line in fin:
|
|
||||||
if not line.strip():
|
|
||||||
continue
|
|
||||||
data = line.strip().split('\t')
|
|
||||||
confidence, text, rel = data[:3]
|
|
||||||
curExtraction = Extraction(pred = rel, sent = text, confidence = float(confidence), head_pred_index=-1)
|
|
||||||
|
|
||||||
for arg in data[4::2]:
|
|
||||||
curExtraction.addArg(arg)
|
|
||||||
|
|
||||||
d[text] = d.get(text, []) + [curExtraction]
|
|
||||||
self.oie = d
|
|
||||||
# self.normalizeConfidence()
|
|
||||||
|
|
||||||
|
|
||||||
def normalizeConfidence(self):
|
|
||||||
''' Normalize confidence to resemble probabilities '''
|
|
||||||
EPSILON = 1e-3
|
|
||||||
|
|
||||||
self.confidences = [extraction.confidence for sent in self.oie for extraction in self.oie[sent]]
|
|
||||||
maxConfidence = max(self.confidences)
|
|
||||||
minConfidence = min(self.confidences)
|
|
||||||
|
|
||||||
denom = maxConfidence - minConfidence + (2*EPSILON)
|
|
||||||
|
|
||||||
for sent, extractions in self.oie.items():
|
|
||||||
for extraction in extractions:
|
|
||||||
extraction.confidence = ( (extraction.confidence - minConfidence) + EPSILON) / denom
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,29 +0,0 @@
|
||||||
from oie_readers.oieReader import OieReader
|
|
||||||
from oie_readers.extraction import Extraction
|
|
||||||
|
|
||||||
class ReVerbReader(OieReader):
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.inputSents = [sent.strip() for sent in open(ReVerbReader.RAW_SENTS_FILE).readlines()]
|
|
||||||
self.name = 'ReVerb'
|
|
||||||
|
|
||||||
def read(self, fn):
|
|
||||||
d = {}
|
|
||||||
with open(fn) as fin:
|
|
||||||
for line in fin:
|
|
||||||
data = line.strip().split('\t')
|
|
||||||
arg1, rel, arg2 = data[2:5]
|
|
||||||
confidence = data[11]
|
|
||||||
text = self.inputSents[int(data[1])-1]
|
|
||||||
|
|
||||||
curExtraction = Extraction(pred = rel, sent = text, confidence = float(confidence))
|
|
||||||
curExtraction.addArg(arg1)
|
|
||||||
curExtraction.addArg(arg2)
|
|
||||||
d[text] = d.get(text, []) + [curExtraction]
|
|
||||||
self.oie = d
|
|
||||||
|
|
||||||
# ReVerb requires a different files from which to get the input sentences
|
|
||||||
# Relative to repo root folder
|
|
||||||
RAW_SENTS_FILE = './raw_sentences/all.txt'
|
|
||||||
|
|
||||||
|
|
|
@ -1,37 +0,0 @@
|
||||||
""" Usage:
|
|
||||||
split_corpus --corpus=CORPUS_FN --reader=READER --in=INPUT_FN --out=OUTPUT_FN
|
|
||||||
|
|
||||||
Split OIE extractions according to raw sentences.
|
|
||||||
This is used in order to split a large file into train, dev and test.
|
|
||||||
|
|
||||||
READER - points out which oie reader to use (see dictionary for possible entries)
|
|
||||||
"""
|
|
||||||
from clausieReader import ClausieReader
|
|
||||||
from ollieReader import OllieReader
|
|
||||||
from openieFourReader import OpenieFourReader
|
|
||||||
from propsReader import PropSReader
|
|
||||||
from reVerbReader import ReVerbReader
|
|
||||||
from stanfordReader import StanfordReader
|
|
||||||
from docopt import docopt
|
|
||||||
import logging
|
|
||||||
logging.basicConfig(level = logging.INFO)
|
|
||||||
|
|
||||||
available_readers = {
|
|
||||||
"clausie": ClausieReader,
|
|
||||||
"ollie": OllieReader,
|
|
||||||
"openie4": OpenieFourReader,
|
|
||||||
"props": PropSReader,
|
|
||||||
"reverb": ReVerbReader,
|
|
||||||
"stanford": StanfordReader
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
args = docopt(__doc__)
|
|
||||||
inp = args["--in"]
|
|
||||||
out = args["--out"]
|
|
||||||
corpus = args["--corpus"]
|
|
||||||
reader = available_readers[args["--reader"]]()
|
|
||||||
reader.read(inp)
|
|
||||||
reader.split_to_corpus(corpus,
|
|
||||||
out)
|
|
|
@ -1,22 +0,0 @@
|
||||||
from oie_readers.oieReader import OieReader
|
|
||||||
from oie_readers.extraction import Extraction
|
|
||||||
|
|
||||||
class StanfordReader(OieReader):
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.name = 'Stanford'
|
|
||||||
|
|
||||||
def read(self, fn):
|
|
||||||
d = {}
|
|
||||||
with open(fn) as fin:
|
|
||||||
for line in fin:
|
|
||||||
data = line.strip().split('\t')
|
|
||||||
arg1, rel, arg2 = data[2:5]
|
|
||||||
confidence = data[11]
|
|
||||||
text = data[12]
|
|
||||||
|
|
||||||
curExtraction = Extraction(pred = rel, head_pred_index = -1, sent = text, confidence = float(confidence))
|
|
||||||
curExtraction.addArg(arg1)
|
|
||||||
curExtraction.addArg(arg2)
|
|
||||||
d[text] = d.get(text, []) + [curExtraction]
|
|
||||||
self.oie = d
|
|
|
@ -1,56 +0,0 @@
|
||||||
""" Usage:
|
|
||||||
tabReader --in=INPUT_FILE
|
|
||||||
|
|
||||||
Read a tab-formatted file.
|
|
||||||
Each line consists of:
|
|
||||||
sent, prob, pred, arg1, arg2, ...
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
from oie_readers.oieReader import OieReader
|
|
||||||
from oie_readers.extraction import Extraction
|
|
||||||
from docopt import docopt
|
|
||||||
import logging
|
|
||||||
import ipdb
|
|
||||||
|
|
||||||
logging.basicConfig(level = logging.DEBUG)
|
|
||||||
|
|
||||||
class TabReader(OieReader):
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.name = 'TabReader'
|
|
||||||
|
|
||||||
def read(self, fn):
|
|
||||||
"""
|
|
||||||
Read a tabbed format line
|
|
||||||
Each line consists of:
|
|
||||||
sent, prob, pred, arg1, arg2, ...
|
|
||||||
"""
|
|
||||||
d = {}
|
|
||||||
ex_index = 0
|
|
||||||
with open(fn) as fin:
|
|
||||||
for line in fin:
|
|
||||||
if not line.strip():
|
|
||||||
continue
|
|
||||||
data = line.strip().split('\t')
|
|
||||||
text, confidence, rel = data[:3]
|
|
||||||
curExtraction = Extraction(pred = rel,
|
|
||||||
head_pred_index = None,
|
|
||||||
sent = text,
|
|
||||||
confidence = float(confidence),
|
|
||||||
question_dist = "./question_distributions/dist_wh_sbj_obj1.json",
|
|
||||||
index = ex_index)
|
|
||||||
ex_index += 1
|
|
||||||
|
|
||||||
for arg in data[3:]:
|
|
||||||
curExtraction.addArg(arg)
|
|
||||||
|
|
||||||
d[text] = d.get(text, []) + [curExtraction]
|
|
||||||
self.oie = d
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
args = docopt(__doc__)
|
|
||||||
input_fn = args["--in"]
|
|
||||||
tr = TabReader()
|
|
||||||
tr.read(input_fn)
|
|
|
@ -1,59 +0,0 @@
|
||||||
""" Usage:
|
|
||||||
tabReader --in=INPUT_FILE
|
|
||||||
|
|
||||||
Read a tab-formatted file.
|
|
||||||
Each line consists of:
|
|
||||||
sent, prob, pred, arg1, arg2, ...
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
from carb.oieReader import OieReader
|
|
||||||
from carb.extraction import Extraction
|
|
||||||
from docopt import docopt
|
|
||||||
import logging
|
|
||||||
import ipdb
|
|
||||||
|
|
||||||
logging.basicConfig(level = logging.DEBUG)
|
|
||||||
|
|
||||||
class TabReader(OieReader):
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.name = 'TabReader'
|
|
||||||
|
|
||||||
def read(self, fn):
|
|
||||||
"""
|
|
||||||
Read a tabbed format line
|
|
||||||
Each line consists of:
|
|
||||||
sent, prob, pred, arg1, arg2, ...
|
|
||||||
"""
|
|
||||||
d = {}
|
|
||||||
ex_index = 0
|
|
||||||
with open(fn) as fin:
|
|
||||||
for line in fin:
|
|
||||||
if not line.strip():
|
|
||||||
continue
|
|
||||||
data = line.strip().split('\t')
|
|
||||||
try:
|
|
||||||
text, confidence, rel = data[:3]
|
|
||||||
except ValueError:
|
|
||||||
continue
|
|
||||||
curExtraction = Extraction(pred = rel,
|
|
||||||
head_pred_index = None,
|
|
||||||
sent = text,
|
|
||||||
confidence = float(confidence),
|
|
||||||
question_dist = "./question_distributions/dist_wh_sbj_obj1.json",
|
|
||||||
index = ex_index)
|
|
||||||
ex_index += 1
|
|
||||||
|
|
||||||
for arg in data[3:]:
|
|
||||||
curExtraction.addArg(arg)
|
|
||||||
|
|
||||||
d[text] = d.get(text, []) + [curExtraction]
|
|
||||||
self.oie = d
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
args = docopt(__doc__)
|
|
||||||
input_fn = args["--in"]
|
|
||||||
tr = TabReader()
|
|
||||||
tr.read(input_fn)
|
|
|
@ -1,162 +0,0 @@
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
from torch.utils.data import Dataset
|
|
||||||
from utils import utils
|
|
||||||
from transformers import BertTokenizer
|
|
||||||
from utils.bio import pred_tag2idx, arg_tag2idx
|
|
||||||
|
|
||||||
|
|
||||||
def load_data(data_path,
|
|
||||||
batch_size,
|
|
||||||
max_len=64,
|
|
||||||
train=True,
|
|
||||||
tokenizer_config='bert-base-cased'):
|
|
||||||
if train:
|
|
||||||
return DataLoader(
|
|
||||||
dataset=OieDataset(
|
|
||||||
data_path,
|
|
||||||
max_len,
|
|
||||||
tokenizer_config),
|
|
||||||
batch_size=batch_size,
|
|
||||||
shuffle=True,
|
|
||||||
num_workers=4,
|
|
||||||
pin_memory=True,
|
|
||||||
drop_last=True)
|
|
||||||
else:
|
|
||||||
return DataLoader(
|
|
||||||
dataset=OieEvalDataset(
|
|
||||||
data_path,
|
|
||||||
max_len,
|
|
||||||
tokenizer_config),
|
|
||||||
batch_size=batch_size,
|
|
||||||
num_workers=4,
|
|
||||||
pin_memory=True)
|
|
||||||
|
|
||||||
|
|
||||||
class OieDataset(Dataset):
|
|
||||||
def __init__(self, data_path, max_len=64, tokenizer_config='bert-base-cased'):
|
|
||||||
data = utils.load_pkl(data_path)
|
|
||||||
self.tokens = data['tokens']
|
|
||||||
self.single_pred_labels = data['single_pred_labels']
|
|
||||||
self.single_arg_labels = data['single_arg_labels']
|
|
||||||
self.all_pred_labels = data['all_pred_labels']
|
|
||||||
|
|
||||||
self.max_len = max_len
|
|
||||||
self.tokenizer = BertTokenizer.from_pretrained(tokenizer_config)
|
|
||||||
self.vocab = self.tokenizer.vocab
|
|
||||||
|
|
||||||
self.pad_idx = self.vocab['[PAD]']
|
|
||||||
self.cls_idx = self.vocab['[CLS]']
|
|
||||||
self.sep_idx = self.vocab['[SEP]']
|
|
||||||
self.mask_idx = self.vocab['[MASK]']
|
|
||||||
|
|
||||||
def add_pad(self, token_ids):
|
|
||||||
diff = self.max_len - len(token_ids)
|
|
||||||
if diff > 0:
|
|
||||||
token_ids += [self.pad_idx] * diff
|
|
||||||
else:
|
|
||||||
token_ids = token_ids[:self.max_len-1] + [self.sep_idx]
|
|
||||||
return token_ids
|
|
||||||
|
|
||||||
def add_special_token(self, token_ids):
|
|
||||||
return [self.cls_idx] + token_ids + [self.sep_idx]
|
|
||||||
|
|
||||||
def idx2mask(self, token_ids):
|
|
||||||
return [token_id != self.pad_idx for token_id in token_ids]
|
|
||||||
|
|
||||||
def add_pad_to_labels(self, pred_label, arg_label, all_pred_label):
|
|
||||||
pred_outside = np.array([pred_tag2idx['O']])
|
|
||||||
arg_outside = np.array([arg_tag2idx['O']])
|
|
||||||
|
|
||||||
pred_label = np.concatenate([pred_outside, pred_label, pred_outside])
|
|
||||||
arg_label = np.concatenate([arg_outside, arg_label, arg_outside])
|
|
||||||
all_pred_label = np.concatenate([pred_outside, all_pred_label, pred_outside])
|
|
||||||
|
|
||||||
diff = self.max_len - pred_label.shape[0]
|
|
||||||
if diff > 0:
|
|
||||||
pred_pad = np.array([pred_tag2idx['O']] * diff)
|
|
||||||
arg_pad = np.array([arg_tag2idx['O']] * diff)
|
|
||||||
pred_label = np.concatenate([pred_label, pred_pad])
|
|
||||||
arg_label = np.concatenate([arg_label, arg_pad])
|
|
||||||
all_pred_label = np.concatenate([all_pred_label, pred_pad])
|
|
||||||
elif diff == 0:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
pred_label = np.concatenate([pred_label[:-1], pred_outside])
|
|
||||||
arg_label = np.concatenate([arg_label[:-1], arg_outside])
|
|
||||||
all_pred_label = np.concatenate([all_pred_label[:-1], pred_outside])
|
|
||||||
return [pred_label, arg_label, all_pred_label]
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.tokens)
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
token_ids = self.tokenizer.convert_tokens_to_ids(self.tokens[idx])
|
|
||||||
token_ids_padded = self.add_pad(self.add_special_token(token_ids))
|
|
||||||
att_mask = self.idx2mask(token_ids_padded)
|
|
||||||
labels = self.add_pad_to_labels(
|
|
||||||
self.single_pred_labels[idx],
|
|
||||||
self.single_arg_labels[idx],
|
|
||||||
self.all_pred_labels[idx])
|
|
||||||
single_pred_label, single_arg_label, all_pred_label = labels
|
|
||||||
|
|
||||||
assert len(token_ids_padded) == self.max_len
|
|
||||||
assert len(att_mask) == self.max_len
|
|
||||||
assert single_pred_label.shape[0] == self.max_len
|
|
||||||
assert single_arg_label.shape[0] == self.max_len
|
|
||||||
assert all_pred_label.shape[0] == self.max_len
|
|
||||||
|
|
||||||
batch = [
|
|
||||||
torch.tensor(token_ids_padded),
|
|
||||||
torch.tensor(att_mask),
|
|
||||||
torch.tensor(single_pred_label),
|
|
||||||
torch.tensor(single_arg_label),
|
|
||||||
torch.tensor(all_pred_label)
|
|
||||||
]
|
|
||||||
return batch
|
|
||||||
|
|
||||||
|
|
||||||
class OieEvalDataset(Dataset):
|
|
||||||
def __init__(self, data_path, max_len, tokenizer_config='bert-base-cased'):
|
|
||||||
self.sentences = utils.load_pkl(data_path)
|
|
||||||
self.tokenizer = BertTokenizer.from_pretrained(tokenizer_config)
|
|
||||||
self.vocab = self.tokenizer.vocab
|
|
||||||
self.max_len = max_len
|
|
||||||
|
|
||||||
self.pad_idx = self.vocab['[PAD]']
|
|
||||||
self.cls_idx = self.vocab['[CLS]']
|
|
||||||
self.sep_idx = self.vocab['[SEP]']
|
|
||||||
self.mask_idx = self.vocab['[MASK]']
|
|
||||||
|
|
||||||
def add_pad(self, token_ids):
|
|
||||||
diff = self.max_len - len(token_ids)
|
|
||||||
if diff > 0:
|
|
||||||
token_ids += [self.pad_idx] * diff
|
|
||||||
else:
|
|
||||||
token_ids = token_ids[:self.max_len-1] + [self.sep_idx]
|
|
||||||
return token_ids
|
|
||||||
|
|
||||||
def idx2mask(self, token_ids):
|
|
||||||
return [token_id != self.pad_idx for token_id in token_ids]
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.sentences)
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
token_ids = self.add_pad(self.tokenizer.encode(self.sentences[idx]))
|
|
||||||
att_mask = self.idx2mask(token_ids)
|
|
||||||
token_strs = self.tokenizer.convert_ids_to_tokens(token_ids)
|
|
||||||
sentence = self.sentences[idx]
|
|
||||||
|
|
||||||
assert len(token_ids) == self.max_len
|
|
||||||
assert len(att_mask) == self.max_len
|
|
||||||
assert len(token_strs) == self.max_len
|
|
||||||
batch = [
|
|
||||||
torch.tensor(token_ids),
|
|
||||||
torch.tensor(att_mask),
|
|
||||||
token_strs,
|
|
||||||
sentence
|
|
||||||
]
|
|
||||||
return batch
|
|
||||||
|
|
|
@ -1,97 +0,0 @@
|
||||||
name: multi2oie
|
|
||||||
channels:
|
|
||||||
- conda-forge
|
|
||||||
- defaults
|
|
||||||
dependencies:
|
|
||||||
- _libgcc_mutex=0.1=main
|
|
||||||
- _pytorch_select=0.2=gpu_0
|
|
||||||
- attrs=19.3.0=py_0
|
|
||||||
- blas=1.0=mkl
|
|
||||||
- brotlipy=0.7.0=py37h8f50634_1000
|
|
||||||
- ca-certificates=2020.4.5.1=hecc5488_0
|
|
||||||
- catalogue=1.0.0=py_0
|
|
||||||
- certifi=2020.4.5.1=py37hc8dfbb8_0
|
|
||||||
- cffi=1.14.0=py37he30daa8_1
|
|
||||||
- chardet=3.0.4=py37hc8dfbb8_1006
|
|
||||||
- cryptography=2.9.2=py37hb09aad4_0
|
|
||||||
- cudatoolkit=10.1.243=h6bb024c_0
|
|
||||||
- cudnn=7.6.5=cuda10.1_0
|
|
||||||
- cymem=2.0.3=py37h3340039_1
|
|
||||||
- cython-blis=0.4.1=py37h8f50634_1
|
|
||||||
- idna=2.9=py_1
|
|
||||||
- importlib-metadata=1.6.0=py37hc8dfbb8_0
|
|
||||||
- importlib_metadata=1.6.0=0
|
|
||||||
- intel-openmp=2020.1=217
|
|
||||||
- joblib=0.15.1=py_0
|
|
||||||
- jsonschema=3.2.0=py37hc8dfbb8_1
|
|
||||||
- ld_impl_linux-64=2.33.1=h53a641e_7
|
|
||||||
- libedit=3.1.20181209=hc058e9b_0
|
|
||||||
- libffi=3.3=he6710b0_1
|
|
||||||
- libgcc-ng=9.1.0=hdf63c60_0
|
|
||||||
- libgfortran-ng=7.3.0=hdf63c60_0
|
|
||||||
- libstdcxx-ng=9.1.0=hdf63c60_0
|
|
||||||
- mkl=2020.1=217
|
|
||||||
- mkl-service=2.3.0=py37he904b0f_0
|
|
||||||
- mkl_fft=1.0.15=py37ha843d7b_0
|
|
||||||
- mkl_random=1.1.1=py37h0573a6f_0
|
|
||||||
- murmurhash=1.0.0=py37h3340039_0
|
|
||||||
- ncurses=6.2=he6710b0_1
|
|
||||||
- ninja=1.9.0=py37hfd86e86_0
|
|
||||||
- numpy=1.18.1=py37h4f9e942_0
|
|
||||||
- numpy-base=1.18.1=py37hde5b4d6_1
|
|
||||||
- openssl=1.1.1g=h516909a_0
|
|
||||||
- pandas=1.0.3=py37h0573a6f_0
|
|
||||||
- pip=20.0.2=py37_3
|
|
||||||
- plac=0.9.6=py37_0
|
|
||||||
- preshed=3.0.2=py37h3340039_2
|
|
||||||
- pycparser=2.20=py_0
|
|
||||||
- pyopenssl=19.1.0=py_1
|
|
||||||
- pyrsistent=0.16.0=py37h8f50634_0
|
|
||||||
- pysocks=1.7.1=py37hc8dfbb8_1
|
|
||||||
- python=3.7.7=hcff3b4d_5
|
|
||||||
- python-dateutil=2.8.1=py_0
|
|
||||||
- python_abi=3.7=1_cp37m
|
|
||||||
- pytorch=1.4.0=cuda101py37h02f0884_0
|
|
||||||
- pytz=2020.1=py_0
|
|
||||||
- readline=8.0=h7b6447c_0
|
|
||||||
- requests=2.23.0=pyh8c360ce_2
|
|
||||||
- scikit-learn=0.22.1=py37hd81dba3_0
|
|
||||||
- scipy=1.4.1=py37h0b6359f_0
|
|
||||||
- setuptools=46.4.0=py37_0
|
|
||||||
- six=1.14.0=py37_0
|
|
||||||
- spacy=2.2.4=py37h99015e2_1
|
|
||||||
- sqlite=3.31.1=h62c20be_1
|
|
||||||
- srsly=1.0.2=py37h3340039_0
|
|
||||||
- thinc=7.4.0=py37h99015e2_2
|
|
||||||
- tk=8.6.8=hbc83047_0
|
|
||||||
- tqdm=4.46.0=pyh9f0ad1d_0
|
|
||||||
- urllib3=1.25.9=py_0
|
|
||||||
- wasabi=0.6.0=py_0
|
|
||||||
- wheel=0.34.2=py37_0
|
|
||||||
- xz=5.2.5=h7b6447c_0
|
|
||||||
- zipp=3.1.0=py_0
|
|
||||||
- zlib=1.2.11=h7b6447c_3
|
|
||||||
- pip:
|
|
||||||
- backcall==0.1.0
|
|
||||||
- click==7.1.2
|
|
||||||
- decorator==4.4.2
|
|
||||||
- docopt==0.6.2
|
|
||||||
- filelock==3.0.12
|
|
||||||
- ipdb==0.13.2
|
|
||||||
- ipython==7.14.0
|
|
||||||
- ipython-genutils==0.2.0
|
|
||||||
- jedi==0.17.0
|
|
||||||
- nltk==3.5
|
|
||||||
- parso==0.7.0
|
|
||||||
- pexpect==4.8.0
|
|
||||||
- pickleshare==0.7.5
|
|
||||||
- prompt-toolkit==3.0.5
|
|
||||||
- ptyprocess==0.6.0
|
|
||||||
- pygments==2.6.1
|
|
||||||
- regex==2020.5.14
|
|
||||||
- sacremoses==0.0.43
|
|
||||||
- sentencepiece==0.1.91
|
|
||||||
- tokenizers==0.7.0
|
|
||||||
- traitlets==4.3.3
|
|
||||||
- transformers==2.10.0
|
|
||||||
- wcwidth==0.1.9
|
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
|
@ -1,21 +0,0 @@
|
||||||
import nltk
|
|
||||||
from operator import itemgetter
|
|
||||||
|
|
||||||
class Argument:
|
|
||||||
def __init__(self, arg):
|
|
||||||
self.words = [x for x in arg[0].strip().split(' ') if x]
|
|
||||||
self.posTags = map(itemgetter(1), nltk.pos_tag(self.words))
|
|
||||||
self.indices = arg[1]
|
|
||||||
self.feats = {}
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return "({})".format('\t'.join(map(str,
|
|
||||||
[escape_special_chars(' '.join(self.words)),
|
|
||||||
str(self.indices)])))
|
|
||||||
|
|
||||||
COREF = 'coref'
|
|
||||||
|
|
||||||
## Helper functions
|
|
||||||
def escape_special_chars(s):
|
|
||||||
return s.replace('\t', '\\t')
|
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,219 +0,0 @@
|
||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
Created on Thu Oct 25 19:24:30 2018
|
|
||||||
|
|
||||||
@author: longzhan
|
|
||||||
"""
|
|
||||||
|
|
||||||
import string
|
|
||||||
import numpy as np
|
|
||||||
from sklearn.metrics import precision_recall_curve
|
|
||||||
from sklearn.metrics import auc
|
|
||||||
import re
|
|
||||||
import logging
|
|
||||||
import sys
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
|
|
||||||
from evaluate.generalReader import GeneralReader
|
|
||||||
from evaluate.goldReader import GoldReader
|
|
||||||
from evaluate.gold_relabel import Relabel_GoldReader
|
|
||||||
from evaluate.matcher import Matcher
|
|
||||||
from operator import itemgetter
|
|
||||||
|
|
||||||
|
|
||||||
class Benchmark:
|
|
||||||
''' Compare the gold OIE dataset against a predicted equivalent '''
|
|
||||||
def __init__(self, gold_fn):
|
|
||||||
''' Load gold Open IE, this will serve to compare against using the compare function '''
|
|
||||||
if 'Re' in gold_fn:
|
|
||||||
gr = Relabel_GoldReader()
|
|
||||||
else:
|
|
||||||
gr = GoldReader()
|
|
||||||
gr.read(gold_fn)
|
|
||||||
self.gold = gr.oie
|
|
||||||
|
|
||||||
def compare(self, predicted, matchingFunc, output_fn, error_file = None):
|
|
||||||
''' Compare gold against predicted using a specified matching function.
|
|
||||||
Outputs PR curve to output_fn '''
|
|
||||||
|
|
||||||
y_true = []
|
|
||||||
y_scores = []
|
|
||||||
errors = []
|
|
||||||
|
|
||||||
correctTotal = 0
|
|
||||||
unmatchedCount = 0
|
|
||||||
|
|
||||||
non_sent = 0
|
|
||||||
non_match = 0
|
|
||||||
|
|
||||||
predicted = Benchmark.normalizeDict(predicted)
|
|
||||||
gold = Benchmark.normalizeDict(self.gold)
|
|
||||||
|
|
||||||
for sent, goldExtractions in gold.items():
|
|
||||||
if sent not in predicted:
|
|
||||||
# The extractor didn't find any extractions for this sentence
|
|
||||||
for goldEx in goldExtractions:
|
|
||||||
non_sent += 1
|
|
||||||
unmatchedCount += len(goldExtractions)
|
|
||||||
correctTotal += len(goldExtractions)
|
|
||||||
continue
|
|
||||||
|
|
||||||
predictedExtractions = predicted[sent]
|
|
||||||
|
|
||||||
for goldEx in goldExtractions:
|
|
||||||
correctTotal += 1
|
|
||||||
found = False
|
|
||||||
|
|
||||||
for predictedEx in predictedExtractions:
|
|
||||||
if output_fn in predictedEx.matched:
|
|
||||||
# This predicted extraction was already matched against a gold extraction
|
|
||||||
# Don't allow to match it again
|
|
||||||
continue
|
|
||||||
|
|
||||||
if matchingFunc(goldEx,
|
|
||||||
predictedEx,
|
|
||||||
ignoreStopwords=True,
|
|
||||||
ignoreCase=True):
|
|
||||||
|
|
||||||
y_true.append(1)
|
|
||||||
y_scores.append(predictedEx.confidence)
|
|
||||||
predictedEx.matched.append(output_fn)
|
|
||||||
found = True
|
|
||||||
break
|
|
||||||
|
|
||||||
if not found:
|
|
||||||
non_match += 1
|
|
||||||
errors.append(goldEx.index)
|
|
||||||
unmatchedCount += 1
|
|
||||||
|
|
||||||
for predictedEx in [x for x in predictedExtractions if (output_fn not in x.matched)]:
|
|
||||||
# Add false positives
|
|
||||||
y_true.append(0)
|
|
||||||
y_scores.append(predictedEx.confidence)
|
|
||||||
|
|
||||||
y_true = y_true
|
|
||||||
y_scores = y_scores
|
|
||||||
|
|
||||||
print("non_sent: ", non_sent)
|
|
||||||
print("non_match: ", non_match)
|
|
||||||
print("correctTotal: ", correctTotal)
|
|
||||||
print("unmatchedCount: ", unmatchedCount)
|
|
||||||
|
|
||||||
# recall on y_true, y (r')_scores computes |covered by extractor| / |True in what's covered by extractor|
|
|
||||||
# to get to true recall we do:
|
|
||||||
# r' * (|True in what's covered by extractor| / |True in gold|) = |true in what's covered| / |true in gold|
|
|
||||||
(p, r), optimal = Benchmark.prCurve(np.array(y_true), np.array(y_scores),
|
|
||||||
recallMultiplier = ((correctTotal - unmatchedCount)/float(correctTotal)))
|
|
||||||
cur_auc = auc(r, p)
|
|
||||||
print("AUC: {}\n Optimal (precision, recall, F1, threshold): {}".format(cur_auc, optimal))
|
|
||||||
|
|
||||||
# Write error log to file
|
|
||||||
if error_file:
|
|
||||||
logging.info("Writing {} error indices to {}".format(len(errors),
|
|
||||||
error_file))
|
|
||||||
with open(error_file, 'w') as fout:
|
|
||||||
fout.write('\n'.join([str(error)
|
|
||||||
for error
|
|
||||||
in errors]) + '\n')
|
|
||||||
|
|
||||||
# write PR to file
|
|
||||||
with open(output_fn, 'w') as fout:
|
|
||||||
fout.write('{0}\t{1}\n'.format("Precision", "Recall"))
|
|
||||||
for cur_p, cur_r in sorted(zip(p, r), key = lambda cur: cur[1]):
|
|
||||||
fout.write('{0}\t{1}\n'.format(cur_p, cur_r))
|
|
||||||
|
|
||||||
return optimal[:-1], cur_auc
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def prCurve(y_true, y_scores, recallMultiplier):
|
|
||||||
# Recall multiplier - accounts for the percentage examples unreached
|
|
||||||
# Return (precision [list], recall[list]), (Optimal F1, Optimal threshold)
|
|
||||||
y_scores = [score \
|
|
||||||
if not (np.isnan(score) or (not np.isfinite(score))) \
|
|
||||||
else 0
|
|
||||||
for score in y_scores]
|
|
||||||
|
|
||||||
precision_ls, recall_ls, thresholds = precision_recall_curve(y_true, y_scores)
|
|
||||||
recall_ls = recall_ls * recallMultiplier
|
|
||||||
optimal = max([(precision, recall, f_beta(precision, recall, beta = 1), threshold)
|
|
||||||
for ((precision, recall), threshold)
|
|
||||||
in zip(zip(precision_ls[:-1], recall_ls[:-1]),
|
|
||||||
thresholds)],
|
|
||||||
key=itemgetter(2)) # Sort by f1 score
|
|
||||||
|
|
||||||
return ((precision_ls, recall_ls),
|
|
||||||
optimal)
|
|
||||||
|
|
||||||
# Helper functions:
|
|
||||||
@staticmethod
|
|
||||||
def normalizeDict(d):
|
|
||||||
return dict([(Benchmark.normalizeKey(k), v) for k, v in d.items()])
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def normalizeKey(k):
|
|
||||||
return Benchmark.removePunct(Benchmark.PTB_unescape(k.replace(' ','')))
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def PTB_escape(s):
|
|
||||||
for u, e in Benchmark.PTB_ESCAPES:
|
|
||||||
s = s.replace(u, e)
|
|
||||||
return s
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def PTB_unescape(s):
|
|
||||||
for u, e in Benchmark.PTB_ESCAPES:
|
|
||||||
s = s.replace(e, u)
|
|
||||||
return s
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def removePunct(s):
|
|
||||||
return Benchmark.regex.sub('', s)
|
|
||||||
|
|
||||||
# CONSTANTS
|
|
||||||
regex = re.compile('[%s]' % re.escape(string.punctuation))
|
|
||||||
|
|
||||||
# Penn treebank bracket escapes
|
|
||||||
# Taken from: https://github.com/nlplab/brat/blob/master/server/src/gtbtokenize.py
|
|
||||||
PTB_ESCAPES = [('(', '-LRB-'),
|
|
||||||
(')', '-RRB-'),
|
|
||||||
('[', '-LSB-'),
|
|
||||||
(']', '-RSB-'),
|
|
||||||
('{', '-LCB-'),
|
|
||||||
('}', '-RCB-'),]
|
|
||||||
|
|
||||||
|
|
||||||
def f_beta(precision, recall, beta = 1):
|
|
||||||
"""
|
|
||||||
Get F_beta score from precision and recall.
|
|
||||||
"""
|
|
||||||
beta = float(beta) # Make sure that results are in float
|
|
||||||
return (1 + pow(beta, 2)) * (precision * recall) / ((pow(beta, 2) * precision) + recall)
|
|
||||||
|
|
||||||
|
|
||||||
f1 = lambda precision, recall: f_beta(precision, recall, beta = 1)
|
|
||||||
|
|
||||||
#gold_flag = sys.argv[1] # to choose whether to use OIE2016 or Re-OIE2016
|
|
||||||
#in_path = sys.argv[2] # input file
|
|
||||||
#out_path = sys.argv[3] # output file
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
|
|
||||||
gold = gold_flag
|
|
||||||
matchingFunc = Matcher.lexicalMatch
|
|
||||||
error_fn = "error.txt"
|
|
||||||
|
|
||||||
if gold == "old":
|
|
||||||
gold_fn = "OIE2016.txt"
|
|
||||||
else:
|
|
||||||
gold_fn = "Re-OIE2016.json"
|
|
||||||
|
|
||||||
b = Benchmark(gold_fn)
|
|
||||||
s_fn = in_path
|
|
||||||
p = GeneralReader()
|
|
||||||
other_p = GeneralReader()
|
|
||||||
other_p.read(s_fn)
|
|
||||||
b.compare(predicted = other_p.oie,
|
|
||||||
matchingFunc = matchingFunc,
|
|
||||||
output_fn = out_path,
|
|
||||||
error_file = error_fn)
|
|
|
@ -1,444 +0,0 @@
|
||||||
from sklearn.preprocessing.data import binarize
|
|
||||||
from evaluate.argument import Argument
|
|
||||||
from operator import itemgetter
|
|
||||||
from collections import defaultdict
|
|
||||||
import nltk
|
|
||||||
import itertools
|
|
||||||
import logging
|
|
||||||
import numpy as np
|
|
||||||
import pdb
|
|
||||||
|
|
||||||
|
|
||||||
class Extraction:
|
|
||||||
"""
|
|
||||||
Stores sentence, single predicate and corresponding arguments.
|
|
||||||
"""
|
|
||||||
def __init__(self, pred, head_pred_index, sent, confidence, question_dist = '', index = -1):
|
|
||||||
self.pred = pred
|
|
||||||
self.head_pred_index = head_pred_index
|
|
||||||
self.sent = sent
|
|
||||||
self.args = []
|
|
||||||
self.confidence = confidence
|
|
||||||
self.matched = []
|
|
||||||
self.questions = {}
|
|
||||||
self.indsForQuestions = defaultdict(lambda: set())
|
|
||||||
self.is_mwp = False
|
|
||||||
self.question_dist = question_dist
|
|
||||||
self.index = index
|
|
||||||
|
|
||||||
def distArgFromPred(self, arg):
|
|
||||||
assert(len(self.pred) == 2)
|
|
||||||
dists = []
|
|
||||||
for x in self.pred[1]:
|
|
||||||
for y in arg.indices:
|
|
||||||
dists.append(abs(x - y))
|
|
||||||
|
|
||||||
return min(dists)
|
|
||||||
|
|
||||||
def argsByDistFromPred(self, question):
|
|
||||||
return sorted(self.questions[question], key = lambda arg: self.distArgFromPred(arg))
|
|
||||||
|
|
||||||
def addArg(self, arg, question = None):
|
|
||||||
self.args.append(arg)
|
|
||||||
if question:
|
|
||||||
self.questions[question] = self.questions.get(question,[]) + [Argument(arg)]
|
|
||||||
|
|
||||||
def noPronounArgs(self):
|
|
||||||
"""
|
|
||||||
Returns True iff all of this extraction's arguments are not pronouns.
|
|
||||||
"""
|
|
||||||
for (a, _) in self.args:
|
|
||||||
tokenized_arg = nltk.word_tokenize(a)
|
|
||||||
if len(tokenized_arg) == 1:
|
|
||||||
_, pos_tag = nltk.pos_tag(tokenized_arg)[0]
|
|
||||||
if ('PRP' in pos_tag):
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def isContiguous(self):
|
|
||||||
return all([indices for (_, indices) in self.args])
|
|
||||||
|
|
||||||
def toBinary(self):
|
|
||||||
''' Try to represent this extraction's arguments as binary
|
|
||||||
If fails, this function will return an empty list. '''
|
|
||||||
|
|
||||||
ret = [self.elementToStr(self.pred)]
|
|
||||||
|
|
||||||
if len(self.args) == 2:
|
|
||||||
# we're in luck
|
|
||||||
return ret + [self.elementToStr(arg) for arg in self.args]
|
|
||||||
|
|
||||||
return []
|
|
||||||
|
|
||||||
if not self.isContiguous():
|
|
||||||
# give up on non contiguous arguments (as we need indexes)
|
|
||||||
return []
|
|
||||||
|
|
||||||
# otherwise, try to merge based on indices
|
|
||||||
# TODO: you can explore other methods for doing this
|
|
||||||
binarized = self.binarizeByIndex()
|
|
||||||
|
|
||||||
if binarized:
|
|
||||||
return ret + binarized
|
|
||||||
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
def elementToStr(self, elem, print_indices = True):
|
|
||||||
''' formats an extraction element (pred or arg) as a raw string
|
|
||||||
removes indices and trailing spaces '''
|
|
||||||
if print_indices:
|
|
||||||
return str(elem)
|
|
||||||
if isinstance(elem, str):
|
|
||||||
return elem
|
|
||||||
if isinstance(elem, tuple):
|
|
||||||
ret = elem[0].rstrip().lstrip()
|
|
||||||
else:
|
|
||||||
ret = ' '.join(elem.words)
|
|
||||||
assert ret, "empty element? {0}".format(elem)
|
|
||||||
return ret
|
|
||||||
|
|
||||||
def binarizeByIndex(self):
|
|
||||||
extraction = [self.pred] + self.args
|
|
||||||
markPred = [(w, ind, i == 0) for i, (w, ind) in enumerate(extraction)]
|
|
||||||
sortedExtraction = sorted(markPred, key = lambda x : x[1][0])
|
|
||||||
s = ' '.join(['{1} {0} {1}'.format(self.elementToStr(elem), SEP) if elem[2] else self.elementToStr(elem) for elem in sortedExtraction])
|
|
||||||
binArgs = [a for a in s.split(SEP) if a.rstrip().lstrip()]
|
|
||||||
|
|
||||||
if len(binArgs) == 2:
|
|
||||||
return binArgs
|
|
||||||
|
|
||||||
# failure
|
|
||||||
return []
|
|
||||||
|
|
||||||
def bow(self):
|
|
||||||
return ' '.join([self.elementToStr(elem) for elem in [self.pred] + self.args])
|
|
||||||
|
|
||||||
def getSortedArgs(self):
|
|
||||||
"""
|
|
||||||
Sort the list of arguments.
|
|
||||||
If a question distribution is provided - use it,
|
|
||||||
otherwise, default to the order of appearance in the sentence.
|
|
||||||
"""
|
|
||||||
if self.question_dist:
|
|
||||||
# There's a question distribtuion - use it
|
|
||||||
return self.sort_args_by_distribution()
|
|
||||||
ls = []
|
|
||||||
for q, args in self.questions.iteritems():
|
|
||||||
if (len(args) != 1):
|
|
||||||
logging.debug("Not one argument: {}".format(args))
|
|
||||||
continue
|
|
||||||
arg = args[0]
|
|
||||||
indices = list(self.indsForQuestions[q].union(arg.indices))
|
|
||||||
if not indices:
|
|
||||||
logging.debug("Empty indexes for arg {} -- backing to zero".format(arg))
|
|
||||||
indices = [0]
|
|
||||||
ls.append(((arg, q), indices))
|
|
||||||
return [a for a, _ in sorted(ls,key = lambda x: min(x[1]))]
|
|
||||||
|
|
||||||
def question_prob_for_loc(self, question, loc):
|
|
||||||
"""
|
|
||||||
Returns the probability of the given question leading to argument
|
|
||||||
appearing in the given location in the output slot.
|
|
||||||
"""
|
|
||||||
gen_question = generalize_question(question)
|
|
||||||
q_dist = self.question_dist[gen_question]
|
|
||||||
logging.debug("distribution of {}: {}".format(gen_question,
|
|
||||||
q_dist))
|
|
||||||
|
|
||||||
return float(q_dist.get(loc, 0)) / \
|
|
||||||
sum(q_dist.values())
|
|
||||||
|
|
||||||
def sort_args_by_distribution(self):
|
|
||||||
"""
|
|
||||||
Use this instance's question distribution (this func assumes it exists)
|
|
||||||
in determining the positioning of the arguments.
|
|
||||||
Greedy algorithm:
|
|
||||||
0. Decide on which argument will serve as the ``subject'' (first slot) of this extraction
|
|
||||||
0.1 Based on the most probable one for this spot
|
|
||||||
(special care is given to select the highly-influential subject position)
|
|
||||||
1. For all other arguments, sort arguments by the prevalance of their questions
|
|
||||||
2. For each argument:
|
|
||||||
2.1 Assign to it the most probable slot still available
|
|
||||||
2.2 If non such exist (fallback) - default to put it in the last location
|
|
||||||
"""
|
|
||||||
INF_LOC = 100 # Used as an impractical last argument
|
|
||||||
|
|
||||||
# Store arguments by slot
|
|
||||||
ret = {INF_LOC: []}
|
|
||||||
logging.debug("sorting: {}".format(self.questions))
|
|
||||||
|
|
||||||
# Find the most suitable arguemnt for the subject location
|
|
||||||
logging.debug("probs for subject: {}".format([(q, self.question_prob_for_loc(q, 0))
|
|
||||||
for (q, _) in self.questions.iteritems()]))
|
|
||||||
|
|
||||||
subj_question, subj_args = max(self.questions.iteritems(),
|
|
||||||
key = lambda x: self.question_prob_for_loc(x[0], 0))
|
|
||||||
|
|
||||||
ret[0] = [(subj_args[0], subj_question)]
|
|
||||||
|
|
||||||
# Find the rest
|
|
||||||
for (question, args) in sorted([(q, a)
|
|
||||||
for (q, a) in self.questions.iteritems() if (q not in [subj_question])],
|
|
||||||
key = lambda x: \
|
|
||||||
sum(self.question_dist[generalize_question(x[0])].values()),
|
|
||||||
reverse = True):
|
|
||||||
gen_question = generalize_question(question)
|
|
||||||
arg = args[0]
|
|
||||||
assigned_flag = False
|
|
||||||
for (loc, count) in sorted(self.question_dist[gen_question].iteritems(),
|
|
||||||
key = lambda x: x[1],
|
|
||||||
reverse = True):
|
|
||||||
if loc not in ret:
|
|
||||||
# Found an empty slot for this item
|
|
||||||
# Place it there and break out
|
|
||||||
ret[loc] = [(arg, question)]
|
|
||||||
assigned_flag = True
|
|
||||||
break
|
|
||||||
|
|
||||||
if not assigned_flag:
|
|
||||||
# Add this argument to the non-assigned (hopefully doesn't happen much)
|
|
||||||
logging.debug("Couldn't find an open assignment for {}".format((arg, gen_question)))
|
|
||||||
ret[INF_LOC].append((arg, question))
|
|
||||||
|
|
||||||
logging.debug("Linearizing arg list: {}".format(ret))
|
|
||||||
|
|
||||||
# Finished iterating - consolidate and return a list of arguments
|
|
||||||
return [arg
|
|
||||||
for (_, arg_ls) in sorted(ret.iteritems(),
|
|
||||||
key = lambda x: int(x[0]))
|
|
||||||
for arg in arg_ls]
|
|
||||||
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
pred_str = self.elementToStr(self.pred)
|
|
||||||
return '{}\t{}\t{}'.format(self.get_base_verb(pred_str),
|
|
||||||
self.compute_global_pred(pred_str,
|
|
||||||
self.questions.keys()),
|
|
||||||
'\t'.join([escape_special_chars(self.augment_arg_with_question(self.elementToStr(arg),
|
|
||||||
question))
|
|
||||||
for arg, question in self.getSortedArgs()]))
|
|
||||||
|
|
||||||
def get_base_verb(self, surface_pred):
|
|
||||||
"""
|
|
||||||
Given the surface pred, return the original annotated verb
|
|
||||||
"""
|
|
||||||
# Assumes that at this point the verb is always the last word
|
|
||||||
# in the surface predicate
|
|
||||||
return surface_pred.split(' ')[-1]
|
|
||||||
|
|
||||||
|
|
||||||
def compute_global_pred(self, surface_pred, questions):
|
|
||||||
"""
|
|
||||||
Given the surface pred and all instansiations of questions,
|
|
||||||
make global coherence decisions regarding the final form of the predicate
|
|
||||||
This should hopefully take care of multi word predicates and correct inflections
|
|
||||||
"""
|
|
||||||
from operator import itemgetter
|
|
||||||
split_surface = surface_pred.split(' ')
|
|
||||||
|
|
||||||
if len(split_surface) > 1:
|
|
||||||
# This predicate has a modal preceding the base verb
|
|
||||||
verb = split_surface[-1]
|
|
||||||
ret = split_surface[:-1] # get all of the elements in the modal
|
|
||||||
else:
|
|
||||||
verb = split_surface[0]
|
|
||||||
ret = []
|
|
||||||
|
|
||||||
split_questions = map(lambda question: question.split(' '),
|
|
||||||
questions)
|
|
||||||
|
|
||||||
preds = map(normalize_element,
|
|
||||||
map(itemgetter(QUESTION_TRG_INDEX),
|
|
||||||
split_questions))
|
|
||||||
if len(set(preds)) > 1:
|
|
||||||
# This predicate is appears in multiple ways, let's stick to the base form
|
|
||||||
ret.append(verb)
|
|
||||||
|
|
||||||
if len(set(preds)) == 1:
|
|
||||||
# Change the predciate to the inflected form
|
|
||||||
# if there's exactly one way in which the predicate is conveyed
|
|
||||||
ret.append(preds[0])
|
|
||||||
|
|
||||||
pps = map(normalize_element,
|
|
||||||
map(itemgetter(QUESTION_PP_INDEX),
|
|
||||||
split_questions))
|
|
||||||
|
|
||||||
obj2s = map(normalize_element,
|
|
||||||
map(itemgetter(QUESTION_OBJ2_INDEX),
|
|
||||||
split_questions))
|
|
||||||
|
|
||||||
if (len(set(pps)) == 1):
|
|
||||||
# If all questions for the predicate include the same pp attachemnt -
|
|
||||||
# assume it's a multiword predicate
|
|
||||||
self.is_mwp = True # Signal to arguments that they shouldn't take the preposition
|
|
||||||
ret.append(pps[0])
|
|
||||||
|
|
||||||
# Concat all elements in the predicate and return
|
|
||||||
return " ".join(ret).strip()
|
|
||||||
|
|
||||||
|
|
||||||
def augment_arg_with_question(self, arg, question):
|
|
||||||
"""
|
|
||||||
Decide what elements from the question to incorporate in the given
|
|
||||||
corresponding argument
|
|
||||||
"""
|
|
||||||
# Parse question
|
|
||||||
wh, aux, sbj, trg, obj1, pp, obj2 = map(normalize_element,
|
|
||||||
question.split(' ')[:-1]) # Last split is the question mark
|
|
||||||
|
|
||||||
# Place preposition in argument
|
|
||||||
# This is safer when dealing with n-ary arguments, as it's directly attaches to the
|
|
||||||
# appropriate argument
|
|
||||||
if (not self.is_mwp) and pp and (not obj2):
|
|
||||||
if not(arg.startswith("{} ".format(pp))):
|
|
||||||
# Avoid repeating the preporition in cases where both question and answer contain it
|
|
||||||
return " ".join([pp,
|
|
||||||
arg])
|
|
||||||
|
|
||||||
# Normal cases
|
|
||||||
return arg
|
|
||||||
|
|
||||||
def clusterScore(self, cluster):
|
|
||||||
"""
|
|
||||||
Calculate cluster density score as the mean distance of the maximum distance of each slot.
|
|
||||||
Lower score represents a denser cluster.
|
|
||||||
"""
|
|
||||||
logging.debug("*-*-*- Cluster: {}".format(cluster))
|
|
||||||
|
|
||||||
# Find global centroid
|
|
||||||
arr = np.array([x for ls in cluster for x in ls])
|
|
||||||
centroid = np.sum(arr)/arr.shape[0]
|
|
||||||
logging.debug("Centroid: {}".format(centroid))
|
|
||||||
|
|
||||||
# Calculate mean over all maxmimum points
|
|
||||||
return np.average([max([abs(x - centroid) for x in ls]) for ls in cluster])
|
|
||||||
|
|
||||||
def resolveAmbiguity(self):
|
|
||||||
"""
|
|
||||||
Heursitic to map the elments (argument and predicates) of this extraction
|
|
||||||
back to the indices of the sentence.
|
|
||||||
"""
|
|
||||||
## TODO: This removes arguments for which there was no consecutive span found
|
|
||||||
## Part of these are non-consecutive arguments,
|
|
||||||
## but other could be a bug in recognizing some punctuation marks
|
|
||||||
|
|
||||||
elements = [self.pred] \
|
|
||||||
+ [(s, indices)
|
|
||||||
for (s, indices)
|
|
||||||
in self.args
|
|
||||||
if indices]
|
|
||||||
logging.debug("Resolving ambiguity in: {}".format(elements))
|
|
||||||
|
|
||||||
# Collect all possible combinations of arguments and predicate indices
|
|
||||||
# (hopefully it's not too much)
|
|
||||||
all_combinations = list(itertools.product(*map(itemgetter(1), elements)))
|
|
||||||
logging.debug("Number of combinations: {}".format(len(all_combinations)))
|
|
||||||
|
|
||||||
# Choose the ones with best clustering and unfold them
|
|
||||||
resolved_elements = zip(map(itemgetter(0), elements),
|
|
||||||
min(all_combinations,
|
|
||||||
key = lambda cluster: self.clusterScore(cluster)))
|
|
||||||
logging.debug("Resolved elements = {}".format(resolved_elements))
|
|
||||||
|
|
||||||
self.pred = resolved_elements[0]
|
|
||||||
self.args = resolved_elements[1:]
|
|
||||||
|
|
||||||
def conll(self, external_feats = {}):
|
|
||||||
"""
|
|
||||||
Return a CoNLL string representation of this extraction
|
|
||||||
"""
|
|
||||||
return '\n'.join(["\t".join(map(str,
|
|
||||||
[i, w] + \
|
|
||||||
list(self.pred) + \
|
|
||||||
[self.head_pred_index] + \
|
|
||||||
external_feats + \
|
|
||||||
[self.get_label(i)]))
|
|
||||||
for (i, w)
|
|
||||||
in enumerate(self.sent.split(" "))]) + '\n'
|
|
||||||
|
|
||||||
def get_label(self, index):
|
|
||||||
"""
|
|
||||||
Given an index of a word in the sentence -- returns the appropriate BIO conll label
|
|
||||||
Assumes that ambiguation was already resolved.
|
|
||||||
"""
|
|
||||||
# Get the element(s) in which this index appears
|
|
||||||
ent = [(elem_ind, elem)
|
|
||||||
for (elem_ind, elem)
|
|
||||||
in enumerate(map(itemgetter(1),
|
|
||||||
[self.pred] + self.args))
|
|
||||||
if index in elem]
|
|
||||||
|
|
||||||
if not ent:
|
|
||||||
# index doesnt appear in any element
|
|
||||||
return "O"
|
|
||||||
|
|
||||||
if len(ent) > 1:
|
|
||||||
# The same word appears in two different answers
|
|
||||||
# In this case we choose the first one as label
|
|
||||||
logging.warn("Index {} appears in one than more element: {}".\
|
|
||||||
format(index,
|
|
||||||
"\t".join(map(str,
|
|
||||||
[ent,
|
|
||||||
self.sent,
|
|
||||||
self.pred,
|
|
||||||
self.args]))))
|
|
||||||
|
|
||||||
## Some indices appear in more than one argument (ones where the above message appears)
|
|
||||||
## From empricial observation, these seem to mostly consist of different levels of granularity:
|
|
||||||
## what had _ been taken _ _ _ ? loan commitments topping $ 3 billion
|
|
||||||
## how much had _ been taken _ _ _ ? topping $ 3 billion
|
|
||||||
## In these cases we heuristically choose the shorter answer span, hopefully creating minimal spans
|
|
||||||
## E.g., in this example two arguemnts are created: (loan commitments, topping $ 3 billion)
|
|
||||||
|
|
||||||
elem_ind, elem = min(ent, key = lambda x: len(x[1]))
|
|
||||||
|
|
||||||
# Distinguish between predicate and arguments
|
|
||||||
prefix = "P" if elem_ind == 0 else "A{}".format(elem_ind - 1)
|
|
||||||
|
|
||||||
# Distinguish between Beginning and Inside labels
|
|
||||||
suffix = "B" if index == elem[0] else "I"
|
|
||||||
|
|
||||||
return "{}-{}".format(prefix, suffix)
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return '{0}\t{1}'.format(self.elementToStr(self.pred,
|
|
||||||
print_indices = True),
|
|
||||||
'\t'.join([self.elementToStr(arg)
|
|
||||||
for arg
|
|
||||||
in self.args]))
|
|
||||||
|
|
||||||
# Flatten a list of lists
|
|
||||||
flatten = lambda l: [item for sublist in l for item in sublist]
|
|
||||||
|
|
||||||
|
|
||||||
def normalize_element(elem):
|
|
||||||
"""
|
|
||||||
Return a surface form of the given question element.
|
|
||||||
the output should be properly able to precede a predicate (or blank otherwise)
|
|
||||||
"""
|
|
||||||
return elem.replace("_", " ") \
|
|
||||||
if (elem != "_")\
|
|
||||||
else ""
|
|
||||||
|
|
||||||
## Helper functions
|
|
||||||
def escape_special_chars(s):
|
|
||||||
return s.replace('\t', '\\t')
|
|
||||||
|
|
||||||
|
|
||||||
def generalize_question(question):
|
|
||||||
"""
|
|
||||||
Given a question in the context of the sentence and the predicate index within
|
|
||||||
the question - return a generalized version which extracts only order-imposing features
|
|
||||||
"""
|
|
||||||
import nltk # Using nltk since couldn't get spaCy to agree on the tokenization
|
|
||||||
wh, aux, sbj, trg, obj1, pp, obj2 = question.split(' ')[:-1] # Last split is the question mark
|
|
||||||
return ' '.join([wh, sbj, obj1])
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## CONSTANTS
|
|
||||||
SEP = ';;;'
|
|
||||||
QUESTION_TRG_INDEX = 3 # index of the predicate within the question
|
|
||||||
QUESTION_PP_INDEX = 5
|
|
||||||
QUESTION_OBJ2_INDEX = 6
|
|
|
@ -1,43 +0,0 @@
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
Created on Thu Nov 15 21:05:15 2018
|
|
||||||
|
|
||||||
@author: win 10
|
|
||||||
"""
|
|
||||||
|
|
||||||
from evaluate.oieReader import OieReader
|
|
||||||
from evaluate.extraction import Extraction
|
|
||||||
|
|
||||||
class GeneralReader(OieReader):
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.name = 'General'
|
|
||||||
|
|
||||||
def read(self, fn):
|
|
||||||
d = {}
|
|
||||||
with open(fn) as fin:
|
|
||||||
for line in fin:
|
|
||||||
data = line.strip().split('\t')
|
|
||||||
if len(data) >= 4:
|
|
||||||
arg1 = data[3]
|
|
||||||
rel = data[2]
|
|
||||||
arg_else = data[4:]
|
|
||||||
confidence = data[1]
|
|
||||||
text = data[0]
|
|
||||||
|
|
||||||
curExtraction = Extraction(pred = rel, head_pred_index=-1, sent = text, confidence = float(confidence))
|
|
||||||
curExtraction.addArg(arg1)
|
|
||||||
for arg in arg_else:
|
|
||||||
curExtraction.addArg(arg)
|
|
||||||
d[text] = d.get(text, []) + [curExtraction]
|
|
||||||
self.oie = d
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
fn = "../data/other_systems/openie4_test.txt"
|
|
||||||
reader = GeneralReader()
|
|
||||||
reader.read(fn)
|
|
||||||
for key in reader.oie:
|
|
||||||
print(key)
|
|
||||||
print(reader.oie[key][0].pred)
|
|
||||||
print(reader.oie[key][0].args)
|
|
||||||
print(reader.oie[key][0].confidence)
|
|
|
@ -1,58 +0,0 @@
|
||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
Created on Thu Oct 25 19:24:30 2018
|
|
||||||
|
|
||||||
@author: longzhan
|
|
||||||
"""
|
|
||||||
|
|
||||||
from evaluate.oieReader import OieReader
|
|
||||||
from evaluate.extraction import Extraction
|
|
||||||
from _collections import defaultdict
|
|
||||||
|
|
||||||
|
|
||||||
class GoldReader(OieReader):
|
|
||||||
|
|
||||||
# Path relative to repo root folder
|
|
||||||
default_filename = './oie_corpus/all.oie'
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.name = 'Gold'
|
|
||||||
|
|
||||||
def read(self, fn):
|
|
||||||
d = defaultdict(lambda: [])
|
|
||||||
multilingual = False
|
|
||||||
for lang in ['spanish']:
|
|
||||||
if lang in fn:
|
|
||||||
multilingual = True
|
|
||||||
encoding = lang
|
|
||||||
break
|
|
||||||
if multilingual and encoding == 'spanish':
|
|
||||||
fin = open(fn, 'r', encoding='latin-1')
|
|
||||||
else:
|
|
||||||
fin = open(fn)
|
|
||||||
for line_ind, line in enumerate(fin):
|
|
||||||
data = line.strip().split('\t')
|
|
||||||
text, rel = data[:2]
|
|
||||||
args = data[2:]
|
|
||||||
confidence = 1
|
|
||||||
|
|
||||||
curExtraction = Extraction(pred = rel.strip(),
|
|
||||||
head_pred_index = None,
|
|
||||||
sent = text.strip(),
|
|
||||||
confidence = float(confidence),
|
|
||||||
index = line_ind)
|
|
||||||
for arg in args:
|
|
||||||
curExtraction.addArg(arg)
|
|
||||||
|
|
||||||
d[text.strip()].append(curExtraction)
|
|
||||||
self.oie = d
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__' :
|
|
||||||
g = GoldReader()
|
|
||||||
g.read('../test.oie')
|
|
||||||
d = g.oie
|
|
||||||
e = list(d.items())[0]
|
|
||||||
print (e[1][0].bow())
|
|
||||||
print (g.count())
|
|
|
@ -1,46 +0,0 @@
|
||||||
from evaluate.oieReader import OieReader
|
|
||||||
from evaluate.extraction import Extraction
|
|
||||||
from _collections import defaultdict
|
|
||||||
import json
|
|
||||||
|
|
||||||
class Relabel_GoldReader(OieReader):
|
|
||||||
|
|
||||||
# Path relative to repo root folder
|
|
||||||
default_filename = './oie_corpus/all.oie'
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.name = 'Relabel_Gold'
|
|
||||||
|
|
||||||
def read(self, fn):
|
|
||||||
d = defaultdict(lambda: [])
|
|
||||||
with open(fn) as fin:
|
|
||||||
data = json.load(fin)
|
|
||||||
for sentence in data:
|
|
||||||
tuples = data[sentence]
|
|
||||||
for t in tuples:
|
|
||||||
if t["pred"].strip() == "<be>":
|
|
||||||
rel = "[is]"
|
|
||||||
else:
|
|
||||||
rel = t["pred"].replace("<be> ","")
|
|
||||||
confidence = 1
|
|
||||||
|
|
||||||
curExtraction = Extraction(pred = rel,
|
|
||||||
head_pred_index = None,
|
|
||||||
sent = sentence,
|
|
||||||
confidence = float(confidence),
|
|
||||||
index = None)
|
|
||||||
if t["arg0"] != "":
|
|
||||||
curExtraction.addArg(t["arg0"])
|
|
||||||
if t["arg1"] != "":
|
|
||||||
curExtraction.addArg(t["arg1"])
|
|
||||||
if t["arg2"] != "":
|
|
||||||
curExtraction.addArg(t["arg2"])
|
|
||||||
if t["arg3"] != "":
|
|
||||||
curExtraction.addArg(t["arg3"])
|
|
||||||
if t["temp"] != "":
|
|
||||||
curExtraction.addArg(t["temp"])
|
|
||||||
if t["loc"] != "":
|
|
||||||
curExtraction.addArg(t["loc"])
|
|
||||||
|
|
||||||
d[sentence].append(curExtraction)
|
|
||||||
self.oie = d
|
|
|
@ -1,109 +0,0 @@
|
||||||
import string
|
|
||||||
import nltk
|
|
||||||
from nltk.translate.bleu_score import sentence_bleu
|
|
||||||
from nltk.corpus import stopwords
|
|
||||||
from nltk.stem import WordNetLemmatizer
|
|
||||||
nltk.download('wordnet', quiet=True)
|
|
||||||
lemmatizer = WordNetLemmatizer()
|
|
||||||
|
|
||||||
|
|
||||||
class Matcher:
|
|
||||||
@staticmethod
|
|
||||||
def bowMatch(ref, ex, ignoreStopwords, ignoreCase):
|
|
||||||
"""
|
|
||||||
A binary function testing for exact lexical match (ignoring ordering) between reference
|
|
||||||
and predicted extraction
|
|
||||||
"""
|
|
||||||
s1 = ref.bow()
|
|
||||||
s2 = ex.bow()
|
|
||||||
if ignoreCase:
|
|
||||||
s1 = s1.lower()
|
|
||||||
s2 = s2.lower()
|
|
||||||
|
|
||||||
s1Words = s1.split(' ')
|
|
||||||
s2Words = s2.split(' ')
|
|
||||||
|
|
||||||
if ignoreStopwords:
|
|
||||||
s1Words = Matcher.removeStopwords(s1Words)
|
|
||||||
s2Words = Matcher.removeStopwords(s2Words)
|
|
||||||
|
|
||||||
return sorted(s1Words) == sorted(s2Words)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def predMatch(ref, ex, ignoreStopwords, ignoreCase):
|
|
||||||
"""
|
|
||||||
Return whehter gold and predicted extractions agree on the predicate
|
|
||||||
"""
|
|
||||||
s1 = ref.elementToStr(ref.pred)
|
|
||||||
s2 = ex.elementToStr(ex.pred)
|
|
||||||
if ignoreCase:
|
|
||||||
s1 = s1.lower()
|
|
||||||
s2 = s2.lower()
|
|
||||||
|
|
||||||
s1Words = s1.split(' ')
|
|
||||||
s2Words = s2.split(' ')
|
|
||||||
|
|
||||||
if ignoreStopwords:
|
|
||||||
s1Words = Matcher.removeStopwords(s1Words)
|
|
||||||
s2Words = Matcher.removeStopwords(s2Words)
|
|
||||||
|
|
||||||
return s1Words == s2Words
|
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def argMatch(ref, ex, ignoreStopwords, ignoreCase):
|
|
||||||
"""
|
|
||||||
Return whehter gold and predicted extractions agree on the arguments
|
|
||||||
"""
|
|
||||||
sRef = ' '.join([ref.elementToStr(elem) for elem in ref.args])
|
|
||||||
sEx = ' '.join([ex.elementToStr(elem) for elem in ex.args])
|
|
||||||
|
|
||||||
count = 0
|
|
||||||
|
|
||||||
for w1 in sRef:
|
|
||||||
for w2 in sEx:
|
|
||||||
if w1 == w2:
|
|
||||||
count += 1
|
|
||||||
|
|
||||||
# We check how well does the extraction lexically cover the reference
|
|
||||||
# Note: this is somewhat lenient as it doesn't penalize the extraction for
|
|
||||||
# being too long
|
|
||||||
coverage = float(count) / len(sRef)
|
|
||||||
|
|
||||||
|
|
||||||
return coverage > Matcher.LEXICAL_THRESHOLD
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def bleuMatch(ref, ex, ignoreStopwords, ignoreCase):
|
|
||||||
sRef = ref.bow()
|
|
||||||
sEx = ex.bow()
|
|
||||||
bleu = sentence_bleu(references = [sRef.split(' ')], hypothesis = sEx.split(' '))
|
|
||||||
return bleu > Matcher.BLEU_THRESHOLD
|
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def lexicalMatch(ref, ex, ignoreStopwords, ignoreCase):
|
|
||||||
sRef = ref.bow().split(' ')
|
|
||||||
sEx = ex.bow().split(' ')
|
|
||||||
count = 0
|
|
||||||
|
|
||||||
for w1 in sRef:
|
|
||||||
for w2 in sEx:
|
|
||||||
if w1 == w2:
|
|
||||||
count += 1
|
|
||||||
|
|
||||||
# We check how well does the extraction lexically cover the reference
|
|
||||||
# Note: this is somewhat lenient as it doesn't penalize the extraction for
|
|
||||||
# being too long
|
|
||||||
coverage = float(count) / len(sRef)
|
|
||||||
return coverage > Matcher.LEXICAL_THRESHOLD
|
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def removeStopwords(ls):
|
|
||||||
return [w for w in ls if w.lower() not in Matcher.stopwords]
|
|
||||||
|
|
||||||
# CONSTANTS
|
|
||||||
BLEU_THRESHOLD = 0.4
|
|
||||||
LEXICAL_THRESHOLD = 0.5 # Note: changing this value didn't change the ordering of the tested systems
|
|
||||||
stopwords = stopwords.words('english') + list(string.punctuation)
|
|
|
@ -1,45 +0,0 @@
|
||||||
class OieReader:
|
|
||||||
|
|
||||||
def read(self, fn, includeNominal):
|
|
||||||
''' should set oie as a class member
|
|
||||||
as a dictionary of extractions by sentence'''
|
|
||||||
raise Exception("Don't run me")
|
|
||||||
|
|
||||||
def count(self):
|
|
||||||
''' number of extractions '''
|
|
||||||
return sum([len(extractions) for _, extractions in self.oie.items()])
|
|
||||||
|
|
||||||
def split_to_corpus(self, corpus_fn, out_fn):
|
|
||||||
"""
|
|
||||||
Given a corpus file name, containing a list of sentences
|
|
||||||
print only the extractions pertaining to it to out_fn in a tab separated format:
|
|
||||||
sent, prob, pred, arg1, arg2, ...
|
|
||||||
"""
|
|
||||||
raw_sents = [line.strip() for line in open(corpus_fn)]
|
|
||||||
with open(out_fn, 'w') as fout:
|
|
||||||
for line in self.get_tabbed().split('\n'):
|
|
||||||
data = line.split('\t')
|
|
||||||
sent = data[0]
|
|
||||||
if sent in raw_sents:
|
|
||||||
fout.write(line + '\n')
|
|
||||||
|
|
||||||
def output_tabbed(self, out_fn):
|
|
||||||
"""
|
|
||||||
Write a tabbed represenation of this corpus.
|
|
||||||
"""
|
|
||||||
with open(out_fn, 'w') as fout:
|
|
||||||
fout.write(self.get_tabbed())
|
|
||||||
|
|
||||||
def get_tabbed(self):
|
|
||||||
"""
|
|
||||||
Get a tabbed format representation of this corpus (assumes that input was
|
|
||||||
already read).
|
|
||||||
"""
|
|
||||||
return "\n".join(['\t'.join(map(str,
|
|
||||||
[ex.sent,
|
|
||||||
ex.confidence,
|
|
||||||
ex.pred,
|
|
||||||
'\t'.join(ex.args)]))
|
|
||||||
for (sent, exs) in self.oie.iteritems()
|
|
||||||
for ex in exs])
|
|
||||||
|
|
|
@ -1,6 +0,0 @@
|
||||||
# Evaluation
|
|
||||||
This code is mainly based on the code from the origin [OIE2016 repository](https://github.com/gabrielStanovsky/oie-benchmark). The command to run the code is <br>
|
|
||||||
```python evaluate.py [new/old] input_file output_file```<old> <br>
|
|
||||||
"new" means that we use Re-OIE2016 as benchmark and "old" means that we use OIE2016 as benchmark. The input_file is the extraction of openIE system. Each line follows the following format (separated by tab):<br>
|
|
||||||
```sentence confidence_score predicate arg0 arg1 arg2 ...``` <br>
|
|
||||||
The script will output the AUC and best F1 score of the system. And the output file is used to draw the pr-curve. The script to draw the pr-curve is [here](https://github.com/gabrielStanovsky/oie-benchmark/blob/master/pr_plot.py).
|
|
|
@ -1,72 +0,0 @@
|
||||||
import os
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
import utils.bio as bio
|
|
||||||
from transformers import BertTokenizer
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
|
|
||||||
def extract(args,
|
|
||||||
model,
|
|
||||||
loader,
|
|
||||||
output_path):
|
|
||||||
model.eval()
|
|
||||||
os.makedirs(output_path, exist_ok=True)
|
|
||||||
extraction_path = os.path.join(output_path, "extraction.txt")
|
|
||||||
tokenizer = BertTokenizer.from_pretrained(args.bert_config)
|
|
||||||
f = open(extraction_path, 'w')
|
|
||||||
|
|
||||||
for step, batch in tqdm(enumerate(loader), desc='eval_steps', total=len(loader)):
|
|
||||||
token_strs = [[word for word in sent] for sent in np.asarray(batch[-2]).T]
|
|
||||||
sentences = batch[-1]
|
|
||||||
token_ids, att_mask = map(lambda x: x.to(args.device), batch[:-2])
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
"""
|
|
||||||
We will iterate B(batch_size) times
|
|
||||||
because there are more than one predicate in one batch.
|
|
||||||
In feeding to argument extractor, # of predicates takes a role as batch size.
|
|
||||||
|
|
||||||
pred_logit: (B, L, 3)
|
|
||||||
pred_hidden: (B, L, D)
|
|
||||||
pred_tags: (B, P, L) ~ list of tensors, where P is # of predicate in each batch
|
|
||||||
"""
|
|
||||||
pred_logit, pred_hidden = model.extract_predicate(
|
|
||||||
input_ids=token_ids, attention_mask=att_mask)
|
|
||||||
pred_tags = torch.argmax(pred_logit, 2)
|
|
||||||
pred_tags = bio.filter_pred_tags(pred_tags, token_strs)
|
|
||||||
pred_tags = bio.get_single_predicate_idxs(pred_tags)
|
|
||||||
pred_probs = torch.nn.Softmax(2)(pred_logit)
|
|
||||||
|
|
||||||
# iterate B times (one iteration means extraction for one sentence)
|
|
||||||
for cur_pred_tags, cur_pred_hidden, cur_att_mask, cur_token_id, cur_pred_probs, token_str, sentence \
|
|
||||||
in zip(pred_tags, pred_hidden, att_mask, token_ids, pred_probs, token_strs, sentences):
|
|
||||||
|
|
||||||
# generate temporary batch for this sentence and feed to argument module
|
|
||||||
cur_pred_masks = bio.get_pred_mask(cur_pred_tags).to(args.device)
|
|
||||||
n_predicates = cur_pred_masks.shape[0]
|
|
||||||
if n_predicates == 0:
|
|
||||||
continue # if there is no predicate, we cannot extract.
|
|
||||||
cur_pred_hidden = torch.cat(n_predicates * [cur_pred_hidden.unsqueeze(0)])
|
|
||||||
cur_token_id = torch.cat(n_predicates * [cur_token_id.unsqueeze(0)])
|
|
||||||
cur_arg_logit = model.extract_argument(
|
|
||||||
input_ids=cur_token_id,
|
|
||||||
predicate_hidden=cur_pred_hidden,
|
|
||||||
predicate_mask=cur_pred_masks)
|
|
||||||
|
|
||||||
# filter and get argument tags with highest probability
|
|
||||||
cur_arg_tags = torch.argmax(cur_arg_logit, 2)
|
|
||||||
cur_arg_probs = torch.nn.Softmax(2)(cur_arg_logit)
|
|
||||||
cur_arg_tags = bio.filter_arg_tags(cur_arg_tags, cur_pred_tags, token_str)
|
|
||||||
|
|
||||||
# get string tuples and write results
|
|
||||||
cur_extractions, cur_extraction_idxs = bio.get_tuple(sentence, cur_pred_tags, cur_arg_tags, tokenizer)
|
|
||||||
cur_confidences = bio.get_confidence_score(cur_pred_probs, cur_arg_probs, cur_extraction_idxs)
|
|
||||||
for extraction, confidence in zip(cur_extractions, cur_confidences):
|
|
||||||
if args.binary:
|
|
||||||
f.write("\t".join([sentence] + [str(1.0)] + extraction[:3]) + '\n')
|
|
||||||
else:
|
|
||||||
f.write("\t".join([sentence] + [str(confidence)] + extraction) + '\n')
|
|
||||||
f.close()
|
|
||||||
print("\nExtraction Done.\n")
|
|
||||||
|
|
Binary file not shown.
Before Width: | Height: | Size: 124 KiB |
Binary file not shown.
Before Width: | Height: | Size: 49 KiB |
|
@ -1,112 +0,0 @@
|
||||||
import argparse
|
|
||||||
import os
|
|
||||||
import torch
|
|
||||||
from utils import utils
|
|
||||||
from utils.utils import SummaryManager
|
|
||||||
from dataset import load_data
|
|
||||||
from tqdm import tqdm
|
|
||||||
from train import train
|
|
||||||
from extract import extract
|
|
||||||
from src.Multi2OIE.test import do_eval
|
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
|
||||||
utils.set_seed(args.seed)
|
|
||||||
model = utils.get_models(
|
|
||||||
bert_config=args.bert_config,
|
|
||||||
pred_n_labels=args.pred_n_labels,
|
|
||||||
arg_n_labels=args.arg_n_labels,
|
|
||||||
n_arg_heads=args.n_arg_heads,
|
|
||||||
n_arg_layers=args.n_arg_layers,
|
|
||||||
lstm_dropout=args.lstm_dropout,
|
|
||||||
mh_dropout=args.mh_dropout,
|
|
||||||
pred_clf_dropout=args.pred_clf_dropout,
|
|
||||||
arg_clf_dropout=args.arg_clf_dropout,
|
|
||||||
pos_emb_dim=args.pos_emb_dim,
|
|
||||||
use_lstm=args.use_lstm,
|
|
||||||
device=args.device)
|
|
||||||
|
|
||||||
trn_loader = load_data(
|
|
||||||
data_path=args.trn_data_path,
|
|
||||||
batch_size=args.batch_size,
|
|
||||||
max_len=args.max_len,
|
|
||||||
tokenizer_config=args.bert_config)
|
|
||||||
dev_loaders = [
|
|
||||||
load_data(
|
|
||||||
data_path=cur_dev_path,
|
|
||||||
batch_size=args.dev_batch_size,
|
|
||||||
tokenizer_config=args.bert_config,
|
|
||||||
train=False)
|
|
||||||
for cur_dev_path in args.dev_data_path]
|
|
||||||
args.total_steps = round(len(trn_loader) * args.epochs)
|
|
||||||
args.warmup_steps = round(args.total_steps / 10)
|
|
||||||
|
|
||||||
optimizer, scheduler = utils.get_train_modules(
|
|
||||||
model=model,
|
|
||||||
lr=args.learning_rate,
|
|
||||||
total_steps=args.total_steps,
|
|
||||||
warmup_steps=args.warmup_steps)
|
|
||||||
model.zero_grad()
|
|
||||||
summarizer = SummaryManager(args)
|
|
||||||
print("\nTraining Starts\n")
|
|
||||||
|
|
||||||
for epoch in tqdm(range(1, args.epochs + 1), desc='epochs'):
|
|
||||||
trn_results = train(
|
|
||||||
args, epoch, model, trn_loader, dev_loaders,
|
|
||||||
summarizer, optimizer, scheduler)
|
|
||||||
|
|
||||||
# extraction on devset
|
|
||||||
dev_iter = zip(args.dev_data_path, args.dev_gold_path, dev_loaders)
|
|
||||||
dev_results = list()
|
|
||||||
total_sum = 0
|
|
||||||
for dev_input, dev_gold, dev_loader in dev_iter:
|
|
||||||
dev_name = dev_input.split('/')[-1].replace('.pkl', '')
|
|
||||||
output_path = os.path.join(args.save_path, f'epoch{epoch}_dev/end_epoch/{dev_name}')
|
|
||||||
extract(args, model, dev_loader, output_path)
|
|
||||||
dev_result = do_eval(output_path, dev_gold)
|
|
||||||
utils.print_results(f"EPOCH{epoch} EVAL",
|
|
||||||
dev_result, ["F1 ", "PREC", "REC ", "AUC "])
|
|
||||||
total_sum += dev_result[0] + dev_result[-1]
|
|
||||||
dev_result.append(dev_result[0] + dev_result[-1])
|
|
||||||
dev_results += dev_result
|
|
||||||
summarizer.save_results([epoch] + trn_results + dev_results + [total_sum])
|
|
||||||
model_name = utils.set_model_name(total_sum, epoch)
|
|
||||||
torch.save(model.state_dict(), os.path.join(args.save_path, model_name))
|
|
||||||
print("\nTraining Ended\n")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
|
|
||||||
# settings
|
|
||||||
parser.add_argument('--seed', type=int, default=1)
|
|
||||||
parser.add_argument('--save_path', default='./results')
|
|
||||||
parser.add_argument('--bert_config', default='bert-base-cased', help='or bert-base-multilingual-cased')
|
|
||||||
parser.add_argument('--trn_data_path', default='./datasets/openie4_train.pkl')
|
|
||||||
parser.add_argument('--dev_data_path', nargs='+', default=['./datasets/oie2016_dev.pkl', './datasets/carb_dev.pkl'])
|
|
||||||
parser.add_argument('--dev_gold_path', nargs='+', default=['./evaluate/OIE2016_dev.txt', './carb/CaRB_dev.tsv'])
|
|
||||||
parser.add_argument('--max_len', type=int, default=64)
|
|
||||||
parser.add_argument('--device', default='cuda:0')
|
|
||||||
parser.add_argument('--visible_device', default="0")
|
|
||||||
parser.add_argument('--summary_step', type=int, default=100)
|
|
||||||
parser.add_argument('--use_lstm', nargs='?', const=True, default=False, type=utils.str2bool)
|
|
||||||
parser.add_argument('--binary', nargs='?', const=True, default=False, type=utils.str2bool)
|
|
||||||
|
|
||||||
# hyper-parameters
|
|
||||||
parser.add_argument('--epochs', type=int, default=1)
|
|
||||||
parser.add_argument('--lstm_dropout', type=float, default=0.)
|
|
||||||
parser.add_argument('--mh_dropout', type=float, default=0.2)
|
|
||||||
parser.add_argument('--pred_clf_dropout', type=float, default=0.)
|
|
||||||
parser.add_argument('--arg_clf_dropout', type=float, default=0.2)
|
|
||||||
parser.add_argument('--batch_size', type=int, default=128)
|
|
||||||
parser.add_argument('--dev_batch_size', type=int, default=32)
|
|
||||||
parser.add_argument('--learning_rate', type=float, default=3e-5)
|
|
||||||
parser.add_argument('--n_arg_heads', type=int, default=8)
|
|
||||||
parser.add_argument('--n_arg_layers', type=int, default=4)
|
|
||||||
parser.add_argument('--pos_emb_dim', type=int, default=64)
|
|
||||||
main_args = parser.parse_args()
|
|
||||||
|
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = main_args.visible_device
|
|
||||||
main_args = utils.clean_config(main_args)
|
|
||||||
main(main_args)
|
|
||||||
|
|
|
@ -1,316 +0,0 @@
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import copy
|
|
||||||
from torch.nn.modules.container import ModuleList
|
|
||||||
from transformers import BertModel
|
|
||||||
|
|
||||||
|
|
||||||
class ArgModule(nn.Module):
|
|
||||||
def __init__(self, arg_layer, n_layers):
|
|
||||||
"""
|
|
||||||
Module for extracting arguments based on given encoder output and predicates.
|
|
||||||
It uses ArgExtractorLayer as a base block and repeat the block N('n_layers') times
|
|
||||||
|
|
||||||
:param arg_layer: an instance of the ArgExtractorLayer() class (required)
|
|
||||||
:param n_layers: the number of sub-layers in the ArgModule (required).
|
|
||||||
"""
|
|
||||||
super(ArgModule, self).__init__()
|
|
||||||
self.layers = _get_clones(arg_layer, n_layers)
|
|
||||||
self.n_layers = n_layers
|
|
||||||
|
|
||||||
def forward(self, encoded, predicate, pred_mask=None):
|
|
||||||
"""
|
|
||||||
:param encoded: output from sentence encoder with the shape of (L, B, D),
|
|
||||||
where L is the sequence length, B is the batch size, D is the embedding dimension
|
|
||||||
:param predicate: output from predicate module with the shape of (L, B, D)
|
|
||||||
:param pred_mask: mask that prevents attention to tokens which are not predicates
|
|
||||||
with the shape of (B, L)
|
|
||||||
:return: tensor like Transformer Decoder Layer Output
|
|
||||||
"""
|
|
||||||
output = encoded
|
|
||||||
for layer_idx in range(self.n_layers):
|
|
||||||
output = self.layers[layer_idx](
|
|
||||||
target=output, source=predicate, key_mask=pred_mask)
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
class ArgExtractorLayer(nn.Module):
|
|
||||||
def __init__(self,
|
|
||||||
d_model=768,
|
|
||||||
n_heads=8,
|
|
||||||
d_feedforward=2048,
|
|
||||||
dropout=0.1,
|
|
||||||
activation='relu'):
|
|
||||||
"""
|
|
||||||
A layer similar to Transformer decoder without decoder self-attention.
|
|
||||||
(only encoder-decoder multi-head attention followed by feed-forward layers)
|
|
||||||
|
|
||||||
:param d_model: model dimensionality (default=768 from BERT-base)
|
|
||||||
:param n_heads: number of heads in multi-head attention layer
|
|
||||||
:param d_feedforward: dimensionality of point-wise feed-forward layer
|
|
||||||
:param dropout: drop rate of all layers
|
|
||||||
:param activation: activation function after first feed-forward layer
|
|
||||||
"""
|
|
||||||
super(ArgExtractorLayer, self).__init__()
|
|
||||||
self.multihead_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
|
|
||||||
self.linear1 = nn.Linear(d_model, d_feedforward)
|
|
||||||
self.dropout1 = nn.Dropout(dropout)
|
|
||||||
self.linear2 = nn.Linear(d_feedforward, d_model)
|
|
||||||
|
|
||||||
self.norm1 = nn.LayerNorm(d_model)
|
|
||||||
self.norm2 = nn.LayerNorm(d_model)
|
|
||||||
self.dropout2 = nn.Dropout(dropout)
|
|
||||||
self.dropout3 = nn.Dropout(dropout)
|
|
||||||
self.activation = _get_activation_fn(activation)
|
|
||||||
|
|
||||||
def forward(self, target, source, key_mask=None):
|
|
||||||
"""
|
|
||||||
Single Transformer Decoder layer without self-attention
|
|
||||||
|
|
||||||
:param target: a tensor which takes a role as a query
|
|
||||||
:param source: a tensor which takes a role as a key & value
|
|
||||||
:param key_mask: key mask tensor with the shape of (batch_size, sequence_length)
|
|
||||||
"""
|
|
||||||
# Multi-head attention layer (+ add & norm)
|
|
||||||
attended = self.multihead_attn(
|
|
||||||
target, source, source,
|
|
||||||
key_padding_mask=key_mask)[0]
|
|
||||||
skipped = target + self.dropout1(attended)
|
|
||||||
normed = self.norm1(skipped)
|
|
||||||
|
|
||||||
# Point-wise feed-forward layer (+ add & norm)
|
|
||||||
projected = self.linear2(self.dropout2(self.activation(self.linear1(normed))))
|
|
||||||
skipped = normed + self.dropout1(projected)
|
|
||||||
normed = self.norm2(skipped)
|
|
||||||
return normed
|
|
||||||
|
|
||||||
|
|
||||||
class Multi2OIE(nn.Module):
|
|
||||||
def __init__(self,
|
|
||||||
bert_config='bert-base-cased',
|
|
||||||
mh_dropout=0.1,
|
|
||||||
pred_clf_dropout=0.,
|
|
||||||
arg_clf_dropout=0.3,
|
|
||||||
n_arg_heads=8,
|
|
||||||
n_arg_layers=4,
|
|
||||||
pos_emb_dim=64,
|
|
||||||
pred_n_labels=3,
|
|
||||||
arg_n_labels=9):
|
|
||||||
super(Multi2OIE, self).__init__()
|
|
||||||
self.pred_n_labels = pred_n_labels
|
|
||||||
self.arg_n_labels = arg_n_labels
|
|
||||||
|
|
||||||
self.bert = BertModel.from_pretrained(
|
|
||||||
bert_config,
|
|
||||||
output_hidden_states=True)
|
|
||||||
d_model = self.bert.config.hidden_size
|
|
||||||
self.pred_dropout = nn.Dropout(pred_clf_dropout)
|
|
||||||
self.pred_classifier = nn.Linear(d_model, self.pred_n_labels)
|
|
||||||
|
|
||||||
self.position_emb = nn.Embedding(3, pos_emb_dim, padding_idx=0)
|
|
||||||
d_model += (d_model + pos_emb_dim)
|
|
||||||
arg_layer = ArgExtractorLayer(
|
|
||||||
d_model=d_model,
|
|
||||||
n_heads=n_arg_heads,
|
|
||||||
dropout=mh_dropout)
|
|
||||||
self.arg_module = ArgModule(arg_layer, n_arg_layers)
|
|
||||||
self.arg_dropout = nn.Dropout(arg_clf_dropout)
|
|
||||||
self.arg_classifier = nn.Linear(d_model, arg_n_labels)
|
|
||||||
|
|
||||||
def forward(self,
|
|
||||||
input_ids,
|
|
||||||
attention_mask,
|
|
||||||
predicate_mask=None,
|
|
||||||
predicate_hidden=None,
|
|
||||||
total_pred_labels=None,
|
|
||||||
arg_labels=None):
|
|
||||||
|
|
||||||
# predicate extraction
|
|
||||||
bert_hidden = self.bert(input_ids, attention_mask)[0]
|
|
||||||
pred_logit = self.pred_classifier(self.pred_dropout(bert_hidden))
|
|
||||||
|
|
||||||
# predicate loss
|
|
||||||
if total_pred_labels is not None:
|
|
||||||
loss_fct = nn.CrossEntropyLoss()
|
|
||||||
active_loss = attention_mask.view(-1) == 1
|
|
||||||
active_logits = pred_logit.view(-1, self.pred_n_labels)
|
|
||||||
active_labels = torch.where(
|
|
||||||
active_loss, total_pred_labels.view(-1),
|
|
||||||
torch.tensor(loss_fct.ignore_index).type_as(total_pred_labels))
|
|
||||||
pred_loss = loss_fct(active_logits, active_labels)
|
|
||||||
|
|
||||||
# inputs for argument extraction
|
|
||||||
pred_feature = _get_pred_feature(bert_hidden, predicate_mask)
|
|
||||||
position_vectors = self.position_emb(_get_position_idxs(predicate_mask, input_ids))
|
|
||||||
bert_hidden = torch.cat([bert_hidden, pred_feature, position_vectors], dim=2)
|
|
||||||
bert_hidden = bert_hidden.transpose(0, 1)
|
|
||||||
|
|
||||||
# argument extraction
|
|
||||||
arg_hidden = self.arg_module(bert_hidden, bert_hidden, predicate_mask)
|
|
||||||
arg_hidden = arg_hidden.transpose(0, 1)
|
|
||||||
arg_logit = self.arg_classifier(self.arg_dropout(arg_hidden))
|
|
||||||
|
|
||||||
# argument loss
|
|
||||||
if arg_labels is not None:
|
|
||||||
loss_fct = nn.CrossEntropyLoss()
|
|
||||||
active_loss = attention_mask.view(-1) == 1
|
|
||||||
active_logits = arg_logit.view(-1, self.arg_n_labels)
|
|
||||||
active_labels = torch.where(
|
|
||||||
active_loss, arg_labels.view(-1),
|
|
||||||
torch.tensor(loss_fct.ignore_index).type_as(arg_labels))
|
|
||||||
arg_loss = loss_fct(active_logits, active_labels)
|
|
||||||
|
|
||||||
# total loss
|
|
||||||
batch_loss = pred_loss + arg_loss
|
|
||||||
outputs = (batch_loss, pred_loss, arg_loss)
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
def extract_predicate(self,
|
|
||||||
input_ids,
|
|
||||||
attention_mask):
|
|
||||||
bert_hidden = self.bert(input_ids, attention_mask)[0]
|
|
||||||
pred_logit = self.pred_classifier(bert_hidden)
|
|
||||||
return pred_logit, bert_hidden
|
|
||||||
|
|
||||||
def extract_argument(self,
|
|
||||||
input_ids,
|
|
||||||
predicate_hidden,
|
|
||||||
predicate_mask):
|
|
||||||
pred_feature = _get_pred_feature(predicate_hidden, predicate_mask)
|
|
||||||
position_vectors = self.position_emb(_get_position_idxs(predicate_mask, input_ids))
|
|
||||||
arg_input = torch.cat([predicate_hidden, pred_feature, position_vectors], dim=2)
|
|
||||||
arg_input = arg_input.transpose(0, 1)
|
|
||||||
arg_hidden = self.arg_module(arg_input, arg_input, predicate_mask)
|
|
||||||
arg_hidden = arg_hidden.transpose(0, 1)
|
|
||||||
return self.arg_classifier(arg_hidden)
|
|
||||||
|
|
||||||
|
|
||||||
class BERTBiLSTM(nn.Module):
|
|
||||||
def __init__(self,
|
|
||||||
bert_config='bert-base-cased',
|
|
||||||
lstm_dropout=0.3,
|
|
||||||
pred_clf_dropout=0.,
|
|
||||||
arg_clf_dropout=0.3,
|
|
||||||
pos_emb_dim=256,
|
|
||||||
pred_n_labels=3,
|
|
||||||
arg_n_labels=9):
|
|
||||||
super(BERTBiLSTM, self).__init__()
|
|
||||||
self.pred_n_labels = pred_n_labels
|
|
||||||
self.arg_n_labels = arg_n_labels
|
|
||||||
|
|
||||||
self.bert = BertModel.from_pretrained(
|
|
||||||
bert_config,
|
|
||||||
output_hidden_states=True)
|
|
||||||
d_model = self.bert.config.hidden_size
|
|
||||||
self.pred_dropout = nn.Dropout(pred_clf_dropout)
|
|
||||||
self.pred_classifier = nn.Linear(d_model, self.pred_n_labels)
|
|
||||||
|
|
||||||
self.position_emb = nn.Embedding(3, pos_emb_dim, padding_idx=0)
|
|
||||||
d_model += pos_emb_dim
|
|
||||||
self.arg_module = nn.LSTM(
|
|
||||||
input_size=d_model,
|
|
||||||
hidden_size=d_model,
|
|
||||||
num_layers=3,
|
|
||||||
dropout=lstm_dropout,
|
|
||||||
batch_first=True,
|
|
||||||
bidirectional=True)
|
|
||||||
self.arg_dropout = nn.Dropout(arg_clf_dropout)
|
|
||||||
self.arg_classifier = nn.Linear(d_model * 2, arg_n_labels)
|
|
||||||
|
|
||||||
def forward(self,
|
|
||||||
input_ids,
|
|
||||||
attention_mask,
|
|
||||||
predicate_mask=None,
|
|
||||||
predicate_hidden=None,
|
|
||||||
total_pred_labels=None,
|
|
||||||
arg_labels=None):
|
|
||||||
|
|
||||||
# predicate extraction
|
|
||||||
bert_hidden = self.bert(input_ids, attention_mask)[0]
|
|
||||||
pred_logit = self.pred_classifier(self.pred_dropout(bert_hidden))
|
|
||||||
|
|
||||||
# predicate loss
|
|
||||||
if total_pred_labels is not None:
|
|
||||||
loss_fct = nn.CrossEntropyLoss()
|
|
||||||
active_loss = attention_mask.view(-1) == 1
|
|
||||||
active_logits = pred_logit.view(-1, self.pred_n_labels)
|
|
||||||
active_labels = torch.where(
|
|
||||||
active_loss, total_pred_labels.view(-1),
|
|
||||||
torch.tensor(loss_fct.ignore_index).type_as(total_pred_labels))
|
|
||||||
pred_loss = loss_fct(active_logits, active_labels)
|
|
||||||
|
|
||||||
# argument extraction
|
|
||||||
position_vectors = self.position_emb(_get_position_idxs(predicate_mask, input_ids))
|
|
||||||
bert_hidden = torch.cat([bert_hidden, position_vectors], dim=2)
|
|
||||||
arg_hidden = self.arg_module(bert_hidden)[0]
|
|
||||||
arg_logit = self.arg_classifier(self.arg_dropout(arg_hidden))
|
|
||||||
|
|
||||||
# argument loss
|
|
||||||
if arg_labels is not None:
|
|
||||||
loss_fct = nn.CrossEntropyLoss()
|
|
||||||
active_loss = attention_mask.view(-1) == 1
|
|
||||||
active_logits = arg_logit.view(-1, self.arg_n_labels)
|
|
||||||
active_labels = torch.where(
|
|
||||||
active_loss, arg_labels.view(-1),
|
|
||||||
torch.tensor(loss_fct.ignore_index).type_as(arg_labels))
|
|
||||||
arg_loss = loss_fct(active_logits, active_labels)
|
|
||||||
|
|
||||||
# total loss
|
|
||||||
batch_loss = pred_loss + arg_loss
|
|
||||||
outputs = (batch_loss, pred_loss, arg_loss)
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
def extract_predicate(self,
|
|
||||||
input_ids,
|
|
||||||
attention_mask):
|
|
||||||
bert_hidden = self.bert(input_ids, attention_mask)[0]
|
|
||||||
pred_logit = self.pred_classifier(bert_hidden)
|
|
||||||
return pred_logit, bert_hidden
|
|
||||||
|
|
||||||
def extract_argument(self,
|
|
||||||
input_ids,
|
|
||||||
predicate_hidden,
|
|
||||||
predicate_mask):
|
|
||||||
position_vectors = self.position_emb(_get_position_idxs(predicate_mask, input_ids))
|
|
||||||
arg_input = torch.cat([predicate_hidden, position_vectors], dim=2)
|
|
||||||
arg_hidden = self.arg_module(arg_input)[0]
|
|
||||||
return self.arg_classifier(arg_hidden)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_activation_fn(activation):
|
|
||||||
if activation == "relu":
|
|
||||||
return F.relu
|
|
||||||
elif activation == "gelu":
|
|
||||||
return F.gelu
|
|
||||||
else:
|
|
||||||
raise RuntimeError("activation should be relu/gelu, not %s." % activation)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_clones(module, n):
|
|
||||||
return ModuleList([copy.deepcopy(module) for _ in range(n)])
|
|
||||||
|
|
||||||
|
|
||||||
def _get_position_idxs(pred_mask, input_ids):
|
|
||||||
position_idxs = torch.zeros(pred_mask.shape, dtype=int, device=pred_mask.device)
|
|
||||||
for mask_idx, cur_mask in enumerate(pred_mask):
|
|
||||||
position_idxs[mask_idx, :] += 2
|
|
||||||
cur_nonzero = (cur_mask == 0).nonzero()
|
|
||||||
start = torch.min(cur_nonzero).item()
|
|
||||||
end = torch.max(cur_nonzero).item()
|
|
||||||
position_idxs[mask_idx, start:end + 1] = 1
|
|
||||||
pad_start = max(input_ids[mask_idx].nonzero()).item() + 1
|
|
||||||
position_idxs[mask_idx, pad_start:] = 0
|
|
||||||
return position_idxs
|
|
||||||
|
|
||||||
|
|
||||||
def _get_pred_feature(pred_hidden, pred_mask):
|
|
||||||
B, L, D = pred_hidden.shape
|
|
||||||
pred_features = torch.zeros((B, L, D), device=pred_mask.device)
|
|
||||||
for mask_idx, cur_mask in enumerate(pred_mask):
|
|
||||||
pred_position = (cur_mask == 0).nonzero().flatten()
|
|
||||||
pred_feature = torch.mean(pred_hidden[mask_idx, pred_position], dim=0)
|
|
||||||
pred_feature = torch.cat(L * [pred_feature.unsqueeze(0)])
|
|
||||||
pred_features[mask_idx, :, :] = pred_feature
|
|
||||||
return pred_features
|
|
||||||
|
|
|
@ -1,59 +0,0 @@
|
||||||
attrs==19.3.0
|
|
||||||
backcall==0.1.0
|
|
||||||
blis==0.4.1
|
|
||||||
brotlipy==0.7.0
|
|
||||||
catalogue==1.0.0
|
|
||||||
certifi==2020.4.5.1
|
|
||||||
cffi==1.14.0
|
|
||||||
chardet==3.0.4
|
|
||||||
click==7.1.2
|
|
||||||
cryptography==2.9.2
|
|
||||||
cymem==2.0.3
|
|
||||||
decorator==4.4.2
|
|
||||||
docopt==0.6.2
|
|
||||||
filelock==3.0.12
|
|
||||||
idna==2.9
|
|
||||||
importlib-metadata==1.6.0
|
|
||||||
ipdb==0.13.2
|
|
||||||
ipython==7.14.0
|
|
||||||
ipython-genutils==0.2.0
|
|
||||||
jedi==0.17.0
|
|
||||||
joblib==0.15.1
|
|
||||||
jsonschema==3.2.0
|
|
||||||
murmurhash==1.0.0
|
|
||||||
nltk==3.5
|
|
||||||
numpy==1.18.1
|
|
||||||
pandas==1.0.3
|
|
||||||
parso==0.7.0
|
|
||||||
pexpect==4.8.0
|
|
||||||
pickleshare==0.7.5
|
|
||||||
plac==0.9.6
|
|
||||||
preshed==3.0.2
|
|
||||||
prompt-toolkit==3.0.5
|
|
||||||
ptyprocess==0.6.0
|
|
||||||
pycparser==2.20
|
|
||||||
Pygments==2.6.1
|
|
||||||
pyOpenSSL==19.1.0
|
|
||||||
pyrsistent==0.16.0
|
|
||||||
PySocks==1.7.1
|
|
||||||
python-dateutil==2.8.1
|
|
||||||
pytz==2020.1
|
|
||||||
regex==2020.5.14
|
|
||||||
requests==2.23.0
|
|
||||||
sacremoses==0.0.43
|
|
||||||
scikit-learn==0.22.1
|
|
||||||
scipy==1.4.1
|
|
||||||
sentencepiece==0.1.91
|
|
||||||
six==1.14.0
|
|
||||||
spacy==2.2.4
|
|
||||||
srsly==1.0.2
|
|
||||||
thinc==7.4.0
|
|
||||||
tokenizers==0.7.0
|
|
||||||
torch==1.4.0
|
|
||||||
tqdm==4.46.0
|
|
||||||
traitlets==4.3.3
|
|
||||||
transformers==2.10.0
|
|
||||||
urllib3==1.25.9
|
|
||||||
wasabi==0.6.0
|
|
||||||
wcwidth==0.1.9
|
|
||||||
zipp==3.1.0
|
|
|
@ -1,103 +0,0 @@
|
||||||
import argparse
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
import torch
|
|
||||||
from utils import utils
|
|
||||||
from dataset import load_data
|
|
||||||
from extract import extract
|
|
||||||
from evaluate.evaluate import Benchmark
|
|
||||||
from evaluate.matcher import Matcher
|
|
||||||
from evaluate.generalReader import GeneralReader
|
|
||||||
from carb.carb import Benchmark as CarbBenchmark
|
|
||||||
from carb.matcher import Matcher as CarbMatcher
|
|
||||||
from carb.tabReader import TabReader
|
|
||||||
|
|
||||||
|
|
||||||
def get_performance(output_path, gold_path):
|
|
||||||
auc, precision, recall, f1 = [None for _ in range(4)]
|
|
||||||
if 'evaluate' in gold_path:
|
|
||||||
matching_func = Matcher.lexicalMatch
|
|
||||||
error_fn = os.path.join(output_path, 'error_idxs.txt')
|
|
||||||
|
|
||||||
evaluator = Benchmark(gold_path)
|
|
||||||
reader = GeneralReader()
|
|
||||||
reader.read(os.path.join(output_path, 'extraction.txt'))
|
|
||||||
|
|
||||||
(precision, recall, f1), auc = evaluator.compare(
|
|
||||||
predicted=reader.oie,
|
|
||||||
matchingFunc=matching_func,
|
|
||||||
output_fn=os.path.join(output_path, 'pr_curve.txt'),
|
|
||||||
error_file=error_fn)
|
|
||||||
elif 'carb' in gold_path:
|
|
||||||
matching_func = CarbMatcher.binary_linient_tuple_match
|
|
||||||
error_fn = os.path.join(output_path, 'error_idxs.txt')
|
|
||||||
|
|
||||||
evaluator = CarbBenchmark(gold_path)
|
|
||||||
reader = TabReader()
|
|
||||||
reader.read(os.path.join(output_path, 'extraction.txt'))
|
|
||||||
|
|
||||||
auc, (precision, recall, f1) = evaluator.compare(
|
|
||||||
predicted=reader.oie,
|
|
||||||
matchingFunc=matching_func,
|
|
||||||
output_fn=os.path.join(output_path, 'pr_curve.txt'),
|
|
||||||
error_file=error_fn)
|
|
||||||
return auc, precision, recall, f1
|
|
||||||
|
|
||||||
|
|
||||||
def do_eval(output_path, gold_path):
|
|
||||||
auc, prec, rec, f1 = get_performance(output_path, gold_path)
|
|
||||||
eval_results = [f1, prec, rec, auc]
|
|
||||||
return eval_results
|
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
|
||||||
model = utils.get_models(
|
|
||||||
bert_config=args.bert_config,
|
|
||||||
pred_n_labels=args.pred_n_labels,
|
|
||||||
arg_n_labels=args.arg_n_labels,
|
|
||||||
n_arg_heads=args.n_arg_heads,
|
|
||||||
n_arg_layers=args.n_arg_layers,
|
|
||||||
pos_emb_dim=args.pos_emb_dim,
|
|
||||||
use_lstm=args.use_lstm,
|
|
||||||
device=args.device)
|
|
||||||
model.load_state_dict(torch.load(args.model_path))
|
|
||||||
model.zero_grad()
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
loader = load_data(
|
|
||||||
data_path=args.test_data_path,
|
|
||||||
batch_size=args.batch_size,
|
|
||||||
tokenizer_config=args.bert_config,
|
|
||||||
train=False)
|
|
||||||
start = time.time()
|
|
||||||
extract(args, model, loader, args.save_path)
|
|
||||||
print("TIME: ", time.time() - start)
|
|
||||||
test_results = do_eval(args.save_path, args.test_gold_path)
|
|
||||||
utils.print_results("TEST RESULT", test_results, ["F1 ", "PREC", "REC ", "AUC "])
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
|
|
||||||
parser.add_argument('--model_path', default='./results/model.bin')
|
|
||||||
parser.add_argument('--save_path', default='./results/carb_test')
|
|
||||||
parser.add_argument('--bert_config', default='bert-base-cased')
|
|
||||||
parser.add_argument('--test_data_path', default='./datasets/carb_test.pkl')
|
|
||||||
parser.add_argument('--test_gold_path', default='./carb/CaRB_test.tsv')
|
|
||||||
parser.add_argument('--device', default='cuda:0')
|
|
||||||
parser.add_argument('--visible_device', default="0")
|
|
||||||
parser.add_argument('--batch_size', type=int, default=1)
|
|
||||||
parser.add_argument('--pos_emb_dim', type=int, default=64)
|
|
||||||
parser.add_argument('--n_arg_heads', type=int, default=8)
|
|
||||||
parser.add_argument('--n_arg_layers', type=int, default=4)
|
|
||||||
parser.add_argument('--use_lstm', nargs='?', const=True, default=False, type=utils.str2bool)
|
|
||||||
parser.add_argument('--binary', nargs='?', const=True, default=False, type=utils.str2bool)
|
|
||||||
main_args = parser.parse_args()
|
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = main_args.visible_device
|
|
||||||
|
|
||||||
main_args.pred_n_labels = 3
|
|
||||||
main_args.arg_n_labels = 9
|
|
||||||
device = torch.device(main_args.device if torch.cuda.is_available() else 'cpu')
|
|
||||||
main_args.device = device
|
|
||||||
main(main_args)
|
|
||||||
|
|
|
@ -1,79 +0,0 @@
|
||||||
import os
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import utils.bio as bio
|
|
||||||
from tqdm import tqdm
|
|
||||||
from extract import extract
|
|
||||||
from utils import utils
|
|
||||||
from src.Multi2OIE.test import do_eval
|
|
||||||
|
|
||||||
|
|
||||||
def train(args,
|
|
||||||
epoch,
|
|
||||||
model,
|
|
||||||
trn_loader,
|
|
||||||
dev_loaders,
|
|
||||||
summarizer,
|
|
||||||
optimizer,
|
|
||||||
scheduler):
|
|
||||||
total_pred_loss, total_arg_loss, trn_results = 0, 0, None
|
|
||||||
epoch_steps = int(args.total_steps / args.epochs)
|
|
||||||
|
|
||||||
iterator = tqdm(enumerate(trn_loader), desc='steps', total=epoch_steps)
|
|
||||||
for step, batch in iterator:
|
|
||||||
batch = map(lambda x: x.to(args.device), batch)
|
|
||||||
token_ids, att_mask, single_pred_label, single_arg_label, all_pred_label = batch
|
|
||||||
pred_mask = bio.get_pred_mask(single_pred_label)
|
|
||||||
|
|
||||||
model.train()
|
|
||||||
model.zero_grad()
|
|
||||||
|
|
||||||
# feed to predicate model
|
|
||||||
batch_loss, pred_loss, arg_loss = model(
|
|
||||||
input_ids=token_ids,
|
|
||||||
attention_mask=att_mask,
|
|
||||||
predicate_mask=pred_mask,
|
|
||||||
total_pred_labels=all_pred_label,
|
|
||||||
arg_labels=single_arg_label)
|
|
||||||
|
|
||||||
# get performance on this batch
|
|
||||||
total_pred_loss += pred_loss.item()
|
|
||||||
total_arg_loss += arg_loss.item()
|
|
||||||
|
|
||||||
batch_loss.backward()
|
|
||||||
nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
|
||||||
optimizer.step()
|
|
||||||
scheduler.step()
|
|
||||||
|
|
||||||
trn_results = [total_pred_loss / (step + 1), total_arg_loss / (step + 1)]
|
|
||||||
if step > epoch_steps:
|
|
||||||
break
|
|
||||||
|
|
||||||
# interim evaluation
|
|
||||||
if step % 1000 == 0 and step != 0:
|
|
||||||
dev_iter = zip(args.dev_data_path, args.dev_gold_path, dev_loaders)
|
|
||||||
dev_results = list()
|
|
||||||
total_sum = 0
|
|
||||||
for dev_input, dev_gold, dev_loader in dev_iter:
|
|
||||||
dev_name = dev_input.split('/')[-1].replace('.pkl', '')
|
|
||||||
output_path = os.path.join(args.save_path, f'epoch{epoch}_dev/step{step}/{dev_name}')
|
|
||||||
extract(args, model, dev_loader, output_path)
|
|
||||||
dev_result = do_eval(output_path, dev_gold)
|
|
||||||
utils.print_results(f"EPOCH{epoch} STEP{step} EVAL",
|
|
||||||
dev_result, ["F1 ", "PREC", "REC ", "AUC "])
|
|
||||||
total_sum += dev_result[0] + dev_result[-1]
|
|
||||||
dev_result.append(dev_result[0] + dev_result[-1])
|
|
||||||
dev_results += dev_result
|
|
||||||
summarizer.save_results([step] + trn_results + dev_results + [total_sum])
|
|
||||||
model_name = utils.set_model_name(total_sum, epoch, step)
|
|
||||||
torch.save(model.state_dict(), os.path.join(args.save_path, model_name))
|
|
||||||
|
|
||||||
if step % args.summary_step == 0 and step != 0:
|
|
||||||
utils.print_results(f"EPOCH{epoch} STEP{step} TRAIN",
|
|
||||||
trn_results, ["PRED LOSS", "ARG LOSS "])
|
|
||||||
|
|
||||||
# end epoch summary
|
|
||||||
utils.print_results(f"EPOCH{epoch} TRAIN",
|
|
||||||
trn_results, ["PRED LOSS", "ARG LOSS "])
|
|
||||||
return trn_results
|
|
||||||
|
|
|
@ -1,310 +0,0 @@
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from src.Multi2OIE.utils import utils
|
|
||||||
|
|
||||||
pred_tag2idx = {
|
|
||||||
'P-B': 0, 'P-I': 1, 'O': 2
|
|
||||||
}
|
|
||||||
arg_tag2idx = {
|
|
||||||
'A0-B': 0, 'A0-I': 1,
|
|
||||||
'A1-B': 2, 'A1-I': 3,
|
|
||||||
'A2-B': 4, 'A2-I': 5,
|
|
||||||
'A3-B': 6, 'A3-I': 7,
|
|
||||||
'O': 8,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_pred_idxs(pred_tags):
|
|
||||||
idxs = list()
|
|
||||||
for pred_tag in pred_tags:
|
|
||||||
idxs.append([idx.item() for idx in (pred_tag != 2).nonzero()])
|
|
||||||
return idxs
|
|
||||||
|
|
||||||
|
|
||||||
def get_pred_mask(tensor):
|
|
||||||
"""
|
|
||||||
Generate predicate masks by converting predicate index with 'O' tag to 1.
|
|
||||||
Other indexes are converted to 0 which means non-masking.
|
|
||||||
|
|
||||||
:param tensor: predicate tagged tensor with the shape of (B, L),
|
|
||||||
where B is the batch size, L is the sequence length.
|
|
||||||
:return: masked binary tensor with the same shape.
|
|
||||||
"""
|
|
||||||
res = tensor.clone()
|
|
||||||
res[tensor == pred_tag2idx['O']] = 1
|
|
||||||
res[tensor != pred_tag2idx['O']] = 0
|
|
||||||
return torch.tensor(res, dtype=torch.bool, device=tensor.device)
|
|
||||||
|
|
||||||
|
|
||||||
def filter_pred_tags(pred_tags, tokens):
|
|
||||||
"""
|
|
||||||
Filter useless tokens by converting them into 'Outside' tag.
|
|
||||||
We treat 'Inside' tag before 'Beginning' tag as meaningful signal,
|
|
||||||
so changed them to 'Beginning' tag unlike [Stanovsky et al., 2018].
|
|
||||||
|
|
||||||
:param pred_tags: predicate tags with the shape of (B, L).
|
|
||||||
:param tokens: list format sentence pieces with the shape of (B, L)
|
|
||||||
:return: tensor of filtered predicate tags with the same shape.
|
|
||||||
"""
|
|
||||||
assert len(pred_tags) == len(tokens)
|
|
||||||
assert len(pred_tags[0]) == len(tokens[0])
|
|
||||||
|
|
||||||
# filter by tokens ([CLS], [SEP], [PAD] tokens should be allocated as 'O')
|
|
||||||
for pred_idx, cur_tokens in enumerate(tokens):
|
|
||||||
for tag_idx, token in enumerate(cur_tokens):
|
|
||||||
if token in ['[CLS]', '[SEP]', '[PAD]']:
|
|
||||||
pred_tags[pred_idx][tag_idx] = pred_tag2idx['O']
|
|
||||||
|
|
||||||
# filter by tags
|
|
||||||
pred_copied = pred_tags.clone()
|
|
||||||
for pred_idx, cur_pred_tag in enumerate(pred_copied):
|
|
||||||
flag = False
|
|
||||||
tag_copied = cur_pred_tag.clone()
|
|
||||||
for tag_idx, tag in enumerate(tag_copied):
|
|
||||||
if not flag and tag == pred_tag2idx['P-B']:
|
|
||||||
flag = True
|
|
||||||
elif not flag and tag == pred_tag2idx['P-I']:
|
|
||||||
pred_tags[pred_idx][tag_idx] = pred_tag2idx['P-B']
|
|
||||||
flag = True
|
|
||||||
elif flag and tag == pred_tag2idx['O']:
|
|
||||||
flag = False
|
|
||||||
return pred_tags
|
|
||||||
|
|
||||||
|
|
||||||
def filter_arg_tags(arg_tags, pred_tags, tokens):
|
|
||||||
"""
|
|
||||||
Same as the description of @filter_pred_tags().
|
|
||||||
|
|
||||||
:param arg_tags: argument tags with the shape of (B, L).
|
|
||||||
:param pred_tags: predicate tags with the same shape.
|
|
||||||
It is used to force predicate position to be allocated the 'Outside' tag.
|
|
||||||
:param tokens: list of string tokens with the length of L.
|
|
||||||
It is used to force special tokens like [CLS] to be allocated the 'Outside' tag.
|
|
||||||
:return: tensor of filtered argument tags with the same shape.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# filter by tokens ([CLS], [SEP], [PAD] tokens should be allocated as 'O')
|
|
||||||
for arg_idx, cur_arg_tag in enumerate(arg_tags):
|
|
||||||
for tag_idx, token in enumerate(tokens):
|
|
||||||
if token in ['[CLS]', '[SEP]', '[PAD]']:
|
|
||||||
arg_tags[arg_idx][tag_idx] = arg_tag2idx['O']
|
|
||||||
|
|
||||||
# filter by tags
|
|
||||||
arg_copied = arg_tags.clone()
|
|
||||||
for arg_idx, (cur_arg_tag, cur_pred_tag) in enumerate(zip(arg_copied, pred_tags)):
|
|
||||||
pred_idxs = [idx[0].item() for idx
|
|
||||||
in (cur_pred_tag != pred_tag2idx['O']).nonzero()]
|
|
||||||
arg_tags[arg_idx][pred_idxs] = arg_tag2idx['O']
|
|
||||||
cur_arg_copied = arg_tags[arg_idx].clone()
|
|
||||||
flag_idx = 999
|
|
||||||
for tag_idx, tag in enumerate(cur_arg_copied):
|
|
||||||
if tag == arg_tag2idx['O']:
|
|
||||||
flag_idx = 999
|
|
||||||
continue
|
|
||||||
arg_n = tag // 2 # 0: A0 / 1: A1 / ...
|
|
||||||
inside = tag % 2 # 0: begin / 1: inside
|
|
||||||
if not inside and flag_idx != arg_n:
|
|
||||||
flag_idx = arg_n
|
|
||||||
# connect_args
|
|
||||||
elif not inside and flag_idx == arg_n:
|
|
||||||
arg_tags[arg_idx][tag_idx] = arg_tag2idx[f'A{arg_n}-I']
|
|
||||||
elif inside and flag_idx != arg_n:
|
|
||||||
arg_tags[arg_idx][tag_idx] = arg_tag2idx[f'A{arg_n}-B']
|
|
||||||
flag_idx = arg_n
|
|
||||||
return arg_tags
|
|
||||||
|
|
||||||
|
|
||||||
def get_max_prob_args(arg_tags, arg_probs):
|
|
||||||
"""
|
|
||||||
Among predicted argument tags, remain only arguments with highest probs.
|
|
||||||
The comparison of probability is made only between the same argument labels.
|
|
||||||
|
|
||||||
:param arg_tags: argument tags with the shape of (B, L).
|
|
||||||
:param arg_probs: argument softmax probabilities with the shape of (B, L, T),
|
|
||||||
where B is the batch size, L is the sequence length, and T is the # of tag labels.
|
|
||||||
:return: tensor of filtered argument tags with the same shape.
|
|
||||||
"""
|
|
||||||
for cur_arg_tag, cur_probs in zip(arg_tags, arg_probs):
|
|
||||||
cur_tag_probs = [cur_probs[idx][tag] for idx, tag in enumerate(cur_arg_tag)]
|
|
||||||
for arg_n in range(4):
|
|
||||||
b_tag = arg_tag2idx[f"A{arg_n}-B"]
|
|
||||||
i_tag = arg_tag2idx[f"A{arg_n}-I"]
|
|
||||||
flag = False
|
|
||||||
total_tags = []
|
|
||||||
cur_tags = []
|
|
||||||
for idx, tag in enumerate(cur_arg_tag):
|
|
||||||
if not flag and tag == b_tag:
|
|
||||||
flag = True
|
|
||||||
cur_tags.append(idx)
|
|
||||||
elif flag and tag == i_tag:
|
|
||||||
cur_tags.append(idx)
|
|
||||||
elif flag and tag == b_tag:
|
|
||||||
total_tags.append(cur_tags)
|
|
||||||
cur_tags = [idx]
|
|
||||||
elif tag != b_tag or tag != i_tag:
|
|
||||||
total_tags.append(cur_tags)
|
|
||||||
cur_tags = []
|
|
||||||
flag = False
|
|
||||||
max_idxs, max_prob = None, 0.0
|
|
||||||
for idxs in total_tags:
|
|
||||||
all_probs = [cur_tag_probs[idx].item() for idx in idxs]
|
|
||||||
if len(all_probs) == 0:
|
|
||||||
continue
|
|
||||||
cur_prob = all_probs[0]
|
|
||||||
if cur_prob > max_prob:
|
|
||||||
max_prob = cur_prob
|
|
||||||
max_idxs = idxs
|
|
||||||
if max_idxs is None:
|
|
||||||
continue
|
|
||||||
del_idxs = [idx for idx, tag in enumerate(cur_arg_tag)
|
|
||||||
if (tag in [b_tag, i_tag]) and (idx not in max_idxs)]
|
|
||||||
cur_arg_tag[del_idxs] = arg_tag2idx['O']
|
|
||||||
return arg_tags
|
|
||||||
|
|
||||||
|
|
||||||
def get_single_predicate_idxs(pred_tags):
|
|
||||||
"""
|
|
||||||
Divide each single batch based on predicted predicates.
|
|
||||||
It is necessary for predicting argument tags with specific predicate.
|
|
||||||
|
|
||||||
:param pred_tags: tensor of predicate tags with the shape of (B, L)
|
|
||||||
EX >>> tensor([[2, 0, 0, 1, 0, 1, 0, 2, 2, 2],
|
|
||||||
[2, 2, 2, 0, 1, 0, 1, 2, 2, 2],
|
|
||||||
[2, 2, 2, 2, 2, 2, 2, 2, 0, 1]])
|
|
||||||
|
|
||||||
:return: list of tensors with the shape of (B, P, L)
|
|
||||||
the number P can be different for each batch.
|
|
||||||
EX >>> [tensor([[2., 0., 2., 2., 2., 2., 2., 2., 2., 2.],
|
|
||||||
[2., 2., 0., 1., 2., 2., 2., 2., 2., 2.],
|
|
||||||
[2., 2., 2., 2., 0., 1., 2., 2., 2., 2.],
|
|
||||||
[2., 2., 2., 2., 2., 2., 0., 2., 2., 2.]]),
|
|
||||||
tensor([[2., 2., 2., 0., 1., 2., 2., 2., 2., 2.],
|
|
||||||
[2., 2., 2., 2., 2., 0., 1., 2., 2., 2.]]),
|
|
||||||
tensor([[2., 2., 2., 2., 2., 2., 2., 2., 0., 1.]])]
|
|
||||||
"""
|
|
||||||
total_pred_tags = []
|
|
||||||
for cur_pred_tag in pred_tags:
|
|
||||||
cur_sent_preds = []
|
|
||||||
begin_idxs = [idx[0].item() for idx in (cur_pred_tag == pred_tag2idx['P-B']).nonzero()]
|
|
||||||
for i, b_idx in enumerate(begin_idxs):
|
|
||||||
cur_pred = np.full(cur_pred_tag.shape[0], pred_tag2idx['O'])
|
|
||||||
cur_pred[b_idx] = pred_tag2idx['P-B']
|
|
||||||
if i == len(begin_idxs) - 1:
|
|
||||||
end_idx = cur_pred_tag.shape[0]
|
|
||||||
else:
|
|
||||||
end_idx = begin_idxs[i + 1]
|
|
||||||
for j, tag in enumerate(cur_pred_tag[b_idx:end_idx]):
|
|
||||||
if tag.item() == pred_tag2idx['O']:
|
|
||||||
break
|
|
||||||
elif tag.item() == pred_tag2idx['P-I']:
|
|
||||||
cur_pred[b_idx + j] = pred_tag2idx['P-I']
|
|
||||||
cur_sent_preds.append(cur_pred)
|
|
||||||
total_pred_tags.append(cur_sent_preds)
|
|
||||||
return [torch.Tensor(pred_tags) for pred_tags in total_pred_tags]
|
|
||||||
|
|
||||||
|
|
||||||
def get_tuple(sentence, pred_tags, arg_tags, tokenizer):
|
|
||||||
"""
|
|
||||||
Get string format tuples from given predicate indexes and argument tags.
|
|
||||||
|
|
||||||
:param sentence: string format raw sentence.
|
|
||||||
:param pred_tags: tensor of predicate tags with the shape of (# of predicates, sequence length).
|
|
||||||
:param arg_tags: tensor of argument tags with the same shape.
|
|
||||||
:param tokenizer: transformer BertTokenizer (bert-base-cased or bert-base-multilingual-cased)
|
|
||||||
|
|
||||||
:return extractions: list of strings each element means predicate, arg0, arg1, ...
|
|
||||||
:return extraction_idxs: list of indexes of each argument for calculating confidence score.
|
|
||||||
"""
|
|
||||||
word2piece = utils.get_word2piece(sentence, tokenizer)
|
|
||||||
words = sentence.split(' ')
|
|
||||||
assert pred_tags.shape[0] == arg_tags.shape[0] # number of predicates
|
|
||||||
|
|
||||||
pred_tags = pred_tags.tolist()
|
|
||||||
arg_tags = arg_tags.tolist()
|
|
||||||
extractions = list()
|
|
||||||
extraction_idxs = list()
|
|
||||||
|
|
||||||
# loop for each predicate
|
|
||||||
for cur_pred_tag, cur_arg_tags in zip(pred_tags, arg_tags):
|
|
||||||
cur_extraction = list()
|
|
||||||
cur_extraction_idxs = list()
|
|
||||||
|
|
||||||
# get predicate
|
|
||||||
pred_labels = [pred_tag2idx['P-B'], pred_tag2idx['P-I']]
|
|
||||||
cur_predicate_idxs = [idx for idx, tag in enumerate(cur_pred_tag) if tag in pred_labels]
|
|
||||||
if len(cur_predicate_idxs) == 0:
|
|
||||||
predicates_str = ''
|
|
||||||
else:
|
|
||||||
cur_pred_words = list()
|
|
||||||
for word_idx, piece_idxs in word2piece.items():
|
|
||||||
if set(piece_idxs) <= set(cur_predicate_idxs):
|
|
||||||
cur_pred_words.append(word_idx)
|
|
||||||
if len(cur_pred_words) == 0:
|
|
||||||
predicates_str = ''
|
|
||||||
cur_predicate_idxs = list()
|
|
||||||
else:
|
|
||||||
predicates_str = ' '.join([words[idx] for idx in cur_pred_words])
|
|
||||||
cur_extraction.append(predicates_str)
|
|
||||||
cur_extraction_idxs.append(cur_predicate_idxs)
|
|
||||||
|
|
||||||
# get arguments
|
|
||||||
for arg_n in range(4):
|
|
||||||
cur_arg_labels = [arg_tag2idx[f'A{arg_n}-B'], arg_tag2idx[f'A{arg_n}-I']]
|
|
||||||
cur_arg_idxs = [idx for idx, tag in enumerate(cur_arg_tags) if tag in cur_arg_labels]
|
|
||||||
if len(cur_arg_idxs) == 0:
|
|
||||||
cur_arg_str = ''
|
|
||||||
else:
|
|
||||||
cur_arg_words = list()
|
|
||||||
for word_idx, piece_idxs in word2piece.items():
|
|
||||||
if set(piece_idxs) <= set(cur_arg_idxs):
|
|
||||||
cur_arg_words.append(word_idx)
|
|
||||||
if len(cur_arg_words) == 0:
|
|
||||||
cur_arg_str = ''
|
|
||||||
cur_arg_idxs = list()
|
|
||||||
else:
|
|
||||||
cur_arg_str = ' '.join([words[idx] for idx in cur_arg_words])
|
|
||||||
cur_extraction.append(cur_arg_str)
|
|
||||||
cur_extraction_idxs.append(cur_arg_idxs)
|
|
||||||
extractions.append(cur_extraction)
|
|
||||||
extraction_idxs.append(cur_extraction_idxs)
|
|
||||||
return extractions, extraction_idxs
|
|
||||||
|
|
||||||
|
|
||||||
def get_confidence_score(pred_probs, arg_probs, extraction_idxs):
|
|
||||||
"""
|
|
||||||
get the confidence score of each extraction for drawing PR-curve.
|
|
||||||
|
|
||||||
:param pred_probs: (sequence length, # of predicate labels)
|
|
||||||
:param arg_probs: (# of predicates, sequence length, # of argument labels)
|
|
||||||
:param extraction_idxs: [[[2, 3, 4], [0, 1], [9, 10]], [[0, 1, 2], [7, 8], [4, 5]], ...]
|
|
||||||
"""
|
|
||||||
confidence_scores = list()
|
|
||||||
for cur_arg_prob, cur_ext_idxs in zip(arg_probs, extraction_idxs):
|
|
||||||
if len(cur_ext_idxs[0]) == 0:
|
|
||||||
confidence_scores.append(0)
|
|
||||||
continue
|
|
||||||
cur_score = 0
|
|
||||||
|
|
||||||
# predicate score
|
|
||||||
pred_score = max(pred_probs[cur_ext_idxs[0][0]]).item()
|
|
||||||
cur_score += pred_score
|
|
||||||
|
|
||||||
# argument score
|
|
||||||
for arg_idx in cur_ext_idxs[1:]:
|
|
||||||
if len(arg_idx) == 0:
|
|
||||||
continue
|
|
||||||
begin_idxs = _find_begins(arg_idx)
|
|
||||||
arg_score = np.mean([max(cur_arg_prob[cur_idx]).item() for cur_idx in begin_idxs])
|
|
||||||
cur_score += arg_score
|
|
||||||
confidence_scores.append(cur_score)
|
|
||||||
return confidence_scores
|
|
||||||
|
|
||||||
|
|
||||||
def _find_begins(idxs):
|
|
||||||
found_begins = [idxs[0]]
|
|
||||||
cur_flag_idx = idxs[0]
|
|
||||||
for cur_idx in idxs[1:]:
|
|
||||||
if cur_idx - cur_flag_idx != 1:
|
|
||||||
found_begins.append(cur_idx)
|
|
||||||
cur_flag_idx = cur_idx
|
|
||||||
return found_begins
|
|
|
@ -1,237 +0,0 @@
|
||||||
"""
|
|
||||||
Script for transforming original dataset for Multi^2OIE training and evaluation
|
|
||||||
|
|
||||||
1. if Mode == 'train'
|
|
||||||
- input: structured.json (https://github.com/zhanjunlang/Span_OIE)
|
|
||||||
- output: ../datasets/openie4_train.pkl
|
|
||||||
|
|
||||||
2. if Mode == 'dev_input'
|
|
||||||
- input: dev.oie.conll (https://github.com/gabrielStanovsky/supervised-oie/tree/master/data)
|
|
||||||
- output: ../datasets/oie2016_dev.pkl
|
|
||||||
|
|
||||||
3. if Mode == 'dev_gold'
|
|
||||||
- input: dev.oie.conll (https://github.com/gabrielStanovsky/supervised-oie/tree/master/data)
|
|
||||||
- output: ../evaluate/OIE2016_dev.txt
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import numpy as np
|
|
||||||
import argparse
|
|
||||||
import pickle
|
|
||||||
from transformers import BertTokenizer
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
|
|
||||||
pred_tag2idx = {
|
|
||||||
'P-B': 0, 'P-I': 1, 'O': 2
|
|
||||||
}
|
|
||||||
|
|
||||||
arg_tag2idx = {
|
|
||||||
'A0-B': 0, 'A0-I': 1,
|
|
||||||
'A1-B': 2, 'A1-I': 3,
|
|
||||||
'A2-B': 4, 'A2-I': 5,
|
|
||||||
'A3-B': 6, 'A3-I': 7,
|
|
||||||
'O': 8,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess_train(args):
|
|
||||||
print("loading dataset...")
|
|
||||||
|
|
||||||
with open(args.data) as json_file:
|
|
||||||
data = json.load(json_file)
|
|
||||||
iterator = tqdm(data)
|
|
||||||
tokenizer = BertTokenizer.from_pretrained(args.bert_config)
|
|
||||||
print("done. preprocessing starts.")
|
|
||||||
|
|
||||||
openie4_train = {
|
|
||||||
'tokens': list(),
|
|
||||||
'single_pred_labels': list(),
|
|
||||||
'single_arg_labels': list(),
|
|
||||||
'all_pred_labels': list()
|
|
||||||
}
|
|
||||||
|
|
||||||
rel_pos_malformed = 0
|
|
||||||
max_over_case = 0
|
|
||||||
|
|
||||||
for cur_data in iterator:
|
|
||||||
words = cur_data['sentence'].replace('\xa0', ' ').split(' ')
|
|
||||||
word2piece = {idx: list() for idx in range(len(words))}
|
|
||||||
sentence_pieces = list()
|
|
||||||
piece_idx = 0
|
|
||||||
for word_idx, word in enumerate(words):
|
|
||||||
pieces = tokenizer.tokenize(word)
|
|
||||||
sentence_pieces += pieces
|
|
||||||
for piece_idx_added, piece in enumerate(pieces):
|
|
||||||
word2piece[word_idx].append(piece_idx + piece_idx_added)
|
|
||||||
piece_idx += len(pieces)
|
|
||||||
assert len(sentence_pieces) == piece_idx
|
|
||||||
|
|
||||||
# if the length of sentencepieces is over maxlen-2, we skip the sentence.
|
|
||||||
if piece_idx > args.max_len - 2:
|
|
||||||
max_over_case += 1
|
|
||||||
continue
|
|
||||||
|
|
||||||
all_pred_label = np.asarray([pred_tag2idx['O'] for _ in range(len(sentence_pieces))])
|
|
||||||
cur_tuple_malformed = 0
|
|
||||||
for cur_tuple in cur_data['tuples']:
|
|
||||||
|
|
||||||
# add predicate labels
|
|
||||||
pred_label = np.asarray([pred_tag2idx['O'] for _ in range(len(sentence_pieces))])
|
|
||||||
if -1 in cur_tuple['rel_pos']:
|
|
||||||
rel_pos_malformed += 1
|
|
||||||
cur_tuple_malformed += 1
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
start_idx, end_idx = cur_tuple['rel_pos'][:2]
|
|
||||||
for pred_word_idx in range(start_idx, end_idx + 1):
|
|
||||||
pred_label[word2piece[pred_word_idx]] = pred_tag2idx['P-I']
|
|
||||||
all_pred_label[word2piece[pred_word_idx]] = pred_tag2idx['P-I']
|
|
||||||
pred_label[word2piece[start_idx][0]] = pred_tag2idx['P-B']
|
|
||||||
all_pred_label[word2piece[start_idx][0]] = pred_tag2idx['P-B']
|
|
||||||
openie4_train['single_pred_labels'].append(pred_label)
|
|
||||||
|
|
||||||
# add argument-0 labels
|
|
||||||
arg_label = np.asarray([arg_tag2idx['O'] for _ in range(len(sentence_pieces))])
|
|
||||||
start_idx, end_idx = cur_tuple['arg0_pos']
|
|
||||||
for arg_word_idx in range(start_idx, end_idx + 1):
|
|
||||||
arg_label[word2piece[arg_word_idx]] = arg_tag2idx['A0-I']
|
|
||||||
arg_label[word2piece[start_idx][0]] = arg_tag2idx['A0-B']
|
|
||||||
|
|
||||||
# add additional argument labels
|
|
||||||
for arg_n, arg_pos in enumerate(cur_tuple['args_pos'][:3]):
|
|
||||||
arg_n += 1
|
|
||||||
start_idx, end_idx = arg_pos
|
|
||||||
for arg_word_idx in range(start_idx, end_idx + 1):
|
|
||||||
arg_label[word2piece[arg_word_idx]] = arg_tag2idx[f'A{arg_n}-I']
|
|
||||||
arg_label[word2piece[start_idx][0]] = arg_tag2idx[f'A{arg_n}-B']
|
|
||||||
openie4_train['single_arg_labels'].append(arg_label)
|
|
||||||
|
|
||||||
# add sentence pieces and total predicate label of current sentence
|
|
||||||
for _ in range(len(cur_data['tuples']) - cur_tuple_malformed):
|
|
||||||
openie4_train['tokens'].append(sentence_pieces)
|
|
||||||
openie4_train['all_pred_labels'].append(all_pred_label)
|
|
||||||
|
|
||||||
assert len(openie4_train['tokens']) == len(openie4_train['all_pred_labels'])
|
|
||||||
assert len(openie4_train['all_pred_labels']) == len(openie4_train['single_pred_labels'])
|
|
||||||
assert len(openie4_train['single_pred_labels']) == len(openie4_train['single_arg_labels'])
|
|
||||||
|
|
||||||
save_pkl(args.save_path, openie4_train)
|
|
||||||
print(f"# of data over max length: {max_over_case}")
|
|
||||||
print(f"# of data with malformed relation positions: {rel_pos_malformed}")
|
|
||||||
print("\npreprocessing done.")
|
|
||||||
|
|
||||||
"""
|
|
||||||
For English BERT,
|
|
||||||
# of data over max length: 5097
|
|
||||||
# of data with malformed relation positions: 1959
|
|
||||||
|
|
||||||
For Multilingual BERT,
|
|
||||||
# of data over max length: 2480
|
|
||||||
# of data with malformed relation positions: 1974
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess_dev_input(args):
|
|
||||||
print("loading dataset...")
|
|
||||||
with open(args.data) as f:
|
|
||||||
lines = f.readlines()
|
|
||||||
lines = [line.strip().split('\t') for line in lines]
|
|
||||||
print("done. preprocessing starts.")
|
|
||||||
|
|
||||||
sentences = list()
|
|
||||||
words_queue = list()
|
|
||||||
for line in tqdm(lines[1:]):
|
|
||||||
if len(line) != len(lines[0]):
|
|
||||||
sentence = " ".join(words_queue)
|
|
||||||
sentences.append(sentence)
|
|
||||||
words_queue = list()
|
|
||||||
continue
|
|
||||||
words_queue.append(line[1])
|
|
||||||
sentences = list(set(sentences))
|
|
||||||
save_pkl(args.save_path, sentences)
|
|
||||||
print("\npreprocessing done.")
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess_dev_gold(args):
|
|
||||||
print("loading dataset...")
|
|
||||||
with open(args.data) as f:
|
|
||||||
lines = f.readlines()
|
|
||||||
lines = [line.strip().split('\t') for line in lines]
|
|
||||||
print("done. preprocessing starts.")
|
|
||||||
|
|
||||||
f = open(args.save_path, 'w')
|
|
||||||
words_queue = list()
|
|
||||||
args_queue = {
|
|
||||||
'A0': list(), 'A1': list(), 'A2': list(), 'A3': list()
|
|
||||||
}
|
|
||||||
|
|
||||||
for line in tqdm(lines[1:]):
|
|
||||||
if len(line) != len(lines[0]):
|
|
||||||
new_line = list()
|
|
||||||
new_line.append(" ".join(words_queue))
|
|
||||||
new_line.append(words_queue[pred_head_id])
|
|
||||||
new_line.append(pred)
|
|
||||||
for label in list(args_queue.keys()):
|
|
||||||
if len(args_queue[label]) != 0:
|
|
||||||
new_line.append(" ".join(args_queue[label]))
|
|
||||||
f.write("\t".join(new_line) + "\n")
|
|
||||||
|
|
||||||
words_queue = list()
|
|
||||||
args_queue = {
|
|
||||||
'A0': list(), 'A1': list(), 'A2': list(), 'A3': list()
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
word = line[1]
|
|
||||||
pred = line[2]
|
|
||||||
pred_head_id = int(line[4])
|
|
||||||
words_queue.append(word)
|
|
||||||
for label in list(args_queue.keys()):
|
|
||||||
if label in line[-1]:
|
|
||||||
args_queue[label].append(word)
|
|
||||||
f.close()
|
|
||||||
|
|
||||||
|
|
||||||
def _get_word2piece(sentence, tokenizer):
|
|
||||||
words = sentence.replace('\xa0', ' ').split(' ')
|
|
||||||
word2piece = {idx: list() for idx in range(len(words))}
|
|
||||||
sentence_pieces = list()
|
|
||||||
piece_idx = 1
|
|
||||||
for word_idx, word in enumerate(words):
|
|
||||||
pieces = tokenizer.tokenize(word)
|
|
||||||
sentence_pieces += pieces
|
|
||||||
for piece_idx_added, piece in enumerate(pieces):
|
|
||||||
word2piece[word_idx].append(piece_idx + piece_idx_added)
|
|
||||||
piece_idx += len(pieces)
|
|
||||||
assert len(sentence_pieces) == piece_idx - 1
|
|
||||||
return word2piece
|
|
||||||
|
|
||||||
|
|
||||||
def save_pkl(path, file):
|
|
||||||
with open(path, 'wb') as f:
|
|
||||||
pickle.dump(file, f)
|
|
||||||
|
|
||||||
|
|
||||||
def load_pkl(path):
|
|
||||||
with open(path, 'rb') as f:
|
|
||||||
return pickle.load(f)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument('--mode', default='train')
|
|
||||||
parser.add_argument('--data', default='../datasets/structured_data.json')
|
|
||||||
parser.add_argument('--save_path', default='../datasets/openie4_train.pkl')
|
|
||||||
parser.add_argument('--bert_config', default='bert-base-cased')
|
|
||||||
parser.add_argument('--max_len', type=int, default=64)
|
|
||||||
main_args = parser.parse_args()
|
|
||||||
|
|
||||||
if main_args.mode == 'train':
|
|
||||||
preprocess_train(main_args)
|
|
||||||
elif main_args.mode == 'dev_input':
|
|
||||||
preprocess_dev_input(main_args)
|
|
||||||
elif main_args.mode == 'dev_gold':
|
|
||||||
preprocess_dev_gold(main_args)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid preprocessing mode: {main_args.mode}")
|
|
|
@ -1,153 +0,0 @@
|
||||||
import argparse
|
|
||||||
import torch
|
|
||||||
import os
|
|
||||||
import random
|
|
||||||
import numpy as np
|
|
||||||
import pickle
|
|
||||||
import pandas as pd
|
|
||||||
import json
|
|
||||||
import copy
|
|
||||||
from src.Multi2OIE.model import Multi2OIE, BERTBiLSTM
|
|
||||||
from transformers import get_linear_schedule_with_warmup, AdamW
|
|
||||||
|
|
||||||
|
|
||||||
def set_seed(seed):
|
|
||||||
random.seed(seed)
|
|
||||||
np.random.seed(seed)
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
|
|
||||||
|
|
||||||
def str2bool(v):
|
|
||||||
if isinstance(v, bool):
|
|
||||||
return v
|
|
||||||
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
|
||||||
return True
|
|
||||||
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
raise argparse.ArgumentTypeError("Boolean value expected.")
|
|
||||||
|
|
||||||
|
|
||||||
def clean_config(config):
|
|
||||||
device = torch.device(config.device if torch.cuda.is_available() else 'cpu')
|
|
||||||
config.device = device
|
|
||||||
config.pred_n_labels = 3
|
|
||||||
config.arg_n_labels = 9
|
|
||||||
os.makedirs(config.save_path, exist_ok=True)
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
def get_models(bert_config,
|
|
||||||
pred_n_labels=3,
|
|
||||||
arg_n_labels=9,
|
|
||||||
n_arg_heads=8,
|
|
||||||
n_arg_layers=4,
|
|
||||||
lstm_dropout=0.3,
|
|
||||||
mh_dropout=0.1,
|
|
||||||
pred_clf_dropout=0.,
|
|
||||||
arg_clf_dropout=0.3,
|
|
||||||
pos_emb_dim=64,
|
|
||||||
use_lstm=False,
|
|
||||||
device=None):
|
|
||||||
if not use_lstm:
|
|
||||||
return Multi2OIE(
|
|
||||||
bert_config=bert_config,
|
|
||||||
mh_dropout=mh_dropout,
|
|
||||||
pred_clf_dropout=pred_clf_dropout,
|
|
||||||
arg_clf_dropout=arg_clf_dropout,
|
|
||||||
n_arg_heads=n_arg_heads,
|
|
||||||
n_arg_layers=n_arg_layers,
|
|
||||||
pos_emb_dim=pos_emb_dim,
|
|
||||||
pred_n_labels=pred_n_labels,
|
|
||||||
arg_n_labels=arg_n_labels).to(device)
|
|
||||||
else:
|
|
||||||
return BERTBiLSTM(
|
|
||||||
bert_config=bert_config,
|
|
||||||
lstm_dropout=lstm_dropout,
|
|
||||||
pred_clf_dropout=pred_clf_dropout,
|
|
||||||
arg_clf_dropout=arg_clf_dropout,
|
|
||||||
pos_emb_dim=pos_emb_dim,
|
|
||||||
pred_n_labels=pred_n_labels,
|
|
||||||
arg_n_labels=arg_n_labels).to(device)
|
|
||||||
|
|
||||||
|
|
||||||
def save_pkl(path, file):
|
|
||||||
with open(path, 'wb') as f:
|
|
||||||
pickle.dump(file, f)
|
|
||||||
|
|
||||||
|
|
||||||
def load_pkl(path):
|
|
||||||
with open(path, 'rb') as f:
|
|
||||||
return pickle.load(f)
|
|
||||||
|
|
||||||
|
|
||||||
def get_word2piece(sentence, tokenizer):
|
|
||||||
words = sentence.split(' ')
|
|
||||||
word2piece = {idx: list() for idx in range(len(words))}
|
|
||||||
sentence_pieces = list()
|
|
||||||
piece_idx = 1
|
|
||||||
for word_idx, word in enumerate(words):
|
|
||||||
pieces = tokenizer.tokenize(word)
|
|
||||||
sentence_pieces += pieces
|
|
||||||
for piece_idx_added, piece in enumerate(pieces):
|
|
||||||
word2piece[word_idx].append(piece_idx + piece_idx_added)
|
|
||||||
piece_idx += len(pieces)
|
|
||||||
assert len(sentence_pieces) == piece_idx - 1
|
|
||||||
return word2piece
|
|
||||||
|
|
||||||
|
|
||||||
def get_train_modules(model,
|
|
||||||
lr,
|
|
||||||
total_steps,
|
|
||||||
warmup_steps):
|
|
||||||
optimizer = AdamW(
|
|
||||||
model.parameters(), lr=lr, correct_bias=False)
|
|
||||||
scheduler = get_linear_schedule_with_warmup(
|
|
||||||
optimizer, warmup_steps, total_steps)
|
|
||||||
return optimizer, scheduler
|
|
||||||
|
|
||||||
|
|
||||||
class SummaryManager:
|
|
||||||
def __init__(self, config):
|
|
||||||
self.config = config
|
|
||||||
self.save_config()
|
|
||||||
columns = ['epoch', 'train_predicate_loss', 'train_argument_loss']
|
|
||||||
for cur_dev_path in config.dev_data_path:
|
|
||||||
cur_dev_name = cur_dev_path.split('/')[-1].replace('.pkl', '')
|
|
||||||
for metric in ['f1', 'prec', 'rec', 'auc', 'sum']:
|
|
||||||
columns.append(f'{cur_dev_name}_{metric}')
|
|
||||||
columns.append('total_sum')
|
|
||||||
self.result_df = pd.DataFrame(columns=columns)
|
|
||||||
self.save_df()
|
|
||||||
|
|
||||||
def save_config(self, display=True):
|
|
||||||
if display:
|
|
||||||
for key, value in self.config.__dict__.items():
|
|
||||||
print("{}: {}".format(key, value))
|
|
||||||
print()
|
|
||||||
copied = copy.deepcopy(self.config)
|
|
||||||
copied.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
||||||
with open(os.path.join(copied.save_path, 'config.json'), 'w') as fp:
|
|
||||||
json.dump(copied.__dict__, fp, indent=4)
|
|
||||||
|
|
||||||
def save_results(self, results):
|
|
||||||
self.result_df = pd.read_csv(os.path.join(self.config.save_path, 'train_results.csv'))
|
|
||||||
self.result_df.loc[len(self.result_df.index)] = results
|
|
||||||
self.save_df()
|
|
||||||
|
|
||||||
def save_df(self):
|
|
||||||
self.result_df.to_csv(os.path.join(self.config.save_path, 'train_results.csv'), index=False)
|
|
||||||
|
|
||||||
|
|
||||||
def set_model_name(dev_results, epoch, step=None):
|
|
||||||
if step is not None:
|
|
||||||
return "model-epoch{}-step{}-score{:.4f}.bin".format(epoch, step, dev_results)
|
|
||||||
else:
|
|
||||||
return "model-epoch{}-end-score{:.4f}.bin".format(epoch, dev_results)
|
|
||||||
|
|
||||||
|
|
||||||
def print_results(message, results, names):
|
|
||||||
print(f"\n===== {message} =====")
|
|
||||||
for result, name in zip(results, names):
|
|
||||||
print("{}: {:.5f}".format(name, result))
|
|
||||||
print()
|
|
132
3-NLP_services/src/OpenNRE/.gitignore
vendored
132
3-NLP_services/src/OpenNRE/.gitignore
vendored
|
@ -1,132 +0,0 @@
|
||||||
# Byte-compiled / optimized / DLL files
|
|
||||||
__pycache__/
|
|
||||||
*.py[cod]
|
|
||||||
*$py.class
|
|
||||||
|
|
||||||
# C extensions
|
|
||||||
*.so
|
|
||||||
|
|
||||||
# Distribution / packaging
|
|
||||||
.Python
|
|
||||||
build/
|
|
||||||
develop-eggs/
|
|
||||||
dist/
|
|
||||||
downloads/
|
|
||||||
eggs/
|
|
||||||
.eggs/
|
|
||||||
lib/
|
|
||||||
lib64/
|
|
||||||
parts/
|
|
||||||
sdist/
|
|
||||||
var/
|
|
||||||
wheels/
|
|
||||||
*.egg-info/
|
|
||||||
.installed.cfg
|
|
||||||
*.egg
|
|
||||||
MANIFEST
|
|
||||||
|
|
||||||
# PyInstaller
|
|
||||||
# Usually these files are written by a python script from a template
|
|
||||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
|
||||||
*.manifest
|
|
||||||
*.spec
|
|
||||||
|
|
||||||
# Installer logs
|
|
||||||
pip-log.txt
|
|
||||||
pip-delete-this-directory.txt
|
|
||||||
|
|
||||||
# Unit test / coverage reports
|
|
||||||
htmlcov/
|
|
||||||
.tox/
|
|
||||||
.coverage
|
|
||||||
.coverage.*
|
|
||||||
.cache
|
|
||||||
nosetests.xml
|
|
||||||
coverage.xml
|
|
||||||
*.cover
|
|
||||||
.hypothesis/
|
|
||||||
.pytest_cache/
|
|
||||||
|
|
||||||
# Translations
|
|
||||||
*.mo
|
|
||||||
*.pot
|
|
||||||
|
|
||||||
# Django stuff:
|
|
||||||
*.log
|
|
||||||
local_settings.py
|
|
||||||
db.sqlite3
|
|
||||||
|
|
||||||
# Flask stuff:
|
|
||||||
instance/
|
|
||||||
.webassets-cache
|
|
||||||
|
|
||||||
# Scrapy stuff:
|
|
||||||
.scrapy
|
|
||||||
|
|
||||||
# Sphinx documentation
|
|
||||||
docs/_build/
|
|
||||||
|
|
||||||
# PyBuilder
|
|
||||||
target/
|
|
||||||
|
|
||||||
# Jupyter Notebook
|
|
||||||
.ipynb_checkpoints
|
|
||||||
|
|
||||||
# pyenv
|
|
||||||
.python-version
|
|
||||||
|
|
||||||
# celery beat schedule file
|
|
||||||
celerybeat-schedule
|
|
||||||
|
|
||||||
# SageMath parsed files
|
|
||||||
*.sage.py
|
|
||||||
|
|
||||||
# Environments
|
|
||||||
.env
|
|
||||||
.venv
|
|
||||||
env/
|
|
||||||
venv/
|
|
||||||
ENV/
|
|
||||||
env.bak/
|
|
||||||
venv.bak/
|
|
||||||
|
|
||||||
# Spyder project settings
|
|
||||||
.spyderproject
|
|
||||||
.spyproject
|
|
||||||
|
|
||||||
# Rope project settings
|
|
||||||
.ropeproject
|
|
||||||
|
|
||||||
# mkdocs documentation
|
|
||||||
/site
|
|
||||||
|
|
||||||
# mypy
|
|
||||||
.mypy_cache/
|
|
||||||
|
|
||||||
# test
|
|
||||||
test.py
|
|
||||||
|
|
||||||
# ckpt
|
|
||||||
ckpt
|
|
||||||
|
|
||||||
# vscode
|
|
||||||
.vscode
|
|
||||||
|
|
||||||
# tacred
|
|
||||||
benchmark/tacred
|
|
||||||
*.swp
|
|
||||||
|
|
||||||
<<<<<<< HEAD
|
|
||||||
# data and pretrain
|
|
||||||
pretrain
|
|
||||||
benchmark
|
|
||||||
!benchmark/*.sh
|
|
||||||
!pretrain/*.sh
|
|
||||||
|
|
||||||
# test env
|
|
||||||
.test
|
|
||||||
|
|
||||||
# package
|
|
||||||
opennre-egg.info
|
|
||||||
|
|
||||||
*.sh
|
|
|
@ -1,21 +0,0 @@
|
||||||
MIT License
|
|
||||||
|
|
||||||
Copyright (c) 2019 Tianyu Gao
|
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
||||||
of this software and associated documentation files (the "Software"), to deal
|
|
||||||
in the Software without restriction, including without limitation the rights
|
|
||||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
||||||
copies of the Software, and to permit persons to whom the Software is
|
|
||||||
furnished to do so, subject to the following conditions:
|
|
||||||
|
|
||||||
The above copyright notice and this permission notice shall be included in all
|
|
||||||
copies or substantial portions of the Software.
|
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
||||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
||||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
||||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
||||||
SOFTWARE.
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,151 +0,0 @@
|
||||||
# OpenNRE
|
|
||||||
|
|
||||||
****** Update ******
|
|
||||||
|
|
||||||
We provide two distantly-supervised datasets with human-annotated test sets, **NYT10m** and **Wiki20m**. Check the [datasets](#datasets) section for details.
|
|
||||||
|
|
||||||
****** Update ******
|
|
||||||
|
|
||||||
OpenNRE is an open-source and extensible toolkit that provides a unified framework to implement relation extraction models. This package is designed for the following groups:
|
|
||||||
|
|
||||||
* **New to relation extraction**: We have hand-by-hand tutorials and detailed documents that can not only enable you to use relation extraction tools, but also help you better understand the research progress in this field.
|
|
||||||
* **Developers**: Our easy-to-use interface and high-performance implementation can acclerate your deployment in the real-world applications. Besides, we provide several pretrained models which can be put into production without any training.
|
|
||||||
* **Researchers**: With our modular design, various task settings and metric tools, you can easily carry out experiments on your own models with only minor modification. We have also provided several most-used benchmarks for different settings of relation extraction.
|
|
||||||
* **Anyone who need to submit an NLP homework to impress their professors**: With state-of-the-art models, our package can definitely help you stand out among your classmates!
|
|
||||||
|
|
||||||
This package is mainly contributed by [Tianyu Gao](https://github.com/gaotianyu1350), [Xu Han](https://github.com/THUCSTHanxu13), [Shulian Cao](https://github.com/ShulinCao), [Lumin Tang](https://github.com/Tsingularity), [Yankai Lin](https://github.com/Mrlyk423), [Zhiyuan Liu](http://nlp.csai.tsinghua.edu.cn/~lzy/)
|
|
||||||
|
|
||||||
## What is Relation Extraction
|
|
||||||
|
|
||||||
Relation extraction is a natural language processing (NLP) task aiming at extracting relations (e.g., *founder of*) between entities (e.g., **Bill Gates** and **Microsoft**). For example, from the sentence *Bill Gates founded Microsoft*, we can extract the relation triple (**Bill Gates**, *founder of*, **Microsoft**).
|
|
||||||
|
|
||||||
Relation extraction is a crucial technique in automatic knowledge graph construction. By using relation extraction, we can accumulatively extract new relation facts and expand the knowledge graph, which, as a way for machines to understand the human world, has many downstream applications like question answering, recommender system and search engine.
|
|
||||||
|
|
||||||
## How to Cite
|
|
||||||
|
|
||||||
A good research work is always accompanied by a thorough and faithful reference. If you use or extend our work, please cite the following paper:
|
|
||||||
|
|
||||||
```
|
|
||||||
@inproceedings{han-etal-2019-opennre,
|
|
||||||
title = "{O}pen{NRE}: An Open and Extensible Toolkit for Neural Relation Extraction",
|
|
||||||
author = "Han, Xu and Gao, Tianyu and Yao, Yuan and Ye, Deming and Liu, Zhiyuan and Sun, Maosong",
|
|
||||||
booktitle = "Proceedings of EMNLP-IJCNLP: System Demonstrations",
|
|
||||||
year = "2019",
|
|
||||||
url = "https://www.aclweb.org/anthology/D19-3029",
|
|
||||||
doi = "10.18653/v1/D19-3029",
|
|
||||||
pages = "169--174"
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
It's our honor to help you better explore relation extraction with our OpenNRE toolkit!
|
|
||||||
|
|
||||||
## Papers and Document
|
|
||||||
|
|
||||||
If you want to learn more about neural relation extraction, visit another project of ours ([NREPapers](https://github.com/thunlp/NREPapers)).
|
|
||||||
|
|
||||||
You can refer to our [document](https://opennre-docs.readthedocs.io/en/latest/) for more details about this project.
|
|
||||||
|
|
||||||
## Install
|
|
||||||
|
|
||||||
### Install as A Python Package
|
|
||||||
|
|
||||||
We are now working on deploy OpenNRE as a Python package. Coming soon!
|
|
||||||
|
|
||||||
### Using Git Repository
|
|
||||||
|
|
||||||
Clone the repository from our github page (don't forget to star us!)
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git clone https://github.com/thunlp/OpenNRE.git
|
|
||||||
```
|
|
||||||
|
|
||||||
If it is too slow, you can try
|
|
||||||
```
|
|
||||||
git clone https://github.com/thunlp/OpenNRE.git --depth 1
|
|
||||||
```
|
|
||||||
|
|
||||||
Then install all the requirements:
|
|
||||||
|
|
||||||
```
|
|
||||||
pip install -r requirements.txt
|
|
||||||
```
|
|
||||||
|
|
||||||
**Note**: Please choose appropriate PyTorch version based on your machine (related to your CUDA version). For details, refer to https://pytorch.org/.
|
|
||||||
|
|
||||||
Then install the package with
|
|
||||||
```
|
|
||||||
python setup.py install
|
|
||||||
```
|
|
||||||
|
|
||||||
If you also want to modify the code, run this:
|
|
||||||
```
|
|
||||||
python setup.py develop
|
|
||||||
```
|
|
||||||
|
|
||||||
Note that we have excluded all data and pretrain files for fast deployment. You can manually download them by running scripts in the ``benchmark`` and ``pretrain`` folders. For example, if you want to download FewRel dataset, you can run
|
|
||||||
|
|
||||||
```bash
|
|
||||||
bash benchmark/download_fewrel.sh
|
|
||||||
```
|
|
||||||
|
|
||||||
## Easy Start
|
|
||||||
|
|
||||||
Make sure you have installed OpenNRE as instructed above. Then import our package and load pre-trained models.
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> import opennre
|
|
||||||
>>> model = opennre.get_model('wiki80_cnn_softmax')
|
|
||||||
```
|
|
||||||
|
|
||||||
Note that it may take a few minutes to download checkpoint and data for the first time. Then use `infer` to do sentence-level relation extraction
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> model.infer({'text': 'He was the son of Máel Dúin mac Máele Fithrich, and grandson of the high king Áed Uaridnach (died 612).', 'h': {'pos': (18, 46)}, 't': {'pos': (78, 91)}})
|
|
||||||
('father', 0.5108704566955566)
|
|
||||||
```
|
|
||||||
|
|
||||||
You will get the relation result and its confidence score.
|
|
||||||
|
|
||||||
If you want to use the model on your GPU, just run
|
|
||||||
```python
|
|
||||||
>>> model = model.cuda()
|
|
||||||
```
|
|
||||||
before calling the inference function.
|
|
||||||
|
|
||||||
For now, we have the following available models:
|
|
||||||
|
|
||||||
* `wiki80_cnn_softmax`: trained on `wiki80` dataset with a CNN encoder.
|
|
||||||
* `wiki80_bert_softmax`: trained on `wiki80` dataset with a BERT encoder.
|
|
||||||
* `wiki80_bertentity_softmax`: trained on `wiki80` dataset with a BERT encoder (using entity representation concatenation).
|
|
||||||
* `tacred_bert_softmax`: trained on `TACRED` dataset with a BERT encoder.
|
|
||||||
* `tacred_bertentity_softmax`: trained on `TACRED` dataset with a BERT encoder (using entity representation concatenation).
|
|
||||||
|
|
||||||
## Datasets
|
|
||||||
|
|
||||||
You can go into the `benchmark` folder and download datasets using our scripts. We also list some of the information about the datasets in [this document](https://opennre-docs.readthedocs.io/en/latest/get_started/benchmark.html#bag-level-relation-extraction).
|
|
||||||
|
|
||||||
## Training
|
|
||||||
|
|
||||||
You can train your own models on your own data with OpenNRE. In `example` folder we give example training codes for supervised RE models and bag-level RE models. You can either use our provided datasets or your own datasets. For example, you can use the following script to train a PCNN-ATT bag-level model on the NYT10 dataset with manual test set:
|
|
||||||
```bash
|
|
||||||
python example/train_bag_cnn.py \
|
|
||||||
--metric auc \
|
|
||||||
--dataset nyt10m \
|
|
||||||
--batch_size 160 \
|
|
||||||
--lr 0.1 \
|
|
||||||
--weight_decay 1e-5 \
|
|
||||||
--max_epoch 100 \
|
|
||||||
--max_length 128 \
|
|
||||||
--seed 42 \
|
|
||||||
--encoder pcnn \
|
|
||||||
--aggr att
|
|
||||||
```
|
|
||||||
|
|
||||||
Or use the following script to train a BERT model on the Wiki80 dataset:
|
|
||||||
```bash
|
|
||||||
python example/train_supervised_bert.py \
|
|
||||||
--pretrain_path bert-base-uncased \
|
|
||||||
--dataset wiki80
|
|
||||||
```
|
|
||||||
|
|
||||||
We provide many options in the example training code and you can check them out for detailed instructions.
|
|
|
@ -1,479 +0,0 @@
|
||||||
{
|
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 0,
|
|
||||||
"metadata": {
|
|
||||||
"colab": {
|
|
||||||
"provenance": [],
|
|
||||||
"collapsed_sections": []
|
|
||||||
},
|
|
||||||
"kernelspec": {
|
|
||||||
"name": "python3",
|
|
||||||
"display_name": "Python 3"
|
|
||||||
},
|
|
||||||
"language_info": {
|
|
||||||
"name": "python"
|
|
||||||
},
|
|
||||||
"accelerator": "GPU"
|
|
||||||
},
|
|
||||||
"cells": [
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"source": [
|
|
||||||
"from google.colab import drive\n",
|
|
||||||
"drive.mount('/content/drive')"
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"colab": {
|
|
||||||
"base_uri": "https://localhost:8080/"
|
|
||||||
},
|
|
||||||
"id": "apMnCeBi0kAa",
|
|
||||||
"executionInfo": {
|
|
||||||
"status": "ok",
|
|
||||||
"timestamp": 1668340169194,
|
|
||||||
"user_tz": -210,
|
|
||||||
"elapsed": 26751,
|
|
||||||
"user": {
|
|
||||||
"displayName": "Mohammad Ebrahimi",
|
|
||||||
"userId": "10407139745331958037"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"outputId": "852beb1d-3b81-4f1e-8b41-511ee62a6458"
|
|
||||||
},
|
|
||||||
"execution_count": 2,
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"output_type": "stream",
|
|
||||||
"name": "stdout",
|
|
||||||
"text": [
|
|
||||||
"Mounted at /content/drive\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 1,
|
|
||||||
"metadata": {
|
|
||||||
"colab": {
|
|
||||||
"base_uri": "https://localhost:8080/"
|
|
||||||
},
|
|
||||||
"id": "juTvp3cXy0vT",
|
|
||||||
"executionInfo": {
|
|
||||||
"status": "ok",
|
|
||||||
"timestamp": 1668340312085,
|
|
||||||
"user_tz": -210,
|
|
||||||
"elapsed": 2,
|
|
||||||
"user": {
|
|
||||||
"displayName": "Mohammad Ebrahimi",
|
|
||||||
"userId": "10407139745331958037"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"outputId": "76fc9722-9631-4573-b68c-f2d460ebaf96"
|
|
||||||
},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"output_type": "stream",
|
|
||||||
"name": "stdout",
|
|
||||||
"text": [
|
|
||||||
"/content/drive/MyDrive/OpenNRE-master\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"%cd /content/drive/MyDrive/OpenNRE-master"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"source": [
|
|
||||||
"!pip install -r requirements.txt"
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"colab": {
|
|
||||||
"base_uri": "https://localhost:8080/"
|
|
||||||
},
|
|
||||||
"id": "Zw-L_bBWy3q6",
|
|
||||||
"executionInfo": {
|
|
||||||
"status": "ok",
|
|
||||||
"timestamp": 1668340295967,
|
|
||||||
"user_tz": -210,
|
|
||||||
"elapsed": 119014,
|
|
||||||
"user": {
|
|
||||||
"displayName": "Mohammad Ebrahimi",
|
|
||||||
"userId": "10407139745331958037"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"outputId": "c73fdd46-9d49-4643-a588-3021e6395c32"
|
|
||||||
},
|
|
||||||
"execution_count": 5,
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"ERROR: Could not find a version that satisfies the requirement torch==1.6.0 (from versions: 1.7.1, 1.8.0, 1.8.1, 1.9.0, 1.9.1, 1.10.0, 1.10.1, 1.10.2, 1.11.0, 1.12.0, 1.12.1, 1.13.0, 1.13.1)\n",
|
|
||||||
"ERROR: No matching distribution found for torch==1.6.0\n",
|
|
||||||
"\n",
|
|
||||||
"[notice] A new release of pip available: 22.3.1 -> 23.0\n",
|
|
||||||
"[notice] To update, run: python.exe -m pip install --upgrade pip\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Requirement already satisfied: pip in e:\\hamed\\work\\5\\opennre-master\\venv\\lib\\site-packages (22.3.1)\n",
|
|
||||||
"Collecting pip\n",
|
|
||||||
" Using cached pip-23.0-py3-none-any.whl (2.1 MB)\n",
|
|
||||||
"Installing collected packages: pip\n",
|
|
||||||
" Attempting uninstall: pip\n",
|
|
||||||
" Found existing installation: pip 22.3.1\n",
|
|
||||||
" Uninstalling pip-22.3.1:\n",
|
|
||||||
" Successfully uninstalled pip-22.3.1\n",
|
|
||||||
"Successfully installed pip-23.0\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"source": [
|
|
||||||
"# %%shell\n",
|
|
||||||
"!python train_supervised_bert.py \\\n",
|
|
||||||
" --pretrain_path HooshvareLab/bert-base-parsbert-uncased \\\n",
|
|
||||||
" --dataset none \\\n",
|
|
||||||
" --train_file ./Perlex/Perlex_train.txt \\\n",
|
|
||||||
" --val_file ./Perlex/Perlex_val.txt \\\n",
|
|
||||||
" --test_file ./Perlex/Perlex_test.txt \\\n",
|
|
||||||
" --rel2id_file ./Perlex/Perlex_rel2id.json \\\n",
|
|
||||||
" --max_epoch 20"
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"colab": {
|
|
||||||
"base_uri": "https://localhost:8080/"
|
|
||||||
},
|
|
||||||
"id": "aEiTLWpPzV3J",
|
|
||||||
"executionInfo": {
|
|
||||||
"status": "ok",
|
|
||||||
"timestamp": 1665308620117,
|
|
||||||
"user_tz": -210,
|
|
||||||
"elapsed": 7519388,
|
|
||||||
"user": {
|
|
||||||
"displayName": "arian ebrahimi",
|
|
||||||
"userId": "00418818321983401320"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"outputId": "90c72478-d09d-4f96-9417-7c2ad805a66b"
|
|
||||||
},
|
|
||||||
"execution_count": 3,
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Traceback (most recent call last):\n",
|
|
||||||
" File \"E:\\Hamed\\Work\\5\\OpenNRE-master\\train_supervised_bert.py\", line 2, in <module>\n",
|
|
||||||
" import torch\n",
|
|
||||||
"ModuleNotFoundError: No module named 'torch'\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"source": [
|
|
||||||
"%%shell\n",
|
|
||||||
"python train_supervised_bert.py \\\n",
|
|
||||||
" --pretrain_path HooshvareLab/bert-base-parsbert-uncased \\\n",
|
|
||||||
" --dataset none \\\n",
|
|
||||||
" --train_file ./Perlex/Perlex_train.txt \\\n",
|
|
||||||
" --val_file ./Perlex/Perlex_val.txt \\\n",
|
|
||||||
" --test_file ./Perlex/Perlex_test.txt \\\n",
|
|
||||||
" --rel2id_file ./Perlex/Perlex_rel2id.json \\\n",
|
|
||||||
" --max_epoch 20"
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"id": "zhYjakhv13c8",
|
|
||||||
"colab": {
|
|
||||||
"base_uri": "https://localhost:8080/"
|
|
||||||
},
|
|
||||||
"executionInfo": {
|
|
||||||
"status": "ok",
|
|
||||||
"timestamp": 1665318000439,
|
|
||||||
"user_tz": -210,
|
|
||||||
"elapsed": 7490397,
|
|
||||||
"user": {
|
|
||||||
"displayName": "arian ebrahimi",
|
|
||||||
"userId": "00418818321983401320"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"outputId": "d2d9d494-0222-425f-ff3a-50e01b33419f"
|
|
||||||
},
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"output_type": "stream",
|
|
||||||
"name": "stdout",
|
|
||||||
"text": [
|
|
||||||
"100% 242/242 [05:39<00:00, 1.40s/it, acc=0.285, loss=2.38]\n",
|
|
||||||
"100% 47/47 [00:25<00:00, 1.82it/s, acc=0.627]\n",
|
|
||||||
"100% 242/242 [05:46<00:00, 1.43s/it, acc=0.643, loss=1.2]\n",
|
|
||||||
"100% 47/47 [00:25<00:00, 1.82it/s, acc=0.718]\n",
|
|
||||||
"100% 242/242 [05:46<00:00, 1.43s/it, acc=0.739, loss=0.848]\n",
|
|
||||||
"100% 47/47 [00:25<00:00, 1.82it/s, acc=0.742]\n",
|
|
||||||
"100% 242/242 [05:46<00:00, 1.43s/it, acc=0.797, loss=0.675]\n",
|
|
||||||
"100% 47/47 [00:25<00:00, 1.83it/s, acc=0.742]\n",
|
|
||||||
"100% 242/242 [05:46<00:00, 1.43s/it, acc=0.842, loss=0.544]\n",
|
|
||||||
"100% 47/47 [00:25<00:00, 1.82it/s, acc=0.752]\n",
|
|
||||||
"100% 242/242 [05:46<00:00, 1.43s/it, acc=0.872, loss=0.452]\n",
|
|
||||||
"100% 47/47 [00:25<00:00, 1.82it/s, acc=0.742]\n",
|
|
||||||
"100% 242/242 [05:45<00:00, 1.43s/it, acc=0.901, loss=0.354]\n",
|
|
||||||
"100% 47/47 [00:25<00:00, 1.82it/s, acc=0.749]\n",
|
|
||||||
"100% 242/242 [05:46<00:00, 1.43s/it, acc=0.926, loss=0.292]\n",
|
|
||||||
"100% 47/47 [00:25<00:00, 1.82it/s, acc=0.746]\n",
|
|
||||||
"100% 242/242 [05:46<00:00, 1.43s/it, acc=0.941, loss=0.236]\n",
|
|
||||||
"100% 47/47 [00:25<00:00, 1.81it/s, acc=0.747]\n",
|
|
||||||
"100% 242/242 [05:46<00:00, 1.43s/it, acc=0.951, loss=0.195]\n",
|
|
||||||
"100% 47/47 [00:25<00:00, 1.82it/s, acc=0.748]\n",
|
|
||||||
"100% 242/242 [05:46<00:00, 1.43s/it, acc=0.962, loss=0.162]\n",
|
|
||||||
"100% 47/47 [00:25<00:00, 1.83it/s, acc=0.746]\n",
|
|
||||||
"100% 242/242 [05:46<00:00, 1.43s/it, acc=0.969, loss=0.14]\n",
|
|
||||||
"100% 47/47 [00:25<00:00, 1.83it/s, acc=0.744]\n",
|
|
||||||
"100% 242/242 [05:46<00:00, 1.43s/it, acc=0.974, loss=0.118]\n",
|
|
||||||
"100% 47/47 [00:25<00:00, 1.82it/s, acc=0.746]\n",
|
|
||||||
"100% 242/242 [05:46<00:00, 1.43s/it, acc=0.978, loss=0.102]\n",
|
|
||||||
"100% 47/47 [00:25<00:00, 1.82it/s, acc=0.746]\n",
|
|
||||||
"100% 242/242 [05:46<00:00, 1.43s/it, acc=0.981, loss=0.0912]\n",
|
|
||||||
"100% 47/47 [00:25<00:00, 1.82it/s, acc=0.74]\n",
|
|
||||||
"100% 242/242 [05:46<00:00, 1.43s/it, acc=0.985, loss=0.081]\n",
|
|
||||||
"100% 47/47 [00:25<00:00, 1.83it/s, acc=0.742]\n",
|
|
||||||
"100% 242/242 [05:46<00:00, 1.43s/it, acc=0.986, loss=0.0736]\n",
|
|
||||||
"100% 47/47 [00:25<00:00, 1.83it/s, acc=0.74]\n",
|
|
||||||
"100% 242/242 [05:46<00:00, 1.43s/it, acc=0.986, loss=0.0724]\n",
|
|
||||||
"100% 47/47 [00:25<00:00, 1.81it/s, acc=0.742]\n",
|
|
||||||
"100% 242/242 [05:46<00:00, 1.43s/it, acc=0.99, loss=0.065]\n",
|
|
||||||
"100% 47/47 [00:25<00:00, 1.82it/s, acc=0.742]\n",
|
|
||||||
"100% 242/242 [05:46<00:00, 1.43s/it, acc=0.99, loss=0.0624]\n",
|
|
||||||
"100% 47/47 [00:25<00:00, 1.82it/s, acc=0.739]\n",
|
|
||||||
"100% 47/47 [00:25<00:00, 1.82it/s, acc=0.744]\n",
|
|
||||||
"Test set results:\n",
|
|
||||||
"Accuracy: 0.7444963308872582\n",
|
|
||||||
"Micro precision: 0.7831612390786339\n",
|
|
||||||
"Micro recall: 0.7919678714859437\n",
|
|
||||||
"Micro F1: 0.7875399361022364\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"output_type": "execute_result",
|
|
||||||
"data": {
|
|
||||||
"text/plain": []
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"execution_count": 3
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"source": [
|
|
||||||
"%%shell\n",
|
|
||||||
"python train_supervised_bert.py \\\n",
|
|
||||||
" --pretrain_path HooshvareLab/bert-base-parsbert-uncased \\\n",
|
|
||||||
" --dataset none \\\n",
|
|
||||||
" --train_file ./Perlex/Perlex_train.txt \\\n",
|
|
||||||
" --val_file ./Perlex/Perlex_val.txt \\\n",
|
|
||||||
" --test_file ./Perlex/Perlex_test.txt \\\n",
|
|
||||||
" --rel2id_file ./Perlex/Perlex_rel2id.json \\\n",
|
|
||||||
" --max_epoch 20 \\\n",
|
|
||||||
" --lr 15e-6"
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"colab": {
|
|
||||||
"base_uri": "https://localhost:8080/"
|
|
||||||
},
|
|
||||||
"outputId": "bcdc6f55-f9b9-40b0-a661-777145267ede",
|
|
||||||
"id": "zUa7CDEeFBGJ"
|
|
||||||
},
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"output_type": "stream",
|
|
||||||
"name": "stdout",
|
|
||||||
"text": [
|
|
||||||
"100% 242/242 [05:38<00:00, 1.40s/it, acc=0.444, loss=1.84]\n",
|
|
||||||
"100% 47/47 [00:25<00:00, 1.83it/s, acc=0.704]\n",
|
|
||||||
"100% 242/242 [05:46<00:00, 1.43s/it, acc=0.733, loss=0.842]\n",
|
|
||||||
"100% 47/47 [00:25<00:00, 1.82it/s, acc=0.753]\n",
|
|
||||||
"100% 242/242 [05:46<00:00, 1.43s/it, acc=0.871, loss=0.432]\n",
|
|
||||||
"100% 47/47 [00:25<00:00, 1.83it/s, acc=0.764]\n",
|
|
||||||
"100% 242/242 [05:46<00:00, 1.43s/it, acc=0.942, loss=0.208]\n",
|
|
||||||
"100% 47/47 [00:25<00:00, 1.82it/s, acc=0.747]\n",
|
|
||||||
"100% 242/242 [05:46<00:00, 1.43s/it, acc=0.979, loss=0.0991]\n",
|
|
||||||
"100% 47/47 [00:25<00:00, 1.81it/s, acc=0.75]\n",
|
|
||||||
"100% 242/242 [05:46<00:00, 1.43s/it, acc=0.991, loss=0.0488]\n",
|
|
||||||
"100% 47/47 [00:25<00:00, 1.81it/s, acc=0.737]\n",
|
|
||||||
"100% 242/242 [05:46<00:00, 1.43s/it, acc=0.998, loss=0.0243]\n",
|
|
||||||
"100% 47/47 [00:25<00:00, 1.82it/s, acc=0.746]\n",
|
|
||||||
"100% 242/242 [05:45<00:00, 1.43s/it, acc=0.999, loss=0.0157]\n",
|
|
||||||
"100% 47/47 [00:25<00:00, 1.82it/s, acc=0.747]\n",
|
|
||||||
"100% 242/242 [05:45<00:00, 1.43s/it, acc=1, loss=0.0098]\n",
|
|
||||||
"100% 47/47 [00:25<00:00, 1.83it/s, acc=0.753]\n",
|
|
||||||
"100% 242/242 [05:45<00:00, 1.43s/it, acc=0.999, loss=0.00806]\n",
|
|
||||||
"100% 47/47 [00:25<00:00, 1.82it/s, acc=0.751]\n",
|
|
||||||
" 65% 157/242 [03:45<02:01, 1.43s/it, acc=0.999, loss=0.00759]"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"source": [
|
|
||||||
"%%shell\n",
|
|
||||||
"python train_supervised_bert.py \\\n",
|
|
||||||
" --pretrain_path HooshvareLab/bert-base-parsbert-uncased \\\n",
|
|
||||||
" --dataset none \\\n",
|
|
||||||
" --train_file ./Perlex/Perlex_train.txt \\\n",
|
|
||||||
" --val_file ./Perlex/Perlex_val.txt \\\n",
|
|
||||||
" --test_file ./Perlex/Perlex_test.txt \\\n",
|
|
||||||
" --rel2id_file ./Perlex/Perlex_rel2id.json \\\n",
|
|
||||||
" --max_epoch 20 \\\n",
|
|
||||||
" --lr 15e-6"
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"colab": {
|
|
||||||
"base_uri": "https://localhost:8080/"
|
|
||||||
},
|
|
||||||
"id": "spUwXoWnPchu",
|
|
||||||
"outputId": "b1b389c8-c1b5-41bf-ee81-7dfc2cb5f8a0"
|
|
||||||
},
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"output_type": "stream",
|
|
||||||
"name": "stdout",
|
|
||||||
"text": [
|
|
||||||
"Downloading: 100% 654M/654M [00:09<00:00, 66.7MB/s]\n",
|
|
||||||
"Downloading: 100% 1.22M/1.22M [00:01<00:00, 872kB/s]\n",
|
|
||||||
"100% 242/242 [05:18<00:00, 1.32s/it, acc=0.444, loss=1.84]\n",
|
|
||||||
"100% 47/47 [00:23<00:00, 1.99it/s, acc=0.704]\n",
|
|
||||||
"100% 242/242 [05:29<00:00, 1.36s/it, acc=0.733, loss=0.842]\n",
|
|
||||||
"100% 47/47 [00:23<00:00, 1.98it/s, acc=0.753]\n",
|
|
||||||
"100% 242/242 [05:29<00:00, 1.36s/it, acc=0.871, loss=0.432]\n",
|
|
||||||
"100% 47/47 [00:23<00:00, 1.99it/s, acc=0.764]\n",
|
|
||||||
"100% 242/242 [05:29<00:00, 1.36s/it, acc=0.942, loss=0.208]\n",
|
|
||||||
"100% 47/47 [00:23<00:00, 1.99it/s, acc=0.747]\n",
|
|
||||||
"100% 242/242 [05:28<00:00, 1.36s/it, acc=0.979, loss=0.0991]\n",
|
|
||||||
"100% 47/47 [00:23<00:00, 1.99it/s, acc=0.75]\n",
|
|
||||||
"100% 242/242 [05:29<00:00, 1.36s/it, acc=0.991, loss=0.0488]\n",
|
|
||||||
"100% 47/47 [00:23<00:00, 2.00it/s, acc=0.737]\n",
|
|
||||||
" 31% 75/242 [01:43<03:47, 1.36s/it, acc=0.998, loss=0.0257]"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"source": [
|
|
||||||
"%%shell\n",
|
|
||||||
"python train_supervised_bert.py \\\n",
|
|
||||||
" --pretrain_path HooshvareLab/bert-base-parsbert-uncased \\\n",
|
|
||||||
" --dataset none \\\n",
|
|
||||||
" --train_file ./Perlex/Perlex_train.txt \\\n",
|
|
||||||
" --val_file ./Perlex/Perlex_val.txt \\\n",
|
|
||||||
" --test_file ./Perlex/Perlex_test.txt \\\n",
|
|
||||||
" --rel2id_file ./Perlex/Perlex_rel2id.json \\\n",
|
|
||||||
" --max_epoch 20 \\\n",
|
|
||||||
" --lr 15e-6"
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"id": "_lWuRi4EuuV4",
|
|
||||||
"colab": {
|
|
||||||
"base_uri": "https://localhost:8080/"
|
|
||||||
},
|
|
||||||
"executionInfo": {
|
|
||||||
"status": "ok",
|
|
||||||
"timestamp": 1668347083563,
|
|
||||||
"user_tz": -210,
|
|
||||||
"elapsed": 441647,
|
|
||||||
"user": {
|
|
||||||
"displayName": "Mohammad Ebrahimi",
|
|
||||||
"userId": "10407139745331958037"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"outputId": "4080fb9e-cf2c-498c-a412-a8ba1b44ad74"
|
|
||||||
},
|
|
||||||
"execution_count": 3,
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"output_type": "stream",
|
|
||||||
"name": "stdout",
|
|
||||||
"text": [
|
|
||||||
"100% 101/101 [02:00<00:00, 1.20s/it, acc=0.278, loss=2.4]\n",
|
|
||||||
"100% 25/25 [00:12<00:00, 2.06it/s, acc=0.541]\n",
|
|
||||||
"100% 101/101 [02:11<00:00, 1.30s/it, acc=0.667, loss=1.15]\n",
|
|
||||||
"100% 25/25 [00:12<00:00, 1.97it/s, acc=0.615]\n",
|
|
||||||
"100% 101/101 [02:12<00:00, 1.31s/it, acc=0.784, loss=0.721]\n",
|
|
||||||
"100% 25/25 [00:12<00:00, 1.99it/s, acc=0.66]\n",
|
|
||||||
"100% 101/101 [02:12<00:00, 1.31s/it, acc=0.879, loss=0.422]\n",
|
|
||||||
"100% 25/25 [00:12<00:00, 2.00it/s, acc=0.656]\n",
|
|
||||||
"100% 101/101 [02:12<00:00, 1.31s/it, acc=0.947, loss=0.209]\n",
|
|
||||||
"100% 25/25 [00:12<00:00, 1.98it/s, acc=0.659]\n",
|
|
||||||
"100% 101/101 [02:12<00:00, 1.31s/it, acc=0.981, loss=0.0976]\n",
|
|
||||||
"100% 25/25 [00:12<00:00, 1.97it/s, acc=0.655]\n",
|
|
||||||
"100% 101/101 [02:12<00:00, 1.31s/it, acc=0.99, loss=0.0528]\n",
|
|
||||||
"100% 25/25 [00:12<00:00, 1.98it/s, acc=0.655]\n",
|
|
||||||
"100% 101/101 [02:12<00:00, 1.31s/it, acc=0.997, loss=0.0308]\n",
|
|
||||||
"100% 25/25 [00:12<00:00, 1.96it/s, acc=0.659]\n",
|
|
||||||
"100% 101/101 [02:12<00:00, 1.31s/it, acc=0.998, loss=0.0199]\n",
|
|
||||||
"100% 25/25 [00:12<00:00, 1.98it/s, acc=0.657]\n",
|
|
||||||
"100% 101/101 [02:12<00:00, 1.31s/it, acc=0.999, loss=0.0145]\n",
|
|
||||||
"100% 25/25 [00:12<00:00, 1.93it/s, acc=0.659]\n",
|
|
||||||
"100% 101/101 [02:12<00:00, 1.31s/it, acc=1, loss=0.0111]\n",
|
|
||||||
"100% 25/25 [00:12<00:00, 1.97it/s, acc=0.653]\n",
|
|
||||||
"100% 101/101 [02:12<00:00, 1.31s/it, acc=1, loss=0.00921]\n",
|
|
||||||
"100% 25/25 [00:12<00:00, 1.98it/s, acc=0.658]\n",
|
|
||||||
"100% 101/101 [02:12<00:00, 1.31s/it, acc=1, loss=0.00797]\n",
|
|
||||||
"100% 25/25 [00:13<00:00, 1.89it/s, acc=0.659]\n",
|
|
||||||
"100% 101/101 [02:12<00:00, 1.31s/it, acc=1, loss=0.00623]\n",
|
|
||||||
"100% 25/25 [00:12<00:00, 1.97it/s, acc=0.655]\n",
|
|
||||||
"100% 101/101 [02:12<00:00, 1.31s/it, acc=1, loss=0.00641]\n",
|
|
||||||
"100% 25/25 [00:12<00:00, 1.97it/s, acc=0.65]\n",
|
|
||||||
"100% 101/101 [02:11<00:00, 1.31s/it, acc=1, loss=0.00551]\n",
|
|
||||||
"100% 25/25 [00:12<00:00, 1.99it/s, acc=0.655]\n",
|
|
||||||
"100% 101/101 [02:12<00:00, 1.31s/it, acc=1, loss=0.00564]\n",
|
|
||||||
"100% 25/25 [00:12<00:00, 1.99it/s, acc=0.655]\n",
|
|
||||||
"100% 101/101 [02:12<00:00, 1.31s/it, acc=1, loss=0.00474]\n",
|
|
||||||
"100% 25/25 [00:12<00:00, 1.93it/s, acc=0.654]\n",
|
|
||||||
"100% 101/101 [02:12<00:00, 1.31s/it, acc=0.999, loss=0.00528]\n",
|
|
||||||
"100% 25/25 [00:12<00:00, 1.95it/s, acc=0.655]\n",
|
|
||||||
"100% 101/101 [02:11<00:00, 1.31s/it, acc=1, loss=0.00506]\n",
|
|
||||||
"100% 25/25 [00:12<00:00, 1.98it/s, acc=0.653]\n",
|
|
||||||
"100% 43/43 [00:21<00:00, 2.04it/s, acc=0.725]\n",
|
|
||||||
"Test set results:\n",
|
|
||||||
"Accuracy: 0.7245949926362297\n",
|
|
||||||
"Micro precision: 0.7602262837249782\n",
|
|
||||||
"Micro recall: 0.7723253757736517\n",
|
|
||||||
"Micro F1: 0.7662280701754387\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"output_type": "execute_result",
|
|
||||||
"data": {
|
|
||||||
"text/plain": []
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"execution_count": 3
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"source": [],
|
|
||||||
"metadata": {
|
|
||||||
"id": "K8ArBYUg1gmV"
|
|
||||||
},
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": []
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
|
@ -1,19 +0,0 @@
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
from .pretrain import check_root, get_model, download, download_pretrain
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
|
|
||||||
logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=os.environ.get("LOGLEVEL", "INFO"))
|
|
||||||
|
|
||||||
def fix_seed(seed=12345):
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
import random
|
|
||||||
torch.manual_seed(seed) # cpu
|
|
||||||
torch.cuda.manual_seed(seed) # gpu
|
|
||||||
np.random.seed(seed) # numpy
|
|
||||||
random.seed(seed) # random and transforms
|
|
||||||
torch.backends.cudnn.deterministic=True # cudnn
|
|
|
@ -1,14 +0,0 @@
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
from .cnn_encoder import CNNEncoder
|
|
||||||
from .pcnn_encoder import PCNNEncoder
|
|
||||||
from .bert_encoder import BERTEncoder, BERTEntityEncoder
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
'CNNEncoder',
|
|
||||||
'PCNNEncoder',
|
|
||||||
'BERTEncoder',
|
|
||||||
'BERTEntityEncoder'
|
|
||||||
]
|
|
|
@ -1,154 +0,0 @@
|
||||||
import math, logging
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import numpy as np
|
|
||||||
from ..tokenization import WordTokenizer
|
|
||||||
|
|
||||||
class BaseEncoder(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
token2id,
|
|
||||||
max_length=128,
|
|
||||||
hidden_size=230,
|
|
||||||
word_size=50,
|
|
||||||
position_size=5,
|
|
||||||
blank_padding=True,
|
|
||||||
word2vec=None,
|
|
||||||
mask_entity=False):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
token2id: dictionary of token->idx mapping
|
|
||||||
max_length: max length of sentence, used for postion embedding
|
|
||||||
hidden_size: hidden size
|
|
||||||
word_size: size of word embedding
|
|
||||||
position_size: size of position embedding
|
|
||||||
blank_padding: padding for CNN
|
|
||||||
word2vec: pretrained word2vec numpy
|
|
||||||
"""
|
|
||||||
# Hyperparameters
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.token2id = token2id
|
|
||||||
self.max_length = max_length
|
|
||||||
self.num_token = len(token2id)
|
|
||||||
self.num_position = max_length * 2
|
|
||||||
self.mask_entity = mask_entity
|
|
||||||
|
|
||||||
if word2vec is None:
|
|
||||||
self.word_size = word_size
|
|
||||||
else:
|
|
||||||
self.word_size = word2vec.shape[-1]
|
|
||||||
|
|
||||||
self.position_size = position_size
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.input_size = word_size + position_size * 2
|
|
||||||
self.blank_padding = blank_padding
|
|
||||||
|
|
||||||
if not '[UNK]' in self.token2id:
|
|
||||||
self.token2id['[UNK]'] = len(self.token2id)
|
|
||||||
self.num_token += 1
|
|
||||||
if not '[PAD]' in self.token2id:
|
|
||||||
self.token2id['[PAD]'] = len(self.token2id)
|
|
||||||
self.num_token += 1
|
|
||||||
|
|
||||||
# Word embedding
|
|
||||||
self.word_embedding = nn.Embedding(self.num_token, self.word_size)
|
|
||||||
if word2vec is not None:
|
|
||||||
logging.info("Initializing word embedding with word2vec.")
|
|
||||||
word2vec = torch.from_numpy(word2vec)
|
|
||||||
if self.num_token == len(word2vec) + 2:
|
|
||||||
unk = torch.randn(1, self.word_size) / math.sqrt(self.word_size)
|
|
||||||
blk = torch.zeros(1, self.word_size)
|
|
||||||
self.word_embedding.weight.data.copy_(torch.cat([word2vec, unk, blk], 0))
|
|
||||||
else:
|
|
||||||
self.word_embedding.weight.data.copy_(word2vec)
|
|
||||||
|
|
||||||
# Position Embedding
|
|
||||||
self.pos1_embedding = nn.Embedding(2 * max_length, self.position_size, padding_idx=0)
|
|
||||||
self.pos2_embedding = nn.Embedding(2 * max_length, self.position_size, padding_idx=0)
|
|
||||||
self.tokenizer = WordTokenizer(vocab=self.token2id, unk_token="[UNK]")
|
|
||||||
|
|
||||||
def forward(self, token, pos1, pos2):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
token: (B, L), index of tokens
|
|
||||||
pos1: (B, L), relative position to head entity
|
|
||||||
pos2: (B, L), relative position to tail entity
|
|
||||||
Return:
|
|
||||||
(B, H), representations for sentences
|
|
||||||
"""
|
|
||||||
# Check size of tensors
|
|
||||||
pass
|
|
||||||
|
|
||||||
def tokenize(self, item):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
item: input instance, including sentence, entity positions, etc.
|
|
||||||
Return:
|
|
||||||
index number of tokens and positions
|
|
||||||
"""
|
|
||||||
if 'text' in item:
|
|
||||||
sentence = item['text']
|
|
||||||
is_token = False
|
|
||||||
else:
|
|
||||||
sentence = item['token']
|
|
||||||
is_token = True
|
|
||||||
pos_head = item['h']['pos']
|
|
||||||
pos_tail = item['t']['pos']
|
|
||||||
|
|
||||||
# Sentence -> token
|
|
||||||
if not is_token:
|
|
||||||
if pos_head[0] > pos_tail[0]:
|
|
||||||
pos_min, pos_max = [pos_tail, pos_head]
|
|
||||||
rev = True
|
|
||||||
else:
|
|
||||||
pos_min, pos_max = [pos_head, pos_tail]
|
|
||||||
rev = False
|
|
||||||
sent_0 = self.tokenizer.tokenize(sentence[:pos_min[0]])
|
|
||||||
sent_1 = self.tokenizer.tokenize(sentence[pos_min[1]:pos_max[0]])
|
|
||||||
sent_2 = self.tokenizer.tokenize(sentence[pos_max[1]:])
|
|
||||||
ent_0 = self.tokenizer.tokenize(sentence[pos_min[0]:pos_min[1]])
|
|
||||||
ent_1 = self.tokenizer.tokenize(sentence[pos_max[0]:pos_max[1]])
|
|
||||||
if self.mask_entity:
|
|
||||||
ent_0 = ['[UNK]']
|
|
||||||
ent_1 = ['[UNK]']
|
|
||||||
tokens = sent_0 + ent_0 + sent_1 + ent_1 + sent_2
|
|
||||||
if rev:
|
|
||||||
pos_tail = [len(sent_0), len(sent_0) + len(ent_0)]
|
|
||||||
pos_head = [len(sent_0) + len(ent_0) + len(sent_1), len(sent_0) + len(ent_0) + len(sent_1) + len(ent_1)]
|
|
||||||
else:
|
|
||||||
pos_head = [len(sent_0), len(sent_0) + len(ent_0)]
|
|
||||||
pos_tail = [len(sent_0) + len(ent_0) + len(sent_1), len(sent_0) + len(ent_0) + len(sent_1) + len(ent_1)]
|
|
||||||
else:
|
|
||||||
tokens = sentence
|
|
||||||
|
|
||||||
# Token -> index
|
|
||||||
if self.blank_padding:
|
|
||||||
indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokens, self.max_length, self.token2id['[PAD]'], self.token2id['[UNK]'])
|
|
||||||
else:
|
|
||||||
indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokens, unk_id = self.token2id['[UNK]'])
|
|
||||||
|
|
||||||
# Position -> index
|
|
||||||
pos1 = []
|
|
||||||
pos2 = []
|
|
||||||
pos1_in_index = min(pos_head[0], self.max_length)
|
|
||||||
pos2_in_index = min(pos_tail[0], self.max_length)
|
|
||||||
for i in range(len(tokens)):
|
|
||||||
pos1.append(min(i - pos1_in_index + self.max_length, 2 * self.max_length - 1))
|
|
||||||
pos2.append(min(i - pos2_in_index + self.max_length, 2 * self.max_length - 1))
|
|
||||||
|
|
||||||
if self.blank_padding:
|
|
||||||
while len(pos1) < self.max_length:
|
|
||||||
pos1.append(0)
|
|
||||||
while len(pos2) < self.max_length:
|
|
||||||
pos2.append(0)
|
|
||||||
indexed_tokens = indexed_tokens[:self.max_length]
|
|
||||||
pos1 = pos1[:self.max_length]
|
|
||||||
pos2 = pos2[:self.max_length]
|
|
||||||
|
|
||||||
indexed_tokens = torch.tensor(indexed_tokens).long().unsqueeze(0) # (1, L)
|
|
||||||
pos1 = torch.tensor(pos1).long().unsqueeze(0) # (1, L)
|
|
||||||
pos2 = torch.tensor(pos2).long().unsqueeze(0) # (1, L)
|
|
||||||
|
|
||||||
return indexed_tokens, pos1, pos2
|
|
|
@ -1,215 +0,0 @@
|
||||||
import logging
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from transformers import BertModel, BertTokenizer
|
|
||||||
from .base_encoder import BaseEncoder
|
|
||||||
|
|
||||||
class BERTEncoder(nn.Module):
|
|
||||||
def __init__(self, max_length, pretrain_path, blank_padding=True, mask_entity=False):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
max_length: max length of sentence
|
|
||||||
pretrain_path: path of pretrain model
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.max_length = max_length
|
|
||||||
self.blank_padding = blank_padding
|
|
||||||
self.hidden_size = 768
|
|
||||||
self.mask_entity = mask_entity
|
|
||||||
logging.info('Loading BERT pre-trained checkpoint.')
|
|
||||||
self.bert = BertModel.from_pretrained(pretrain_path)
|
|
||||||
self.tokenizer = BertTokenizer.from_pretrained(pretrain_path)
|
|
||||||
|
|
||||||
def forward(self, token, att_mask, pos1, pos2):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
token: (B, L), index of tokens
|
|
||||||
att_mask: (B, L), attention mask (1 for contents and 0 for padding)
|
|
||||||
Return:
|
|
||||||
(B, H), representations for sentences
|
|
||||||
"""
|
|
||||||
_, x = self.bert(token, attention_mask=att_mask, return_dict=False)
|
|
||||||
#return_dict=fault is set to adapt to the new version of transformers
|
|
||||||
return x
|
|
||||||
|
|
||||||
def tokenize(self, item):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
item: data instance containing 'text' / 'token', 'h' and 't'
|
|
||||||
Return:
|
|
||||||
Name of the relation of the sentence
|
|
||||||
"""
|
|
||||||
# Sentence -> token
|
|
||||||
if 'text' in item:
|
|
||||||
sentence = item['text']
|
|
||||||
is_token = False
|
|
||||||
else:
|
|
||||||
sentence = item['token']
|
|
||||||
is_token = True
|
|
||||||
pos_head = item['h']['pos']
|
|
||||||
pos_tail = item['t']['pos']
|
|
||||||
|
|
||||||
pos_min = pos_head
|
|
||||||
pos_max = pos_tail
|
|
||||||
if pos_head[0] > pos_tail[0]:
|
|
||||||
pos_min = pos_tail
|
|
||||||
pos_max = pos_head
|
|
||||||
rev = True
|
|
||||||
else:
|
|
||||||
rev = False
|
|
||||||
|
|
||||||
if not is_token:
|
|
||||||
sent0 = self.tokenizer.tokenize(sentence[:pos_min[0]])
|
|
||||||
ent0 = self.tokenizer.tokenize(sentence[pos_min[0]:pos_min[1]])
|
|
||||||
sent1 = self.tokenizer.tokenize(sentence[pos_min[1]:pos_max[0]])
|
|
||||||
ent1 = self.tokenizer.tokenize(sentence[pos_max[0]:pos_max[1]])
|
|
||||||
sent2 = self.tokenizer.tokenize(sentence[pos_max[1]:])
|
|
||||||
else:
|
|
||||||
sent0 = self.tokenizer.tokenize(' '.join(sentence[:pos_min[0]]))
|
|
||||||
ent0 = self.tokenizer.tokenize(' '.join(sentence[pos_min[0]:pos_min[1]]))
|
|
||||||
sent1 = self.tokenizer.tokenize(' '.join(sentence[pos_min[1]:pos_max[0]]))
|
|
||||||
ent1 = self.tokenizer.tokenize(' '.join(sentence[pos_max[0]:pos_max[1]]))
|
|
||||||
sent2 = self.tokenizer.tokenize(' '.join(sentence[pos_max[1]:]))
|
|
||||||
|
|
||||||
if self.mask_entity:
|
|
||||||
ent0 = ['[unused4]'] if not rev else ['[unused5]']
|
|
||||||
ent1 = ['[unused5]'] if not rev else ['[unused4]']
|
|
||||||
else:
|
|
||||||
ent0 = ['[unused0]'] + ent0 + ['[unused1]'] if not rev else ['[unused2]'] + ent0 + ['[unused3]']
|
|
||||||
ent1 = ['[unused2]'] + ent1 + ['[unused3]'] if not rev else ['[unused0]'] + ent1 + ['[unused1]']
|
|
||||||
|
|
||||||
re_tokens = ['[CLS]'] + sent0 + ent0 + sent1 + ent1 + sent2 + ['[SEP]']
|
|
||||||
|
|
||||||
pos1 = 1 + len(sent0) if not rev else 1 + len(sent0 + ent0 + sent1)
|
|
||||||
pos2 = 1 + len(sent0 + ent0 + sent1) if not rev else 1 + len(sent0)
|
|
||||||
pos1 = min(self.max_length - 1, pos1)
|
|
||||||
pos2 = min(self.max_length - 1, pos2)
|
|
||||||
indexed_tokens = self.tokenizer.convert_tokens_to_ids(re_tokens)
|
|
||||||
avai_len = len(indexed_tokens)
|
|
||||||
pos1 = torch.tensor([[pos1]]).long()
|
|
||||||
pos2 = torch.tensor([[pos2]]).long()
|
|
||||||
|
|
||||||
# Padding
|
|
||||||
if self.blank_padding:
|
|
||||||
while len(indexed_tokens) < self.max_length:
|
|
||||||
indexed_tokens.append(0) # 0 is id for [PAD]
|
|
||||||
indexed_tokens = indexed_tokens[:self.max_length]
|
|
||||||
indexed_tokens = torch.tensor(indexed_tokens).long().unsqueeze(0) # (1, L)
|
|
||||||
|
|
||||||
# Attention mask
|
|
||||||
att_mask = torch.zeros(indexed_tokens.size()).long() # (1, L)
|
|
||||||
att_mask[0, :avai_len] = 1
|
|
||||||
|
|
||||||
return indexed_tokens, att_mask, pos1, pos2
|
|
||||||
|
|
||||||
|
|
||||||
class BERTEntityEncoder(nn.Module):
|
|
||||||
def __init__(self, max_length, pretrain_path, blank_padding=True, mask_entity=False):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
max_length: max length of sentence
|
|
||||||
pretrain_path: path of pretrain model
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.max_length = max_length
|
|
||||||
self.blank_padding = blank_padding
|
|
||||||
self.hidden_size = 768 * 2
|
|
||||||
self.mask_entity = mask_entity
|
|
||||||
logging.info('Loading BERT pre-trained checkpoint.')
|
|
||||||
self.bert = BertModel.from_pretrained(pretrain_path)
|
|
||||||
self.tokenizer = BertTokenizer.from_pretrained(pretrain_path)
|
|
||||||
self.linear = nn.Linear(self.hidden_size, self.hidden_size)
|
|
||||||
|
|
||||||
def forward(self, token, att_mask, pos1, pos2):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
token: (B, L), index of tokens
|
|
||||||
att_mask: (B, L), attention mask (1 for contents and 0 for padding)
|
|
||||||
pos1: (B, 1), position of the head entity starter
|
|
||||||
pos2: (B, 1), position of the tail entity starter
|
|
||||||
Return:
|
|
||||||
(B, 2H), representations for sentences
|
|
||||||
"""
|
|
||||||
hidden, _ = self.bert(token, attention_mask=att_mask, return_dict=False)
|
|
||||||
# Get entity start hidden state
|
|
||||||
onehot_head = torch.zeros(hidden.size()[:2]).float().to(hidden.device) # (B, L)
|
|
||||||
onehot_tail = torch.zeros(hidden.size()[:2]).float().to(hidden.device) # (B, L)
|
|
||||||
onehot_head = onehot_head.scatter_(1, pos1, 1)
|
|
||||||
onehot_tail = onehot_tail.scatter_(1, pos2, 1)
|
|
||||||
head_hidden = (onehot_head.unsqueeze(2) * hidden).sum(1) # (B, H)
|
|
||||||
tail_hidden = (onehot_tail.unsqueeze(2) * hidden).sum(1) # (B, H)
|
|
||||||
x = torch.cat([head_hidden, tail_hidden], 1) # (B, 2H)
|
|
||||||
x = self.linear(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def tokenize(self, item):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
item: data instance containing 'text' / 'token', 'h' and 't'
|
|
||||||
Return:
|
|
||||||
Name of the relation of the sentence
|
|
||||||
"""
|
|
||||||
# Sentence -> token
|
|
||||||
if 'text' in item:
|
|
||||||
sentence = item['text']
|
|
||||||
is_token = False
|
|
||||||
else:
|
|
||||||
sentence = item['token']
|
|
||||||
is_token = True
|
|
||||||
pos_head = item['h']['pos']
|
|
||||||
pos_tail = item['t']['pos']
|
|
||||||
|
|
||||||
pos_min = pos_head
|
|
||||||
pos_max = pos_tail
|
|
||||||
if pos_head[0] > pos_tail[0]:
|
|
||||||
pos_min = pos_tail
|
|
||||||
pos_max = pos_head
|
|
||||||
rev = True
|
|
||||||
else:
|
|
||||||
rev = False
|
|
||||||
|
|
||||||
if not is_token:
|
|
||||||
sent0 = self.tokenizer.tokenize(sentence[:pos_min[0]])
|
|
||||||
ent0 = self.tokenizer.tokenize(sentence[pos_min[0]:pos_min[1]])
|
|
||||||
sent1 = self.tokenizer.tokenize(sentence[pos_min[1]:pos_max[0]])
|
|
||||||
ent1 = self.tokenizer.tokenize(sentence[pos_max[0]:pos_max[1]])
|
|
||||||
sent2 = self.tokenizer.tokenize(sentence[pos_max[1]:])
|
|
||||||
else:
|
|
||||||
sent0 = self.tokenizer.tokenize(' '.join(sentence[:pos_min[0]]))
|
|
||||||
ent0 = self.tokenizer.tokenize(' '.join(sentence[pos_min[0]:pos_min[1]]))
|
|
||||||
sent1 = self.tokenizer.tokenize(' '.join(sentence[pos_min[1]:pos_max[0]]))
|
|
||||||
ent1 = self.tokenizer.tokenize(' '.join(sentence[pos_max[0]:pos_max[1]]))
|
|
||||||
sent2 = self.tokenizer.tokenize(' '.join(sentence[pos_max[1]:]))
|
|
||||||
|
|
||||||
if self.mask_entity:
|
|
||||||
ent0 = ['[unused4]'] if not rev else ['[unused5]']
|
|
||||||
ent1 = ['[unused5]'] if not rev else ['[unused4]']
|
|
||||||
else:
|
|
||||||
ent0 = ['[unused0]'] + ent0 + ['[unused1]'] if not rev else ['[unused2]'] + ent0 + ['[unused3]']
|
|
||||||
ent1 = ['[unused2]'] + ent1 + ['[unused3]'] if not rev else ['[unused0]'] + ent1 + ['[unused1]']
|
|
||||||
|
|
||||||
re_tokens = ['[CLS]'] + sent0 + ent0 + sent1 + ent1 + sent2 + ['[SEP]']
|
|
||||||
pos1 = 1 + len(sent0) if not rev else 1 + len(sent0 + ent0 + sent1)
|
|
||||||
pos2 = 1 + len(sent0 + ent0 + sent1) if not rev else 1 + len(sent0)
|
|
||||||
pos1 = min(self.max_length - 1, pos1)
|
|
||||||
pos2 = min(self.max_length - 1, pos2)
|
|
||||||
|
|
||||||
indexed_tokens = self.tokenizer.convert_tokens_to_ids(re_tokens)
|
|
||||||
avai_len = len(indexed_tokens)
|
|
||||||
|
|
||||||
# Position
|
|
||||||
pos1 = torch.tensor([[pos1]]).long()
|
|
||||||
pos2 = torch.tensor([[pos2]]).long()
|
|
||||||
|
|
||||||
# Padding
|
|
||||||
if self.blank_padding:
|
|
||||||
while len(indexed_tokens) < self.max_length:
|
|
||||||
indexed_tokens.append(0) # 0 is id for [PAD]
|
|
||||||
indexed_tokens = indexed_tokens[:self.max_length]
|
|
||||||
indexed_tokens = torch.tensor(indexed_tokens).long().unsqueeze(0) # (1, L)
|
|
||||||
|
|
||||||
# Attention mask
|
|
||||||
att_mask = torch.zeros(indexed_tokens.size()).long() # (1, L)
|
|
||||||
att_mask[0, :avai_len] = 1
|
|
||||||
|
|
||||||
return indexed_tokens, att_mask, pos1, pos2
|
|
|
@ -1,68 +0,0 @@
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from ..module.nn import CNN
|
|
||||||
from ..module.pool import MaxPool
|
|
||||||
from .base_encoder import BaseEncoder
|
|
||||||
|
|
||||||
class CNNEncoder(BaseEncoder):
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
token2id,
|
|
||||||
max_length=128,
|
|
||||||
hidden_size=230,
|
|
||||||
word_size=50,
|
|
||||||
position_size=5,
|
|
||||||
blank_padding=True,
|
|
||||||
word2vec=None,
|
|
||||||
kernel_size=3,
|
|
||||||
padding_size=1,
|
|
||||||
dropout=0,
|
|
||||||
activation_function=F.relu,
|
|
||||||
mask_entity=False):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
token2id: dictionary of token->idx mapping
|
|
||||||
max_length: max length of sentence, used for postion embedding
|
|
||||||
hidden_size: hidden size
|
|
||||||
word_size: size of word embedding
|
|
||||||
position_size: size of position embedding
|
|
||||||
blank_padding: padding for CNN
|
|
||||||
word2vec: pretrained word2vec numpy
|
|
||||||
kernel_size: kernel_size size for CNN
|
|
||||||
padding_size: padding_size for CNN
|
|
||||||
"""
|
|
||||||
# Hyperparameters
|
|
||||||
super(CNNEncoder, self).__init__(token2id, max_length, hidden_size, word_size, position_size, blank_padding, word2vec, mask_entity=mask_entity)
|
|
||||||
self.drop = nn.Dropout(dropout)
|
|
||||||
self.kernel_size = kernel_size
|
|
||||||
self.padding_size = padding_size
|
|
||||||
self.act = activation_function
|
|
||||||
|
|
||||||
self.conv = nn.Conv1d(self.input_size, self.hidden_size, self.kernel_size, padding=self.padding_size)
|
|
||||||
self.pool = nn.MaxPool1d(self.max_length)
|
|
||||||
|
|
||||||
def forward(self, token, pos1, pos2):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
token: (B, L), index of tokens
|
|
||||||
pos1: (B, L), relative position to head entity
|
|
||||||
pos2: (B, L), relative position to tail entity
|
|
||||||
Return:
|
|
||||||
(B, EMBED), representations for sentences
|
|
||||||
"""
|
|
||||||
# Check size of tensors
|
|
||||||
if len(token.size()) != 2 or token.size() != pos1.size() or token.size() != pos2.size():
|
|
||||||
raise Exception("Size of token, pos1 ans pos2 should be (B, L)")
|
|
||||||
x = torch.cat([self.word_embedding(token),
|
|
||||||
self.pos1_embedding(pos1),
|
|
||||||
self.pos2_embedding(pos2)], 2) # (B, L, EMBED)
|
|
||||||
x = x.transpose(1, 2) # (B, EMBED, L)
|
|
||||||
x = self.act(self.conv(x)) # (B, H, L)
|
|
||||||
x = self.pool(x).squeeze(-1)
|
|
||||||
x = self.drop(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def tokenize(self, item):
|
|
||||||
return super().tokenize(item)
|
|
|
@ -1,173 +0,0 @@
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from ..module.nn import CNN
|
|
||||||
from ..module.pool import MaxPool
|
|
||||||
from .base_encoder import BaseEncoder
|
|
||||||
|
|
||||||
from nltk import word_tokenize
|
|
||||||
|
|
||||||
class PCNNEncoder(BaseEncoder):
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
token2id,
|
|
||||||
max_length=128,
|
|
||||||
hidden_size=230,
|
|
||||||
word_size=50,
|
|
||||||
position_size=5,
|
|
||||||
blank_padding=True,
|
|
||||||
word2vec=None,
|
|
||||||
kernel_size=3,
|
|
||||||
padding_size=1,
|
|
||||||
dropout=0.0,
|
|
||||||
activation_function=F.relu,
|
|
||||||
mask_entity=False):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
token2id: dictionary of token->idx mapping
|
|
||||||
max_length: max length of sentence, used for postion embedding
|
|
||||||
hidden_size: hidden size
|
|
||||||
word_size: size of word embedding
|
|
||||||
position_size: size of position embedding
|
|
||||||
blank_padding: padding for CNN
|
|
||||||
word2vec: pretrained word2vec numpy
|
|
||||||
kernel_size: kernel_size size for CNN
|
|
||||||
padding_size: padding_size for CNN
|
|
||||||
"""
|
|
||||||
# hyperparameters
|
|
||||||
super().__init__(token2id, max_length, hidden_size, word_size, position_size, blank_padding, word2vec, mask_entity=mask_entity)
|
|
||||||
self.drop = nn.Dropout(dropout)
|
|
||||||
self.kernel_size = kernel_size
|
|
||||||
self.padding_size = padding_size
|
|
||||||
self.act = activation_function
|
|
||||||
|
|
||||||
self.conv = nn.Conv1d(self.input_size, self.hidden_size, self.kernel_size, padding=self.padding_size)
|
|
||||||
self.pool = nn.MaxPool1d(self.max_length)
|
|
||||||
self.mask_embedding = nn.Embedding(4, 3)
|
|
||||||
self.mask_embedding.weight.data.copy_(torch.FloatTensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]]))
|
|
||||||
self.mask_embedding.weight.requires_grad = False
|
|
||||||
self._minus = -100
|
|
||||||
|
|
||||||
self.hidden_size *= 3
|
|
||||||
|
|
||||||
def forward(self, token, pos1, pos2, mask):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
token: (B, L), index of tokens
|
|
||||||
pos1: (B, L), relative position to head entity
|
|
||||||
pos2: (B, L), relative position to tail entity
|
|
||||||
Return:
|
|
||||||
(B, EMBED), representations for sentences
|
|
||||||
"""
|
|
||||||
# Check size of tensors
|
|
||||||
if len(token.size()) != 2 or token.size() != pos1.size() or token.size() != pos2.size():
|
|
||||||
raise Exception("Size of token, pos1 ans pos2 should be (B, L)")
|
|
||||||
x = torch.cat([self.word_embedding(token),
|
|
||||||
self.pos1_embedding(pos1),
|
|
||||||
self.pos2_embedding(pos2)], 2) # (B, L, EMBED)
|
|
||||||
x = x.transpose(1, 2) # (B, EMBED, L)
|
|
||||||
x = self.conv(x) # (B, H, L)
|
|
||||||
|
|
||||||
mask = 1 - self.mask_embedding(mask).transpose(1, 2) # (B, L) -> (B, L, 3) -> (B, 3, L)
|
|
||||||
pool1 = self.pool(self.act(x + self._minus * mask[:, 0:1, :])) # (B, H, 1)
|
|
||||||
pool2 = self.pool(self.act(x + self._minus * mask[:, 1:2, :]))
|
|
||||||
pool3 = self.pool(self.act(x + self._minus * mask[:, 2:3, :]))
|
|
||||||
x = torch.cat([pool1, pool2, pool3], 1) # (B, 3H, 1)
|
|
||||||
x = x.squeeze(2) # (B, 3H)
|
|
||||||
x = self.drop(x)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
def tokenize(self, item):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
sentence: string, the input sentence
|
|
||||||
pos_head: [start, end], position of the head entity
|
|
||||||
pos_end: [start, end], position of the tail entity
|
|
||||||
is_token: if is_token == True, sentence becomes an array of token
|
|
||||||
Return:
|
|
||||||
Name of the relation of the sentence
|
|
||||||
"""
|
|
||||||
if 'text' in item:
|
|
||||||
sentence = item['text']
|
|
||||||
is_token = False
|
|
||||||
else:
|
|
||||||
sentence = item['token']
|
|
||||||
is_token = True
|
|
||||||
pos_head = item['h']['pos']
|
|
||||||
pos_tail = item['t']['pos']
|
|
||||||
|
|
||||||
# Sentence -> token
|
|
||||||
if not is_token:
|
|
||||||
if pos_head[0] > pos_tail[0]:
|
|
||||||
pos_min, pos_max = [pos_tail, pos_head]
|
|
||||||
rev = True
|
|
||||||
else:
|
|
||||||
pos_min, pos_max = [pos_head, pos_tail]
|
|
||||||
rev = False
|
|
||||||
sent_0 = self.tokenizer.tokenize(sentence[:pos_min[0]])
|
|
||||||
sent_1 = self.tokenizer.tokenize(sentence[pos_min[1]:pos_max[0]])
|
|
||||||
sent_2 = self.tokenizer.tokenize(sentence[pos_max[1]:])
|
|
||||||
ent_0 = self.tokenizer.tokenize(sentence[pos_min[0]:pos_min[1]])
|
|
||||||
ent_1 = self.tokenizer.tokenize(sentence[pos_max[0]:pos_max[1]])
|
|
||||||
if self.mask_entity:
|
|
||||||
ent_0 = ['[UNK]']
|
|
||||||
ent_1 = ['[UNK]']
|
|
||||||
tokens = sent_0 + ent_0 + sent_1 + ent_1 + sent_2
|
|
||||||
if rev:
|
|
||||||
pos_tail = [len(sent_0), len(sent_0) + len(ent_0)]
|
|
||||||
pos_head = [len(sent_0) + len(ent_0) + len(sent_1), len(sent_0) + len(ent_0) + len(sent_1) + len(ent_1)]
|
|
||||||
else:
|
|
||||||
pos_head = [len(sent_0), len(sent_0) + len(ent_0)]
|
|
||||||
pos_tail = [len(sent_0) + len(ent_0) + len(sent_1), len(sent_0) + len(ent_0) + len(sent_1) + len(ent_1)]
|
|
||||||
else:
|
|
||||||
tokens = sentence
|
|
||||||
|
|
||||||
# Token -> index
|
|
||||||
if self.blank_padding:
|
|
||||||
indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokens, self.max_length, self.token2id['[PAD]'], self.token2id['[UNK]'])
|
|
||||||
else:
|
|
||||||
indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokens, unk_id = self.token2id['[UNK]'])
|
|
||||||
|
|
||||||
# Position -> index
|
|
||||||
pos1 = []
|
|
||||||
pos2 = []
|
|
||||||
pos1_in_index = min(pos_head[0], self.max_length)
|
|
||||||
pos2_in_index = min(pos_tail[0], self.max_length)
|
|
||||||
for i in range(len(tokens)):
|
|
||||||
pos1.append(min(i - pos1_in_index + self.max_length, 2 * self.max_length - 1))
|
|
||||||
pos2.append(min(i - pos2_in_index + self.max_length, 2 * self.max_length - 1))
|
|
||||||
|
|
||||||
if self.blank_padding:
|
|
||||||
while len(pos1) < self.max_length:
|
|
||||||
pos1.append(0)
|
|
||||||
while len(pos2) < self.max_length:
|
|
||||||
pos2.append(0)
|
|
||||||
indexed_tokens = indexed_tokens[:self.max_length]
|
|
||||||
pos1 = pos1[:self.max_length]
|
|
||||||
pos2 = pos2[:self.max_length]
|
|
||||||
|
|
||||||
indexed_tokens = torch.tensor(indexed_tokens).long().unsqueeze(0) # (1, L)
|
|
||||||
pos1 = torch.tensor(pos1).long().unsqueeze(0) # (1, L)
|
|
||||||
pos2 = torch.tensor(pos2).long().unsqueeze(0) # (1, L)
|
|
||||||
|
|
||||||
# Mask
|
|
||||||
mask = []
|
|
||||||
pos_min = min(pos1_in_index, pos2_in_index)
|
|
||||||
pos_max = max(pos1_in_index, pos2_in_index)
|
|
||||||
for i in range(len(tokens)):
|
|
||||||
if i <= pos_min:
|
|
||||||
mask.append(1)
|
|
||||||
elif i <= pos_max:
|
|
||||||
mask.append(2)
|
|
||||||
else:
|
|
||||||
mask.append(3)
|
|
||||||
# Padding
|
|
||||||
if self.blank_padding:
|
|
||||||
while len(mask) < self.max_length:
|
|
||||||
mask.append(0)
|
|
||||||
mask = mask[:self.max_length]
|
|
||||||
|
|
||||||
mask = torch.tensor(mask).long().unsqueeze(0) # (1, L)
|
|
||||||
return indexed_tokens, pos1, pos2, mask
|
|
|
@ -1,20 +0,0 @@
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
from .data_loader import SentenceREDataset, SentenceRELoader, BagREDataset, BagRELoader, MultiLabelSentenceREDataset, MultiLabelSentenceRELoader
|
|
||||||
from .sentence_re import SentenceRE
|
|
||||||
from .bag_re import BagRE
|
|
||||||
from .multi_label_sentence_re import MultiLabelSentenceRE
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
'SentenceREDataset',
|
|
||||||
'SentenceRELoader',
|
|
||||||
'SentenceRE',
|
|
||||||
'BagRE',
|
|
||||||
'BagREDataset',
|
|
||||||
'BagRELoader',
|
|
||||||
'MultiLabelSentenceREDataset',
|
|
||||||
'MultiLabelSentenceRELoader',
|
|
||||||
'MultiLabelSentenceRE'
|
|
||||||
]
|
|
|
@ -1,184 +0,0 @@
|
||||||
import torch
|
|
||||||
from torch import nn, optim
|
|
||||||
import json
|
|
||||||
from .data_loader import SentenceRELoader, BagRELoader
|
|
||||||
from .utils import AverageMeter
|
|
||||||
from tqdm import tqdm
|
|
||||||
import os
|
|
||||||
|
|
||||||
class BagRE(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
model,
|
|
||||||
train_path,
|
|
||||||
val_path,
|
|
||||||
test_path,
|
|
||||||
ckpt,
|
|
||||||
batch_size=32,
|
|
||||||
max_epoch=100,
|
|
||||||
lr=0.1,
|
|
||||||
weight_decay=1e-5,
|
|
||||||
opt='sgd',
|
|
||||||
bag_size=0,
|
|
||||||
loss_weight=False):
|
|
||||||
|
|
||||||
super().__init__()
|
|
||||||
self.max_epoch = max_epoch
|
|
||||||
self.bag_size = bag_size
|
|
||||||
# Load data
|
|
||||||
if train_path != None:
|
|
||||||
self.train_loader = BagRELoader(
|
|
||||||
train_path,
|
|
||||||
model.rel2id,
|
|
||||||
model.sentence_encoder.tokenize,
|
|
||||||
batch_size,
|
|
||||||
True,
|
|
||||||
bag_size=bag_size,
|
|
||||||
entpair_as_bag=False)
|
|
||||||
|
|
||||||
if val_path != None:
|
|
||||||
self.val_loader = BagRELoader(
|
|
||||||
val_path,
|
|
||||||
model.rel2id,
|
|
||||||
model.sentence_encoder.tokenize,
|
|
||||||
batch_size,
|
|
||||||
False,
|
|
||||||
bag_size=bag_size,
|
|
||||||
entpair_as_bag=True)
|
|
||||||
|
|
||||||
if test_path != None:
|
|
||||||
self.test_loader = BagRELoader(
|
|
||||||
test_path,
|
|
||||||
model.rel2id,
|
|
||||||
model.sentence_encoder.tokenize,
|
|
||||||
batch_size,
|
|
||||||
False,
|
|
||||||
bag_size=bag_size,
|
|
||||||
entpair_as_bag=True
|
|
||||||
)
|
|
||||||
# Model
|
|
||||||
self.model = nn.DataParallel(model)
|
|
||||||
# Criterion
|
|
||||||
if loss_weight:
|
|
||||||
self.criterion = nn.CrossEntropyLoss(weight=self.train_loader.dataset.weight)
|
|
||||||
else:
|
|
||||||
self.criterion = nn.CrossEntropyLoss()
|
|
||||||
# Params and optimizer
|
|
||||||
params = self.model.parameters()
|
|
||||||
self.lr = lr
|
|
||||||
if opt == 'sgd':
|
|
||||||
self.optimizer = optim.SGD(params, lr, weight_decay=weight_decay)
|
|
||||||
elif opt == 'adam':
|
|
||||||
self.optimizer = optim.Adam(params, lr, weight_decay=weight_decay)
|
|
||||||
elif opt == 'adamw':
|
|
||||||
from transformers import AdamW
|
|
||||||
params = list(self.named_parameters())
|
|
||||||
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
|
|
||||||
grouped_params = [
|
|
||||||
{
|
|
||||||
'params': [p for n, p in params if not any(nd in n for nd in no_decay)],
|
|
||||||
'weight_decay': 0.01,
|
|
||||||
'lr': lr,
|
|
||||||
'ori_lr': lr
|
|
||||||
},
|
|
||||||
{
|
|
||||||
'params': [p for n, p in params if any(nd in n for nd in no_decay)],
|
|
||||||
'weight_decay': 0.0,
|
|
||||||
'lr': lr,
|
|
||||||
'ori_lr': lr
|
|
||||||
}
|
|
||||||
]
|
|
||||||
self.optimizer = AdamW(grouped_params, correct_bias=False)
|
|
||||||
else:
|
|
||||||
raise Exception("Invalid optimizer. Must be 'sgd' or 'adam' or 'bert_adam'.")
|
|
||||||
# Cuda
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
self.cuda()
|
|
||||||
# Ckpt
|
|
||||||
self.ckpt = ckpt
|
|
||||||
|
|
||||||
def train_model(self, metric='auc'):
|
|
||||||
best_metric = 0
|
|
||||||
for epoch in range(self.max_epoch):
|
|
||||||
# Train
|
|
||||||
self.train()
|
|
||||||
print("=== Epoch %d train ===" % epoch)
|
|
||||||
avg_loss = AverageMeter()
|
|
||||||
avg_acc = AverageMeter()
|
|
||||||
avg_pos_acc = AverageMeter()
|
|
||||||
t = tqdm(self.train_loader)
|
|
||||||
for iter, data in enumerate(t):
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
for i in range(len(data)):
|
|
||||||
try:
|
|
||||||
data[i] = data[i].cuda()
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
label = data[0]
|
|
||||||
bag_name = data[1]
|
|
||||||
scope = data[2]
|
|
||||||
args = data[3:]
|
|
||||||
logits = self.model(label, scope, *args, bag_size=self.bag_size)
|
|
||||||
loss = self.criterion(logits, label)
|
|
||||||
score, pred = logits.max(-1) # (B)
|
|
||||||
acc = float((pred == label).long().sum()) / label.size(0)
|
|
||||||
pos_total = (label != 0).long().sum()
|
|
||||||
pos_correct = ((pred == label).long() * (label != 0).long()).sum()
|
|
||||||
if pos_total > 0:
|
|
||||||
pos_acc = float(pos_correct) / float(pos_total)
|
|
||||||
else:
|
|
||||||
pos_acc = 0
|
|
||||||
|
|
||||||
# Log
|
|
||||||
avg_loss.update(loss.item(), 1)
|
|
||||||
avg_acc.update(acc, 1)
|
|
||||||
avg_pos_acc.update(pos_acc, 1)
|
|
||||||
t.set_postfix(loss=avg_loss.avg, acc=avg_acc.avg, pos_acc=avg_pos_acc.avg)
|
|
||||||
|
|
||||||
# Optimize
|
|
||||||
loss.backward()
|
|
||||||
self.optimizer.step()
|
|
||||||
self.optimizer.zero_grad()
|
|
||||||
|
|
||||||
# Val
|
|
||||||
print("=== Epoch %d val ===" % epoch)
|
|
||||||
result = self.eval_model(self.val_loader)
|
|
||||||
print("AUC: %.4f" % result['auc'])
|
|
||||||
print("Micro F1: %.4f" % (result['max_micro_f1']))
|
|
||||||
if result[metric] > best_metric:
|
|
||||||
print("Best ckpt and saved.")
|
|
||||||
torch.save({'state_dict': self.model.module.state_dict()}, self.ckpt)
|
|
||||||
best_metric = result[metric]
|
|
||||||
print("Best %s on val set: %f" % (metric, best_metric))
|
|
||||||
|
|
||||||
def eval_model(self, eval_loader):
|
|
||||||
self.model.eval()
|
|
||||||
with torch.no_grad():
|
|
||||||
t = tqdm(eval_loader)
|
|
||||||
pred_result = []
|
|
||||||
for iter, data in enumerate(t):
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
for i in range(len(data)):
|
|
||||||
try:
|
|
||||||
data[i] = data[i].cuda()
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
label = data[0]
|
|
||||||
bag_name = data[1]
|
|
||||||
scope = data[2]
|
|
||||||
args = data[3:]
|
|
||||||
logits = self.model(None, scope, *args, train=False, bag_size=self.bag_size) # results after softmax
|
|
||||||
logits = logits.cpu().numpy()
|
|
||||||
for i in range(len(logits)):
|
|
||||||
for relid in range(self.model.module.num_class):
|
|
||||||
if self.model.module.id2rel[relid] != 'NA':
|
|
||||||
pred_result.append({
|
|
||||||
'entpair': bag_name[i][:2],
|
|
||||||
'relation': self.model.module.id2rel[relid],
|
|
||||||
'score': logits[i][relid]
|
|
||||||
})
|
|
||||||
result = eval_loader.dataset.eval(pred_result)
|
|
||||||
return result
|
|
||||||
|
|
||||||
def load_state_dict(self, state_dict):
|
|
||||||
self.model.module.load_state_dict(state_dict)
|
|
|
@ -1,459 +0,0 @@
|
||||||
import torch
|
|
||||||
import torch.utils.data as data
|
|
||||||
import os, random, json, logging
|
|
||||||
import numpy as np
|
|
||||||
import sklearn.metrics
|
|
||||||
|
|
||||||
class SentenceREDataset(data.Dataset):
|
|
||||||
"""
|
|
||||||
Sentence-level relation extraction dataset
|
|
||||||
"""
|
|
||||||
def __init__(self, path, rel2id, tokenizer, kwargs):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
path: path of the input file
|
|
||||||
rel2id: dictionary of relation->id mapping
|
|
||||||
tokenizer: function of tokenizing
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.path = path
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.rel2id = rel2id
|
|
||||||
self.kwargs = kwargs
|
|
||||||
|
|
||||||
# Load the file
|
|
||||||
f = open(path)
|
|
||||||
self.data = []
|
|
||||||
for line in f.readlines():
|
|
||||||
line = line.rstrip()
|
|
||||||
if len(line) > 0:
|
|
||||||
self.data.append(eval(line))
|
|
||||||
f.close()
|
|
||||||
logging.info("Loaded sentence RE dataset {} with {} lines and {} relations.".format(path, len(self.data), len(self.rel2id)))
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.data)
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
|
||||||
item = self.data[index]
|
|
||||||
seq = list(self.tokenizer(item, **self.kwargs))
|
|
||||||
res = [self.rel2id[item['relation']]] + seq
|
|
||||||
return [self.rel2id[item['relation']]] + seq # label, seq1, seq2, ...
|
|
||||||
|
|
||||||
def collate_fn(data):
|
|
||||||
data = list(zip(*data))
|
|
||||||
labels = data[0]
|
|
||||||
seqs = data[1:]
|
|
||||||
batch_labels = torch.tensor(labels).long() # (B)
|
|
||||||
batch_seqs = []
|
|
||||||
for seq in seqs:
|
|
||||||
batch_seqs.append(torch.cat(seq, 0)) # (B, L)
|
|
||||||
return [batch_labels] + batch_seqs
|
|
||||||
|
|
||||||
def eval(self, pred_result, use_name=False):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
pred_result: a list of predicted label (id)
|
|
||||||
Make sure that the `shuffle` param is set to `False` when getting the loader.
|
|
||||||
use_name: if True, `pred_result` contains predicted relation names instead of ids
|
|
||||||
Return:
|
|
||||||
{'acc': xx}
|
|
||||||
"""
|
|
||||||
correct = 0
|
|
||||||
total = len(self.data)
|
|
||||||
correct_positive = 0
|
|
||||||
pred_positive = 0
|
|
||||||
gold_positive = 0
|
|
||||||
neg = -1
|
|
||||||
for name in ['NA', 'na', 'no_relation', 'Other', 'Others']:
|
|
||||||
if name in self.rel2id:
|
|
||||||
if use_name:
|
|
||||||
neg = name
|
|
||||||
else:
|
|
||||||
neg = self.rel2id[name]
|
|
||||||
break
|
|
||||||
for i in range(total):
|
|
||||||
if use_name:
|
|
||||||
golden = self.data[i]['relation']
|
|
||||||
else:
|
|
||||||
golden = self.rel2id[self.data[i]['relation']]
|
|
||||||
if golden == pred_result[i]:
|
|
||||||
correct += 1
|
|
||||||
if golden != neg:
|
|
||||||
correct_positive += 1
|
|
||||||
if golden != neg:
|
|
||||||
gold_positive +=1
|
|
||||||
if pred_result[i] != neg:
|
|
||||||
pred_positive += 1
|
|
||||||
acc = float(correct) / float(total)
|
|
||||||
try:
|
|
||||||
micro_p = float(correct_positive) / float(pred_positive)
|
|
||||||
except:
|
|
||||||
micro_p = 0
|
|
||||||
try:
|
|
||||||
micro_r = float(correct_positive) / float(gold_positive)
|
|
||||||
except:
|
|
||||||
micro_r = 0
|
|
||||||
try:
|
|
||||||
micro_f1 = 2 * micro_p * micro_r / (micro_p + micro_r)
|
|
||||||
except:
|
|
||||||
micro_f1 = 0
|
|
||||||
result = {'acc': acc, 'micro_p': micro_p, 'micro_r': micro_r, 'micro_f1': micro_f1}
|
|
||||||
logging.info('Evaluation result: {}.'.format(result))
|
|
||||||
return result
|
|
||||||
|
|
||||||
def SentenceRELoader(path, rel2id, tokenizer, batch_size,
|
|
||||||
shuffle, num_workers=8, collate_fn=SentenceREDataset.collate_fn, **kwargs):
|
|
||||||
dataset = SentenceREDataset(path = path, rel2id = rel2id, tokenizer = tokenizer, kwargs=kwargs)
|
|
||||||
data_loader = data.DataLoader(dataset=dataset,
|
|
||||||
batch_size=batch_size,
|
|
||||||
shuffle=shuffle,
|
|
||||||
pin_memory=True,
|
|
||||||
num_workers=num_workers,
|
|
||||||
collate_fn=collate_fn)
|
|
||||||
return data_loader
|
|
||||||
|
|
||||||
class BagREDataset(data.Dataset):
|
|
||||||
"""
|
|
||||||
Bag-level relation extraction dataset. Note that relation of NA should be named as 'NA'.
|
|
||||||
"""
|
|
||||||
def __init__(self, path, rel2id, tokenizer, entpair_as_bag=False, bag_size=0, mode=None):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
path: path of the input file
|
|
||||||
rel2id: dictionary of relation->id mapping
|
|
||||||
tokenizer: function of tokenizing
|
|
||||||
entpair_as_bag: if True, bags are constructed based on same
|
|
||||||
entity pairs instead of same relation facts (ignoring
|
|
||||||
relation labels)
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.rel2id = rel2id
|
|
||||||
self.entpair_as_bag = entpair_as_bag
|
|
||||||
self.bag_size = bag_size
|
|
||||||
|
|
||||||
# Load the file
|
|
||||||
f = open(path)
|
|
||||||
self.data = []
|
|
||||||
for line in f:
|
|
||||||
line = line.rstrip()
|
|
||||||
if len(line) > 0:
|
|
||||||
self.data.append(eval(line))
|
|
||||||
f.close()
|
|
||||||
|
|
||||||
# Construct bag-level dataset (a bag contains instances sharing the same relation fact)
|
|
||||||
if mode == None:
|
|
||||||
self.weight = np.ones((len(self.rel2id)), dtype=np.float32)
|
|
||||||
self.bag_scope = []
|
|
||||||
self.name2id = {}
|
|
||||||
self.bag_name = []
|
|
||||||
self.facts = {}
|
|
||||||
for idx, item in enumerate(self.data):
|
|
||||||
# Annotated test set
|
|
||||||
if 'anno_relation_list' in item:
|
|
||||||
for r in item['anno_relation_list']:
|
|
||||||
fact = (item['h']['id'], item['t']['id'], r)
|
|
||||||
if r != 'NA':
|
|
||||||
self.facts[fact] = 1
|
|
||||||
assert entpair_as_bag
|
|
||||||
name = (item['h']['id'], item['t']['id'])
|
|
||||||
else:
|
|
||||||
fact = (item['h']['id'], item['t']['id'], item['relation'])
|
|
||||||
if item['relation'] != 'NA':
|
|
||||||
self.facts[fact] = 1
|
|
||||||
if entpair_as_bag:
|
|
||||||
name = (item['h']['id'], item['t']['id'])
|
|
||||||
else:
|
|
||||||
name = fact
|
|
||||||
if name not in self.name2id:
|
|
||||||
self.name2id[name] = len(self.name2id)
|
|
||||||
self.bag_scope.append([])
|
|
||||||
self.bag_name.append(name)
|
|
||||||
self.bag_scope[self.name2id[name]].append(idx)
|
|
||||||
self.weight[self.rel2id[item['relation']]] += 1.0
|
|
||||||
self.weight = 1.0 / (self.weight ** 0.05)
|
|
||||||
self.weight = torch.from_numpy(self.weight)
|
|
||||||
else:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.bag_scope)
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
|
||||||
bag = self.bag_scope[index]
|
|
||||||
if self.bag_size > 0:
|
|
||||||
if self.bag_size <= len(bag):
|
|
||||||
resize_bag = random.sample(bag, self.bag_size)
|
|
||||||
else:
|
|
||||||
resize_bag = bag + list(np.random.choice(bag, self.bag_size - len(bag)))
|
|
||||||
bag = resize_bag
|
|
||||||
|
|
||||||
seqs = None
|
|
||||||
rel = self.rel2id[self.data[bag[0]]['relation']]
|
|
||||||
for sent_id in bag:
|
|
||||||
item = self.data[sent_id]
|
|
||||||
seq = list(self.tokenizer(item))
|
|
||||||
if seqs is None:
|
|
||||||
seqs = []
|
|
||||||
for i in range(len(seq)):
|
|
||||||
seqs.append([])
|
|
||||||
for i in range(len(seq)):
|
|
||||||
seqs[i].append(seq[i])
|
|
||||||
for i in range(len(seqs)):
|
|
||||||
seqs[i] = torch.cat(seqs[i], 0) # (n, L), n is the size of bag
|
|
||||||
return [rel, self.bag_name[index], len(bag)] + seqs
|
|
||||||
|
|
||||||
def collate_fn(data):
|
|
||||||
data = list(zip(*data))
|
|
||||||
label, bag_name, count = data[:3]
|
|
||||||
seqs = data[3:]
|
|
||||||
for i in range(len(seqs)):
|
|
||||||
seqs[i] = torch.cat(seqs[i], 0) # (sumn, L)
|
|
||||||
seqs[i] = seqs[i].expand((torch.cuda.device_count() if torch.cuda.device_count() > 0 else 1, ) + seqs[i].size())
|
|
||||||
scope = [] # (B, 2)
|
|
||||||
start = 0
|
|
||||||
for c in count:
|
|
||||||
scope.append((start, start + c))
|
|
||||||
start += c
|
|
||||||
assert(start == seqs[0].size(1))
|
|
||||||
scope = torch.tensor(scope).long()
|
|
||||||
label = torch.tensor(label).long() # (B)
|
|
||||||
return [label, bag_name, scope] + seqs
|
|
||||||
|
|
||||||
def collate_bag_size_fn(data):
|
|
||||||
data = list(zip(*data))
|
|
||||||
label, bag_name, count = data[:3]
|
|
||||||
seqs = data[3:]
|
|
||||||
for i in range(len(seqs)):
|
|
||||||
seqs[i] = torch.stack(seqs[i], 0) # (batch, bag, L)
|
|
||||||
scope = [] # (B, 2)
|
|
||||||
start = 0
|
|
||||||
for c in count:
|
|
||||||
scope.append((start, start + c))
|
|
||||||
start += c
|
|
||||||
label = torch.tensor(label).long() # (B)
|
|
||||||
return [label, bag_name, scope] + seqs
|
|
||||||
|
|
||||||
|
|
||||||
def eval(self, pred_result, threshold=0.5):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
pred_result: a list with dict {'entpair': (head_id, tail_id), 'relation': rel, 'score': score}.
|
|
||||||
Note that relation of NA should be excluded.
|
|
||||||
Return:
|
|
||||||
{'prec': narray[...], 'rec': narray[...], 'mean_prec': xx, 'f1': xx, 'auc': xx}
|
|
||||||
prec (precision) and rec (recall) are in micro style.
|
|
||||||
prec (precision) and rec (recall) are sorted in the decreasing order of the score.
|
|
||||||
f1 is the max f1 score of those precison-recall points
|
|
||||||
"""
|
|
||||||
sorted_pred_result = sorted(pred_result, key=lambda x: x['score'], reverse=True)
|
|
||||||
prec = []
|
|
||||||
rec = []
|
|
||||||
correct = 0
|
|
||||||
total = len(self.facts)
|
|
||||||
|
|
||||||
entpair = {}
|
|
||||||
|
|
||||||
for i, item in enumerate(sorted_pred_result):
|
|
||||||
# Save entpair label and result for later calculating F1
|
|
||||||
idtf = item['entpair'][0] + '#' + item['entpair'][1]
|
|
||||||
if idtf not in entpair:
|
|
||||||
entpair[idtf] = {
|
|
||||||
'label': np.zeros((len(self.rel2id)), dtype=np.int),
|
|
||||||
'pred': np.zeros((len(self.rel2id)), dtype=np.int),
|
|
||||||
'score': np.zeros((len(self.rel2id)), dtype=np.float)
|
|
||||||
}
|
|
||||||
if (item['entpair'][0], item['entpair'][1], item['relation']) in self.facts:
|
|
||||||
correct += 1
|
|
||||||
entpair[idtf]['label'][self.rel2id[item['relation']]] = 1
|
|
||||||
if item['score'] >= threshold:
|
|
||||||
entpair[idtf]['pred'][self.rel2id[item['relation']]] = 1
|
|
||||||
entpair[idtf]['score'][self.rel2id[item['relation']]] = item['score']
|
|
||||||
|
|
||||||
prec.append(float(correct) / float(i + 1))
|
|
||||||
rec.append(float(correct) / float(total))
|
|
||||||
|
|
||||||
auc = sklearn.metrics.auc(x=rec, y=prec)
|
|
||||||
np_prec = np.array(prec)
|
|
||||||
np_rec = np.array(rec)
|
|
||||||
max_micro_f1 = (2 * np_prec * np_rec / (np_prec + np_rec + 1e-20)).max()
|
|
||||||
best_threshold = sorted_pred_result[(2 * np_prec * np_rec / (np_prec + np_rec + 1e-20)).argmax()]['score']
|
|
||||||
mean_prec = np_prec.mean()
|
|
||||||
|
|
||||||
label_vec = []
|
|
||||||
pred_result_vec = []
|
|
||||||
score_vec = []
|
|
||||||
for ep in entpair:
|
|
||||||
label_vec.append(entpair[ep]['label'])
|
|
||||||
pred_result_vec.append(entpair[ep]['pred'])
|
|
||||||
score_vec.append(entpair[ep]['score'])
|
|
||||||
label_vec = np.stack(label_vec, 0)
|
|
||||||
pred_result_vec = np.stack(pred_result_vec, 0)
|
|
||||||
score_vec = np.stack(score_vec, 0)
|
|
||||||
|
|
||||||
micro_p = sklearn.metrics.precision_score(label_vec, pred_result_vec, labels=list(range(1, len(self.rel2id))), average='micro')
|
|
||||||
micro_r = sklearn.metrics.recall_score(label_vec, pred_result_vec, labels=list(range(1, len(self.rel2id))), average='micro')
|
|
||||||
micro_f1 = sklearn.metrics.f1_score(label_vec, pred_result_vec, labels=list(range(1, len(self.rel2id))), average='micro')
|
|
||||||
|
|
||||||
macro_p = sklearn.metrics.precision_score(label_vec, pred_result_vec, labels=list(range(1, len(self.rel2id))), average='macro')
|
|
||||||
macro_r = sklearn.metrics.recall_score(label_vec, pred_result_vec, labels=list(range(1, len(self.rel2id))), average='macro')
|
|
||||||
macro_f1 = sklearn.metrics.f1_score(label_vec, pred_result_vec, labels=list(range(1, len(self.rel2id))), average='macro')
|
|
||||||
|
|
||||||
pred_result_vec = score_vec >= best_threshold
|
|
||||||
max_macro_f1 = sklearn.metrics.f1_score(label_vec, pred_result_vec, labels=list(range(1, len(self.rel2id))), average='macro')
|
|
||||||
max_micro_f1_each_relation = {}
|
|
||||||
for rel in self.rel2id:
|
|
||||||
if rel != 'NA':
|
|
||||||
max_micro_f1_each_relation[rel] = sklearn.metrics.f1_score(label_vec, pred_result_vec, labels=[self.rel2id[rel]], average='micro')
|
|
||||||
|
|
||||||
return {'np_prec': np_prec, 'np_rec': np_rec, 'max_micro_f1': max_micro_f1, 'max_macro_f1': max_macro_f1, 'auc': auc, 'p@100': np_prec[99], 'p@200': np_prec[199], 'p@300': np_prec[299], 'avg_p300': (np_prec[99] + np_prec[199] + np_prec[299]) / 3, 'micro_f1': micro_f1, 'macro_f1': macro_f1, 'max_micro_f1_each_relation': max_micro_f1_each_relation}
|
|
||||||
|
|
||||||
def BagRELoader(path, rel2id, tokenizer, batch_size,
|
|
||||||
shuffle, entpair_as_bag=False, bag_size=0, num_workers=8,
|
|
||||||
collate_fn=BagREDataset.collate_fn):
|
|
||||||
if bag_size == 0:
|
|
||||||
collate_fn = BagREDataset.collate_fn
|
|
||||||
else:
|
|
||||||
collate_fn = BagREDataset.collate_bag_size_fn
|
|
||||||
dataset = BagREDataset(path, rel2id, tokenizer, entpair_as_bag=entpair_as_bag, bag_size=bag_size)
|
|
||||||
data_loader = data.DataLoader(dataset=dataset,
|
|
||||||
batch_size=batch_size,
|
|
||||||
shuffle=shuffle,
|
|
||||||
pin_memory=True,
|
|
||||||
num_workers=num_workers,
|
|
||||||
collate_fn=collate_fn)
|
|
||||||
return data_loader
|
|
||||||
|
|
||||||
|
|
||||||
class MultiLabelSentenceREDataset(data.Dataset):
|
|
||||||
"""
|
|
||||||
Sentence-level relation extraction dataset
|
|
||||||
"""
|
|
||||||
def __init__(self, path, rel2id, tokenizer, kwargs):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
path: path of the input file
|
|
||||||
rel2id: dictionary of relation->id mapping
|
|
||||||
tokenizer: function of tokenizing
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.path = path
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.rel2id = rel2id
|
|
||||||
self.kwargs = kwargs
|
|
||||||
|
|
||||||
# Load the file
|
|
||||||
f = open(path)
|
|
||||||
self.data = []
|
|
||||||
for line in f.readlines():
|
|
||||||
line = line.rstrip()
|
|
||||||
if len(line) > 0:
|
|
||||||
self.data.append(eval(line))
|
|
||||||
f.close()
|
|
||||||
logging.info("Loaded sentence RE dataset {} with {} lines and {} relations.".format(path, len(self.data), len(self.rel2id)))
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.data)
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
|
||||||
item = self.data[index]
|
|
||||||
seq = list(self.tokenizer(item, **self.kwargs))
|
|
||||||
res = [self.rel2id[item['relation']]] + seq
|
|
||||||
return [self.rel2id[item['relation']]] + seq # label, seq1, seq2, ...
|
|
||||||
|
|
||||||
def collate_fn(data):
|
|
||||||
data = list(zip(*data))
|
|
||||||
labels = data[0]
|
|
||||||
seqs = data[1:]
|
|
||||||
batch_labels = torch.tensor(labels).long() # (B)
|
|
||||||
batch_seqs = []
|
|
||||||
for seq in seqs:
|
|
||||||
batch_seqs.append(torch.cat(seq, 0)) # (B, L)
|
|
||||||
return [batch_labels] + batch_seqs
|
|
||||||
|
|
||||||
def eval(self, pred_score, threshold=0.5, use_name=False):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
pred_score: [sent_num, label_num]
|
|
||||||
use_name: if True, `pred_result` contains predicted relation names instead of ids
|
|
||||||
Return:
|
|
||||||
{'acc': xx}
|
|
||||||
"""
|
|
||||||
assert len(self.data) == len(pred_score)
|
|
||||||
pred_score = np.array(pred_score)
|
|
||||||
|
|
||||||
# Calculate AUC
|
|
||||||
sorted_result = []
|
|
||||||
total = 0
|
|
||||||
for sent_id in range(len(self.data)):
|
|
||||||
for rel in self.rel2id:
|
|
||||||
if rel not in ['NA', 'na', 'N/A', 'None', 'none', 'n/a', 'no_relation']:
|
|
||||||
sorted_result.append({'sent_id': sent_id, 'relation': rel, 'score': pred_score[sent_id][self.rel2id[rel]]})
|
|
||||||
if 'anno_relation_list' in self.data[sent_id]:
|
|
||||||
if rel in self.data[sent_id]['anno_relation_list']:
|
|
||||||
total += 1
|
|
||||||
else:
|
|
||||||
if rel == self.data[sent_id]['relation']:
|
|
||||||
total += 1
|
|
||||||
|
|
||||||
sorted_result.sort(key=lambda x: x['score'], reverse=True)
|
|
||||||
prec = []
|
|
||||||
rec = []
|
|
||||||
correct = 0
|
|
||||||
for i, item in enumerate(sorted_result):
|
|
||||||
if 'anno_relation_list' in self.data[item['sent_id']]:
|
|
||||||
if item['relation'] in self.data[item['sent_id']]['anno_relation_list']:
|
|
||||||
correct += 1
|
|
||||||
else:
|
|
||||||
if item['relation'] == self.data[item['sent_id']]['relation']:
|
|
||||||
correct += 1
|
|
||||||
prec.append(float(correct) / float(i + 1))
|
|
||||||
rec.append(float(correct) / float(total))
|
|
||||||
auc = sklearn.metrics.auc(x=rec, y=prec)
|
|
||||||
np_prec = np.array(prec)
|
|
||||||
np_rec = np.array(rec)
|
|
||||||
max_micro_f1 = (2 * np_prec * np_rec / (np_prec + np_rec + 1e-20)).max()
|
|
||||||
max_micro_f1_threshold = sorted_result[(2 * np_prec * np_rec / (np_prec + np_rec + 1e-20)).argmax()]['score']
|
|
||||||
mean_prec = np_prec.mean()
|
|
||||||
|
|
||||||
# Calculate F1
|
|
||||||
pred_result_vec = np.zeros((len(self.data), len(self.rel2id)), dtype=np.int)
|
|
||||||
pred_result_vec[pred_score >= threshold] = 1
|
|
||||||
label_vec = []
|
|
||||||
for item in self.data:
|
|
||||||
if 'anno_relation_list' in item:
|
|
||||||
label_vec.append(np.array(item['anno_relation_vec'], dtype=np.int))
|
|
||||||
else:
|
|
||||||
one_hot = np.zeros((len(self.rel2id)), dtype=np.int)
|
|
||||||
one_hot[self.rel2id[item['relation']]] = 1
|
|
||||||
label_vec.append(one_hot)
|
|
||||||
label_vec = np.stack(label_vec, 0)
|
|
||||||
assert label_vec.shape == pred_result_vec.shape
|
|
||||||
|
|
||||||
micro_p = sklearn.metrics.precision_score(label_vec, pred_result_vec, labels=list(range(1, len(self.rel2id))), average='micro')
|
|
||||||
micro_r = sklearn.metrics.recall_score(label_vec, pred_result_vec, labels=list(range(1, len(self.rel2id))), average='micro')
|
|
||||||
micro_f1 = sklearn.metrics.f1_score(label_vec, pred_result_vec, labels=list(range(1, len(self.rel2id))), average='micro')
|
|
||||||
|
|
||||||
macro_p = sklearn.metrics.precision_score(label_vec, pred_result_vec, labels=list(range(1, len(self.rel2id))), average='macro')
|
|
||||||
macro_r = sklearn.metrics.recall_score(label_vec, pred_result_vec, labels=list(range(1, len(self.rel2id))), average='macro')
|
|
||||||
macro_f1 = sklearn.metrics.f1_score(label_vec, pred_result_vec, labels=list(range(1, len(self.rel2id))), average='macro')
|
|
||||||
|
|
||||||
acc = (label_vec == pred_result_vec).mean()
|
|
||||||
|
|
||||||
result = {'acc': acc, 'micro_p': micro_p, 'micro_r': micro_r, 'micro_f1': micro_f1, 'macro_p': macro_p, 'macro_r': macro_r, 'macro_f1': macro_f1, 'np_prec': np_prec, 'np_rec': np_rec, 'max_micro_f1': max_micro_f1, 'max_micro_f1_threshold': max_micro_f1_threshold, 'auc': auc, 'p@100': np_prec[99], 'p@200': np_prec[199], 'p@300': np_prec[299]}
|
|
||||||
logging.info('Evaluation result: {}.'.format(result))
|
|
||||||
return result
|
|
||||||
|
|
||||||
def MultiLabelSentenceRELoader(path, rel2id, tokenizer, batch_size,
|
|
||||||
shuffle, num_workers=8, collate_fn=SentenceREDataset.collate_fn, **kwargs):
|
|
||||||
dataset = MultiLabelSentenceREDataset(path = path, rel2id = rel2id, tokenizer = tokenizer, kwargs=kwargs)
|
|
||||||
data_loader = data.DataLoader(dataset=dataset,
|
|
||||||
batch_size=batch_size,
|
|
||||||
shuffle=shuffle,
|
|
||||||
pin_memory=True,
|
|
||||||
num_workers=num_workers,
|
|
||||||
collate_fn=collate_fn)
|
|
||||||
return data_loader
|
|
||||||
|
|
||||||
|
|
|
@ -1,175 +0,0 @@
|
||||||
import os, logging, json
|
|
||||||
from tqdm import tqdm
|
|
||||||
import torch
|
|
||||||
from torch import nn, optim
|
|
||||||
from .data_loader import MultiLabelSentenceRELoader
|
|
||||||
from .utils import AverageMeter
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
class MultiLabelSentenceRE(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
model,
|
|
||||||
train_path,
|
|
||||||
val_path,
|
|
||||||
test_path,
|
|
||||||
ckpt,
|
|
||||||
batch_size=32,
|
|
||||||
max_epoch=100,
|
|
||||||
lr=0.1,
|
|
||||||
weight_decay=1e-5,
|
|
||||||
warmup_step=300,
|
|
||||||
opt='sgd'):
|
|
||||||
|
|
||||||
super().__init__()
|
|
||||||
self.max_epoch = max_epoch
|
|
||||||
# Load data
|
|
||||||
if train_path != None:
|
|
||||||
self.train_loader = MultiLabelSentenceRELoader(
|
|
||||||
train_path,
|
|
||||||
model.rel2id,
|
|
||||||
model.sentence_encoder.tokenize,
|
|
||||||
batch_size,
|
|
||||||
True)
|
|
||||||
|
|
||||||
if val_path != None:
|
|
||||||
self.val_loader = MultiLabelSentenceRELoader(
|
|
||||||
val_path,
|
|
||||||
model.rel2id,
|
|
||||||
model.sentence_encoder.tokenize,
|
|
||||||
batch_size,
|
|
||||||
False)
|
|
||||||
|
|
||||||
if test_path != None:
|
|
||||||
self.test_loader = MultiLabelSentenceRELoader(
|
|
||||||
test_path,
|
|
||||||
model.rel2id,
|
|
||||||
model.sentence_encoder.tokenize,
|
|
||||||
batch_size,
|
|
||||||
False
|
|
||||||
)
|
|
||||||
# Model
|
|
||||||
self.model = model
|
|
||||||
self.parallel_model = nn.DataParallel(self.model)
|
|
||||||
# Criterion
|
|
||||||
self.criterion = nn.BCEWithLogitsLoss()
|
|
||||||
# Params and optimizer
|
|
||||||
params = self.parameters()
|
|
||||||
self.lr = lr
|
|
||||||
if opt == 'sgd':
|
|
||||||
self.optimizer = optim.SGD(params, lr, weight_decay=weight_decay)
|
|
||||||
elif opt == 'adam':
|
|
||||||
self.optimizer = optim.Adam(params, lr, weight_decay=weight_decay)
|
|
||||||
elif opt == 'adamw': # Optimizer for BERT
|
|
||||||
from transformers import AdamW
|
|
||||||
params = list(self.named_parameters())
|
|
||||||
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
|
|
||||||
grouped_params = [
|
|
||||||
{
|
|
||||||
'params': [p for n, p in params if not any(nd in n for nd in no_decay)],
|
|
||||||
'weight_decay': 0.01,
|
|
||||||
'lr': lr,
|
|
||||||
'ori_lr': lr
|
|
||||||
},
|
|
||||||
{
|
|
||||||
'params': [p for n, p in params if any(nd in n for nd in no_decay)],
|
|
||||||
'weight_decay': 0.0,
|
|
||||||
'lr': lr,
|
|
||||||
'ori_lr': lr
|
|
||||||
}
|
|
||||||
]
|
|
||||||
self.optimizer = AdamW(grouped_params, correct_bias=False)
|
|
||||||
else:
|
|
||||||
raise Exception("Invalid optimizer. Must be 'sgd' or 'adam' or 'adamw'.")
|
|
||||||
# Warmup
|
|
||||||
if warmup_step > 0:
|
|
||||||
from transformers import get_linear_schedule_with_warmup
|
|
||||||
training_steps = self.train_loader.dataset.__len__() // batch_size * self.max_epoch
|
|
||||||
self.scheduler = get_linear_schedule_with_warmup(self.optimizer, num_warmup_steps=warmup_step, num_training_steps=training_steps)
|
|
||||||
else:
|
|
||||||
self.scheduler = None
|
|
||||||
# Cuda
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
self.cuda()
|
|
||||||
# Ckpt
|
|
||||||
self.ckpt = ckpt
|
|
||||||
|
|
||||||
def train_model(self, metric='acc'):
|
|
||||||
best_metric = 0
|
|
||||||
global_step = 0
|
|
||||||
for epoch in range(self.max_epoch):
|
|
||||||
self.train()
|
|
||||||
logging.info("=== Epoch %d train ===" % epoch)
|
|
||||||
avg_loss = AverageMeter()
|
|
||||||
avg_acc = AverageMeter()
|
|
||||||
t = tqdm(self.train_loader)
|
|
||||||
for iter, data in enumerate(t):
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
for i in range(len(data)):
|
|
||||||
try:
|
|
||||||
data[i] = data[i].cuda()
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
label = data[0]
|
|
||||||
args = data[1:]
|
|
||||||
logits = self.parallel_model(*args)
|
|
||||||
|
|
||||||
label_vec = torch.zeros_like(logits).cuda()
|
|
||||||
label_vec[torch.arange(label_vec.size(0)), label] = 1
|
|
||||||
label_vec = label_vec[:, 1:]
|
|
||||||
logits = logits[:, 1:]
|
|
||||||
|
|
||||||
loss = self.criterion(logits.reshape(-1), label_vec.reshape(-1))
|
|
||||||
pred = (torch.sigmoid(logits) >= 0.5).long()
|
|
||||||
acc = float((pred == label_vec).long().sum()) / (label_vec.size(0) * label_vec.size(1))
|
|
||||||
|
|
||||||
# Log
|
|
||||||
avg_loss.update(loss.item(), 1)
|
|
||||||
avg_acc.update(acc, 1)
|
|
||||||
t.set_postfix(loss=avg_loss.avg, acc=avg_acc.avg)
|
|
||||||
# Optimize
|
|
||||||
loss.backward()
|
|
||||||
self.optimizer.step()
|
|
||||||
if self.scheduler is not None:
|
|
||||||
self.scheduler.step()
|
|
||||||
self.optimizer.zero_grad()
|
|
||||||
global_step += 1
|
|
||||||
# Val
|
|
||||||
logging.info("=== Epoch %d val ===" % epoch)
|
|
||||||
result = self.eval_model(self.val_loader)
|
|
||||||
logging.info('Metric {} current / best: {} / {}'.format(metric, result[metric], best_metric))
|
|
||||||
if result[metric] > best_metric:
|
|
||||||
logging.info("Best ckpt and saved.")
|
|
||||||
folder_path = '/'.join(self.ckpt.split('/')[:-1])
|
|
||||||
if not os.path.exists(folder_path):
|
|
||||||
os.mkdir(folder_path)
|
|
||||||
torch.save({'state_dict': self.model.state_dict()}, self.ckpt)
|
|
||||||
best_metric = result[metric]
|
|
||||||
logging.info("Best %s on val set: %f" % (metric, best_metric))
|
|
||||||
|
|
||||||
def eval_model(self, eval_loader):
|
|
||||||
self.eval()
|
|
||||||
pred_score = []
|
|
||||||
with torch.no_grad():
|
|
||||||
t = tqdm(eval_loader)
|
|
||||||
for iter, data in enumerate(t):
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
for i in range(len(data)):
|
|
||||||
try:
|
|
||||||
data[i] = data[i].cuda()
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
label = data[0]
|
|
||||||
args = data[1:]
|
|
||||||
logits = self.parallel_model(*args)
|
|
||||||
score = self.parallel_model.module.logit_to_score(logits).cpu().numpy()
|
|
||||||
# Save result
|
|
||||||
pred_score.append(score)
|
|
||||||
# Log
|
|
||||||
pred_score = np.concatenate(pred_score, 0)
|
|
||||||
result = eval_loader.dataset.eval(pred_score)
|
|
||||||
return result
|
|
||||||
|
|
||||||
def load_state_dict(self, state_dict):
|
|
||||||
self.model.load_state_dict(state_dict)
|
|
||||||
|
|
|
@ -1,171 +0,0 @@
|
||||||
import os, logging, json
|
|
||||||
from tqdm import tqdm
|
|
||||||
import torch
|
|
||||||
from torch import nn, optim
|
|
||||||
from .data_loader import SentenceRELoader
|
|
||||||
from .utils import AverageMeter
|
|
||||||
|
|
||||||
class SentenceRE(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
model,
|
|
||||||
train_path,
|
|
||||||
val_path,
|
|
||||||
test_path,
|
|
||||||
ckpt,
|
|
||||||
batch_size=32,
|
|
||||||
max_epoch=100,
|
|
||||||
lr=0.1,
|
|
||||||
weight_decay=1e-5,
|
|
||||||
warmup_step=300,
|
|
||||||
opt='sgd'):
|
|
||||||
|
|
||||||
super().__init__()
|
|
||||||
self.max_epoch = max_epoch
|
|
||||||
# Load data
|
|
||||||
if train_path != None:
|
|
||||||
self.train_loader = SentenceRELoader(
|
|
||||||
train_path,
|
|
||||||
model.rel2id,
|
|
||||||
model.sentence_encoder.tokenize,
|
|
||||||
batch_size,
|
|
||||||
True)
|
|
||||||
|
|
||||||
if val_path != None:
|
|
||||||
self.val_loader = SentenceRELoader(
|
|
||||||
val_path,
|
|
||||||
model.rel2id,
|
|
||||||
model.sentence_encoder.tokenize,
|
|
||||||
batch_size,
|
|
||||||
False)
|
|
||||||
|
|
||||||
if test_path != None:
|
|
||||||
self.test_loader = SentenceRELoader(
|
|
||||||
test_path,
|
|
||||||
model.rel2id,
|
|
||||||
model.sentence_encoder.tokenize,
|
|
||||||
batch_size,
|
|
||||||
False
|
|
||||||
)
|
|
||||||
# Model
|
|
||||||
self.model = model
|
|
||||||
self.parallel_model = nn.DataParallel(self.model)
|
|
||||||
# Criterion
|
|
||||||
self.criterion = nn.CrossEntropyLoss()
|
|
||||||
# Params and optimizer
|
|
||||||
params = self.parameters()
|
|
||||||
self.lr = lr
|
|
||||||
if opt == 'sgd':
|
|
||||||
self.optimizer = optim.SGD(params, lr, weight_decay=weight_decay)
|
|
||||||
elif opt == 'adam':
|
|
||||||
self.optimizer = optim.Adam(params, lr, weight_decay=weight_decay)
|
|
||||||
elif opt == 'adamw': # Optimizer for BERT
|
|
||||||
from transformers import AdamW
|
|
||||||
params = list(self.named_parameters())
|
|
||||||
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
|
|
||||||
grouped_params = [
|
|
||||||
{
|
|
||||||
'params': [p for n, p in params if not any(nd in n for nd in no_decay)],
|
|
||||||
'weight_decay': 0.01,
|
|
||||||
'lr': lr,
|
|
||||||
'ori_lr': lr
|
|
||||||
},
|
|
||||||
{
|
|
||||||
'params': [p for n, p in params if any(nd in n for nd in no_decay)],
|
|
||||||
'weight_decay': 0.0,
|
|
||||||
'lr': lr,
|
|
||||||
'ori_lr': lr
|
|
||||||
}
|
|
||||||
]
|
|
||||||
self.optimizer = AdamW(grouped_params, correct_bias=False)
|
|
||||||
else:
|
|
||||||
raise Exception("Invalid optimizer. Must be 'sgd' or 'adam' or 'adamw'.")
|
|
||||||
# Warmup
|
|
||||||
if warmup_step > 0:
|
|
||||||
from transformers import get_linear_schedule_with_warmup
|
|
||||||
training_steps = self.train_loader.dataset.__len__() // batch_size * self.max_epoch
|
|
||||||
self.scheduler = get_linear_schedule_with_warmup(self.optimizer, num_warmup_steps=warmup_step, num_training_steps=training_steps)
|
|
||||||
else:
|
|
||||||
self.scheduler = None
|
|
||||||
# Cuda
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
self.cuda()
|
|
||||||
# Ckpt
|
|
||||||
self.ckpt = ckpt
|
|
||||||
|
|
||||||
def train_model(self, metric='acc'):
|
|
||||||
best_metric = 0
|
|
||||||
global_step = 0
|
|
||||||
for epoch in range(self.max_epoch):
|
|
||||||
self.train()
|
|
||||||
logging.info("=== Epoch %d train ===" % epoch)
|
|
||||||
avg_loss = AverageMeter()
|
|
||||||
avg_acc = AverageMeter()
|
|
||||||
t = tqdm(self.train_loader)
|
|
||||||
for iter, data in enumerate(t):
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
for i in range(len(data)):
|
|
||||||
try:
|
|
||||||
data[i] = data[i].cuda()
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
label = data[0]
|
|
||||||
args = data[1:]
|
|
||||||
logits = self.parallel_model(*args)
|
|
||||||
loss = self.criterion(logits, label)
|
|
||||||
score, pred = logits.max(-1) # (B)
|
|
||||||
acc = float((pred == label).long().sum()) / label.size(0)
|
|
||||||
# Log
|
|
||||||
avg_loss.update(loss.item(), 1)
|
|
||||||
avg_acc.update(acc, 1)
|
|
||||||
t.set_postfix(loss=avg_loss.avg, acc=avg_acc.avg)
|
|
||||||
# Optimize
|
|
||||||
loss.backward()
|
|
||||||
self.optimizer.step()
|
|
||||||
if self.scheduler is not None:
|
|
||||||
self.scheduler.step()
|
|
||||||
self.optimizer.zero_grad()
|
|
||||||
global_step += 1
|
|
||||||
# Val
|
|
||||||
logging.info("=== Epoch %d val ===" % epoch)
|
|
||||||
result = self.eval_model(self.val_loader)
|
|
||||||
logging.info('Metric {} current / best: {} / {}'.format(metric, result[metric], best_metric))
|
|
||||||
if result[metric] > best_metric:
|
|
||||||
logging.info("Best ckpt and saved.")
|
|
||||||
folder_path = '/'.join(self.ckpt.split('/')[:-1])
|
|
||||||
if not os.path.exists(folder_path):
|
|
||||||
os.mkdir(folder_path)
|
|
||||||
torch.save({'state_dict': self.model.state_dict()}, self.ckpt)
|
|
||||||
best_metric = result[metric]
|
|
||||||
logging.info("Best %s on val set: %f" % (metric, best_metric))
|
|
||||||
|
|
||||||
def eval_model(self, eval_loader):
|
|
||||||
self.eval()
|
|
||||||
avg_acc = AverageMeter()
|
|
||||||
pred_result = []
|
|
||||||
with torch.no_grad():
|
|
||||||
t = tqdm(eval_loader)
|
|
||||||
for iter, data in enumerate(t):
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
for i in range(len(data)):
|
|
||||||
try:
|
|
||||||
data[i] = data[i].cuda()
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
label = data[0]
|
|
||||||
args = data[1:]
|
|
||||||
logits = self.parallel_model(*args)
|
|
||||||
score, pred = logits.max(-1) # (B)
|
|
||||||
# Save result
|
|
||||||
for i in range(pred.size(0)):
|
|
||||||
pred_result.append(pred[i].item())
|
|
||||||
# Log
|
|
||||||
acc = float((pred == label).long().sum()) / label.size(0)
|
|
||||||
avg_acc.update(acc, pred.size(0))
|
|
||||||
t.set_postfix(acc=avg_acc.avg)
|
|
||||||
result = eval_loader.dataset.eval(pred_result)
|
|
||||||
return result
|
|
||||||
|
|
||||||
def load_state_dict(self, state_dict):
|
|
||||||
self.model.load_state_dict(state_dict)
|
|
||||||
|
|
|
@ -1,29 +0,0 @@
|
||||||
class AverageMeter(object):
|
|
||||||
"""
|
|
||||||
Computes and stores the average and current value of metrics.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.reset()
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
self.val = 0
|
|
||||||
self.avg = 0
|
|
||||||
self.sum = 0
|
|
||||||
self.count = 0
|
|
||||||
|
|
||||||
def update(self, val, n=0):
|
|
||||||
self.val = val
|
|
||||||
self.sum += val * n
|
|
||||||
self.count += n
|
|
||||||
self.avg = self.sum / (.0001 + self.count)
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
"""
|
|
||||||
String representation for logging
|
|
||||||
"""
|
|
||||||
# for values that should be recorded exactly e.g. iteration number
|
|
||||||
if self.count == 0:
|
|
||||||
return str(self.val)
|
|
||||||
# for stats
|
|
||||||
return '%.4f (%.4f)' % (self.val, self.avg)
|
|
|
@ -1,21 +0,0 @@
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
from .base_model import SentenceRE, BagRE, FewShotRE, NER
|
|
||||||
from .softmax_nn import SoftmaxNN
|
|
||||||
from .sigmoid_nn import SigmoidNN
|
|
||||||
from .bag_attention import BagAttention
|
|
||||||
from .bag_average import BagAverage
|
|
||||||
from .bag_one import BagOne
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
'SentenceRE',
|
|
||||||
'BagRE',
|
|
||||||
'FewShotRE',
|
|
||||||
'NER',
|
|
||||||
'SoftmaxNN',
|
|
||||||
'BagAttention',
|
|
||||||
'BagAverage',
|
|
||||||
'BagOne'
|
|
||||||
]
|
|
|
@ -1,182 +0,0 @@
|
||||||
import torch
|
|
||||||
from torch import nn, optim
|
|
||||||
from .base_model import BagRE
|
|
||||||
|
|
||||||
class BagAttention(BagRE):
|
|
||||||
"""
|
|
||||||
Instance attention for bag-level relation extraction.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, sentence_encoder, num_class, rel2id, use_diag=True):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
sentence_encoder: encoder for sentences
|
|
||||||
num_class: number of classes
|
|
||||||
id2rel: dictionary of id -> relation name mapping
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.sentence_encoder = sentence_encoder
|
|
||||||
self.num_class = num_class
|
|
||||||
self.fc = nn.Linear(self.sentence_encoder.hidden_size, num_class)
|
|
||||||
self.softmax = nn.Softmax(-1)
|
|
||||||
self.rel2id = rel2id
|
|
||||||
self.id2rel = {}
|
|
||||||
self.drop = nn.Dropout()
|
|
||||||
for rel, id in rel2id.items():
|
|
||||||
self.id2rel[id] = rel
|
|
||||||
if use_diag:
|
|
||||||
self.use_diag = True
|
|
||||||
self.diag = nn.Parameter(torch.ones(self.sentence_encoder.hidden_size))
|
|
||||||
else:
|
|
||||||
self.use_diag = False
|
|
||||||
|
|
||||||
def infer(self, bag):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
bag: bag of sentences with the same entity pair
|
|
||||||
[{
|
|
||||||
'text' or 'token': ...,
|
|
||||||
'h': {'pos': [start, end], ...},
|
|
||||||
't': {'pos': [start, end], ...}
|
|
||||||
}]
|
|
||||||
Return:
|
|
||||||
(relation, score)
|
|
||||||
"""
|
|
||||||
self.eval()
|
|
||||||
tokens = []
|
|
||||||
pos1s = []
|
|
||||||
pos2s = []
|
|
||||||
masks = []
|
|
||||||
for item in bag:
|
|
||||||
token, pos1, pos2, mask = self.sentence_encoder.tokenize(item)
|
|
||||||
tokens.append(token)
|
|
||||||
pos1s.append(pos1)
|
|
||||||
pos2s.append(pos2)
|
|
||||||
masks.append(mask)
|
|
||||||
tokens = torch.cat(tokens, 0).unsqueeze(0) # (n, L)
|
|
||||||
pos1s = torch.cat(pos1s, 0).unsqueeze(0)
|
|
||||||
pos2s = torch.cat(pos2s, 0).unsqueeze(0)
|
|
||||||
masks = torch.cat(masks, 0).unsqueeze(0)
|
|
||||||
scope = torch.tensor([[0, len(bag)]]).long() # (1, 2)
|
|
||||||
bag_logits = self.forward(None, scope, tokens, pos1s, pos2s, masks, train=False).squeeze(0) # (N) after softmax
|
|
||||||
score, pred = bag_logits.max(0)
|
|
||||||
score = score.item()
|
|
||||||
pred = pred.item()
|
|
||||||
rel = self.id2rel[pred]
|
|
||||||
return (rel, score)
|
|
||||||
|
|
||||||
def forward(self, label, scope, token, pos1, pos2, mask=None, train=True, bag_size=0):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
label: (B), label of the bag
|
|
||||||
scope: (B), scope for each bag
|
|
||||||
token: (nsum, L), index of tokens
|
|
||||||
pos1: (nsum, L), relative position to head entity
|
|
||||||
pos2: (nsum, L), relative position to tail entity
|
|
||||||
mask: (nsum, L), used for piece-wise CNN
|
|
||||||
Return:
|
|
||||||
logits, (B, N)
|
|
||||||
|
|
||||||
Dirty hack:
|
|
||||||
When the encoder is BERT, the input is actually token, att_mask, pos1, pos2, but
|
|
||||||
since the arguments are then fed into BERT encoder with the original order,
|
|
||||||
the encoder can actually work out correclty.
|
|
||||||
"""
|
|
||||||
if bag_size > 0:
|
|
||||||
token = token.view(-1, token.size(-1))
|
|
||||||
pos1 = pos1.view(-1, pos1.size(-1))
|
|
||||||
pos2 = pos2.view(-1, pos2.size(-1))
|
|
||||||
if mask is not None:
|
|
||||||
mask = mask.view(-1, mask.size(-1))
|
|
||||||
else:
|
|
||||||
begin, end = scope[0][0], scope[-1][1]
|
|
||||||
token = token[:, begin:end, :].view(-1, token.size(-1))
|
|
||||||
pos1 = pos1[:, begin:end, :].view(-1, pos1.size(-1))
|
|
||||||
pos2 = pos2[:, begin:end, :].view(-1, pos2.size(-1))
|
|
||||||
if mask is not None:
|
|
||||||
mask = mask[:, begin:end, :].view(-1, mask.size(-1))
|
|
||||||
scope = torch.sub(scope, torch.zeros_like(scope).fill_(begin))
|
|
||||||
|
|
||||||
# Attention
|
|
||||||
if train:
|
|
||||||
if mask is not None:
|
|
||||||
rep = self.sentence_encoder(token, pos1, pos2, mask) # (nsum, H)
|
|
||||||
else:
|
|
||||||
rep = self.sentence_encoder(token, pos1, pos2) # (nsum, H)
|
|
||||||
|
|
||||||
if bag_size == 0:
|
|
||||||
bag_rep = []
|
|
||||||
query = torch.zeros((rep.size(0))).long()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
query = query.cuda()
|
|
||||||
for i in range(len(scope)):
|
|
||||||
query[scope[i][0]:scope[i][1]] = label[i]
|
|
||||||
att_mat = self.fc.weight[query] # (nsum, H)
|
|
||||||
if self.use_diag:
|
|
||||||
att_mat = att_mat * self.diag.unsqueeze(0)
|
|
||||||
att_score = (rep * att_mat).sum(-1) # (nsum)
|
|
||||||
|
|
||||||
for i in range(len(scope)):
|
|
||||||
bag_mat = rep[scope[i][0]:scope[i][1]] # (n, H)
|
|
||||||
softmax_att_score = self.softmax(att_score[scope[i][0]:scope[i][1]]) # (n)
|
|
||||||
bag_rep.append((softmax_att_score.unsqueeze(-1) * bag_mat).sum(0)) # (n, 1) * (n, H) -> (n, H) -> (H)
|
|
||||||
bag_rep = torch.stack(bag_rep, 0) # (B, H)
|
|
||||||
else:
|
|
||||||
batch_size = label.size(0)
|
|
||||||
query = label.unsqueeze(1) # (B, 1)
|
|
||||||
att_mat = self.fc.weight[query] # (B, 1, H)
|
|
||||||
if self.use_diag:
|
|
||||||
att_mat = att_mat * self.diag.unsqueeze(0)
|
|
||||||
rep = rep.view(batch_size, bag_size, -1)
|
|
||||||
att_score = (rep * att_mat).sum(-1) # (B, bag)
|
|
||||||
softmax_att_score = self.softmax(att_score) # (B, bag)
|
|
||||||
bag_rep = (softmax_att_score.unsqueeze(-1) * rep).sum(1) # (B, bag, 1) * (B, bag, H) -> (B, bag, H) -> (B, H)
|
|
||||||
bag_rep = self.drop(bag_rep)
|
|
||||||
bag_logits = self.fc(bag_rep) # (B, N)
|
|
||||||
else:
|
|
||||||
|
|
||||||
if bag_size == 0:
|
|
||||||
rep = []
|
|
||||||
bs = 256
|
|
||||||
total_bs = len(token) // bs + (1 if len(token) % bs != 0 else 0)
|
|
||||||
for b in range(total_bs):
|
|
||||||
with torch.no_grad():
|
|
||||||
left = bs * b
|
|
||||||
right = min(bs * (b + 1), len(token))
|
|
||||||
if mask is not None:
|
|
||||||
rep.append(self.sentence_encoder(token[left:right], pos1[left:right], pos2[left:right], mask[left:right]).detach()) # (nsum, H)
|
|
||||||
else:
|
|
||||||
rep.append(self.sentence_encoder(token[left:right], pos1[left:right], pos2[left:right]).detach()) # (nsum, H)
|
|
||||||
rep = torch.cat(rep, 0)
|
|
||||||
|
|
||||||
bag_logits = []
|
|
||||||
att_mat = self.fc.weight.transpose(0, 1)
|
|
||||||
if self.use_diag:
|
|
||||||
att_mat = att_mat * self.diag.unsqueeze(1)
|
|
||||||
att_score = torch.matmul(rep, att_mat) # (nsum, H) * (H, N) -> (nsum, N)
|
|
||||||
for i in range(len(scope)):
|
|
||||||
bag_mat = rep[scope[i][0]:scope[i][1]] # (n, H)
|
|
||||||
softmax_att_score = self.softmax(att_score[scope[i][0]:scope[i][1]].transpose(0, 1)) # (N, (softmax)n)
|
|
||||||
rep_for_each_rel = torch.matmul(softmax_att_score, bag_mat) # (N, n) * (n, H) -> (N, H)
|
|
||||||
logit_for_each_rel = self.softmax(self.fc(rep_for_each_rel)) # ((each rel)N, (logit)N)
|
|
||||||
logit_for_each_rel = logit_for_each_rel.diag() # (N)
|
|
||||||
bag_logits.append(logit_for_each_rel)
|
|
||||||
bag_logits = torch.stack(bag_logits, 0) # after **softmax**
|
|
||||||
else:
|
|
||||||
if mask is not None:
|
|
||||||
rep = self.sentence_encoder(token, pos1, pos2, mask) # (nsum, H)
|
|
||||||
else:
|
|
||||||
rep = self.sentence_encoder(token, pos1, pos2) # (nsum, H)
|
|
||||||
|
|
||||||
batch_size = rep.size(0) // bag_size
|
|
||||||
att_mat = self.fc.weight.transpose(0, 1)
|
|
||||||
if self.use_diag:
|
|
||||||
att_mat = att_mat * self.diag.unsqueeze(1)
|
|
||||||
att_score = torch.matmul(rep, att_mat) # (nsum, H) * (H, N) -> (nsum, N)
|
|
||||||
att_score = att_score.view(batch_size, bag_size, -1) # (B, bag, N)
|
|
||||||
rep = rep.view(batch_size, bag_size, -1) # (B, bag, H)
|
|
||||||
softmax_att_score = self.softmax(att_score.transpose(1, 2)) # (B, N, (softmax)bag)
|
|
||||||
rep_for_each_rel = torch.matmul(softmax_att_score, rep) # (B, N, bag) * (B, bag, H) -> (B, N, H)
|
|
||||||
bag_logits = self.softmax(self.fc(rep_for_each_rel)).diagonal(dim1=1, dim2=2) # (B, (each rel)N)
|
|
||||||
return bag_logits
|
|
||||||
|
|
|
@ -1,134 +0,0 @@
|
||||||
import torch
|
|
||||||
from torch import nn, optim
|
|
||||||
from .base_model import BagRE
|
|
||||||
|
|
||||||
class BagAverage(BagRE):
|
|
||||||
"""
|
|
||||||
Average policy for bag-level relation extraction.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, sentence_encoder, num_class, rel2id):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
sentence_encoder: encoder for sentences
|
|
||||||
num_class: number of classes
|
|
||||||
id2rel: dictionary of id -> relation name mapping
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.sentence_encoder = sentence_encoder
|
|
||||||
self.num_class = num_class
|
|
||||||
self.fc = nn.Linear(self.sentence_encoder.hidden_size, num_class)
|
|
||||||
self.softmax = nn.Softmax(-1)
|
|
||||||
self.rel2id = rel2id
|
|
||||||
self.id2rel = {}
|
|
||||||
self.drop = nn.Dropout()
|
|
||||||
for rel, id in rel2id.items():
|
|
||||||
self.id2rel[id] = rel
|
|
||||||
|
|
||||||
def infer(self, bag):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
bag: bag of sentences with the same entity pair
|
|
||||||
[{
|
|
||||||
'text' or 'token': ...,
|
|
||||||
'h': {'pos': [start, end], ...},
|
|
||||||
't': {'pos': [start, end], ...}
|
|
||||||
}]
|
|
||||||
Return:
|
|
||||||
(relation, score)
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
"""
|
|
||||||
tokens = []
|
|
||||||
pos1s = []
|
|
||||||
pos2s = []
|
|
||||||
masks = []
|
|
||||||
for item in bag:
|
|
||||||
if 'text' in item:
|
|
||||||
token, pos1, pos2, mask = self.tokenizer(item['text'],
|
|
||||||
item['h']['pos'], item['t']['pos'], is_token=False, padding=True)
|
|
||||||
else:
|
|
||||||
token, pos1, pos2, mask = self.tokenizer(item['token'],
|
|
||||||
item['h']['pos'], item['t']['pos'], is_token=True, padding=True)
|
|
||||||
tokens.append(token)
|
|
||||||
pos1s.append(pos1)
|
|
||||||
pos2s.append(pos2)
|
|
||||||
masks.append(mask)
|
|
||||||
tokens = torch.cat(tokens, 0) # (n, L)
|
|
||||||
pos1s = torch.cat(pos1s, 0)
|
|
||||||
pos2s = torch.cat(pos2s, 0)
|
|
||||||
masks = torch.cat(masks, 0)
|
|
||||||
scope = torch.tensor([[0, len(bag)]]).long() # (1, 2)
|
|
||||||
bag_logits = self.forward(None, scope, tokens, pos1s, pos2s, masks, train=False).squeeze(0) # (N) after softmax
|
|
||||||
score, pred = bag_logits.max()
|
|
||||||
score = score.item()
|
|
||||||
pred = pred.item()
|
|
||||||
rel = self.id2rel[pred]
|
|
||||||
return (rel, score)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def forward(self, label, scope, token, pos1, pos2, mask=None, train=True, bag_size=None):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
label: (B), label of the bag
|
|
||||||
scope: (B), scope for each bag
|
|
||||||
token: (nsum, L), index of tokens
|
|
||||||
pos1: (nsum, L), relative position to head entity
|
|
||||||
pos2: (nsum, L), relative position to tail entity
|
|
||||||
mask: (nsum, L), used for piece-wise CNN
|
|
||||||
Return:
|
|
||||||
logits, (B, N)
|
|
||||||
"""
|
|
||||||
if bag_size > 0:
|
|
||||||
token = token.view(-1, token.size(-1))
|
|
||||||
pos1 = pos1.view(-1, pos1.size(-1))
|
|
||||||
pos2 = pos2.view(-1, pos2.size(-1))
|
|
||||||
if mask is not None:
|
|
||||||
mask = mask.view(-1, mask.size(-1))
|
|
||||||
else:
|
|
||||||
begin, end = scope[0][0], scope[-1][1]
|
|
||||||
token = token[:, begin:end, :].view(-1, token.size(-1))
|
|
||||||
pos1 = pos1[:, begin:end, :].view(-1, pos1.size(-1))
|
|
||||||
pos2 = pos2[:, begin:end, :].view(-1, pos2.size(-1))
|
|
||||||
if mask is not None:
|
|
||||||
mask = mask[:, begin:end, :].view(-1, mask.size(-1))
|
|
||||||
scope = torch.sub(scope, torch.zeros_like(scope).fill_(begin))
|
|
||||||
|
|
||||||
if train or bag_size > 0:
|
|
||||||
if mask is not None:
|
|
||||||
rep = self.sentence_encoder(token, pos1, pos2, mask) # (nsum, H)
|
|
||||||
else:
|
|
||||||
rep = self.sentence_encoder(token, pos1, pos2) # (nsum, H)
|
|
||||||
else:
|
|
||||||
rep = []
|
|
||||||
bs = 256
|
|
||||||
total_bs = len(token) // bs + (1 if len(token) % bs != 0 else 0)
|
|
||||||
for b in range(total_bs):
|
|
||||||
with torch.no_grad():
|
|
||||||
left = bs * b
|
|
||||||
right = min(bs * (b + 1), len(token))
|
|
||||||
if mask is not None:
|
|
||||||
rep.append(self.sentence_encoder(token[left:right], pos1[left:right], pos2[left:right], mask[left:right]).detach()) # (nsum, H)
|
|
||||||
else:
|
|
||||||
rep.append(self.sentence_encoder(token[left:right], pos1[left:right], pos2[left:right]).detach()) # (nsum, H)
|
|
||||||
rep = torch.cat(rep, 0)
|
|
||||||
|
|
||||||
# Average
|
|
||||||
bag_rep = []
|
|
||||||
if bag_size is None or bag_size == 0:
|
|
||||||
for i in range(len(scope)):
|
|
||||||
bag_rep.append(rep[scope[i][0]:scope[i][1]].mean(0))
|
|
||||||
bag_rep = torch.stack(bag_rep, 0) # (B, H)
|
|
||||||
else:
|
|
||||||
batch_size = len(scope)
|
|
||||||
rep = rep.view(batch_size, bag_size, -1) # (B, bag, H)
|
|
||||||
bag_rep = rep.mean(1) # (B, H)
|
|
||||||
bag_rep = self.drop(bag_rep)
|
|
||||||
bag_logits = self.fc(bag_rep) # (B, N)
|
|
||||||
|
|
||||||
if not train:
|
|
||||||
bag_logits = torch.softmax(bag_logits, -1)
|
|
||||||
|
|
||||||
return bag_logits
|
|
||||||
|
|
|
@ -1,155 +0,0 @@
|
||||||
import torch
|
|
||||||
from torch import nn, optim
|
|
||||||
from .base_model import BagRE
|
|
||||||
|
|
||||||
class BagOne(BagRE):
|
|
||||||
"""
|
|
||||||
Instance one(max) for bag-level relation extraction.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, sentence_encoder, num_class, rel2id):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
sentence_encoder: encoder for sentences
|
|
||||||
num_class: number of classes
|
|
||||||
id2rel: dictionary of id -> relation name mapping
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.sentence_encoder = sentence_encoder
|
|
||||||
self.num_class = num_class
|
|
||||||
self.fc = nn.Linear(self.sentence_encoder.hidden_size, num_class)
|
|
||||||
self.softmax = nn.Softmax(-1)
|
|
||||||
self.rel2id = rel2id
|
|
||||||
self.id2rel = {}
|
|
||||||
self.drop = nn.Dropout()
|
|
||||||
for rel, id in rel2id.items():
|
|
||||||
self.id2rel[id] = rel
|
|
||||||
|
|
||||||
def infer(self, bag):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
bag: bag of sentences with the same entity pair
|
|
||||||
[{
|
|
||||||
'text' or 'token': ...,
|
|
||||||
'h': {'pos': [start, end], ...},
|
|
||||||
't': {'pos': [start, end], ...}
|
|
||||||
}]
|
|
||||||
Return:
|
|
||||||
(relation, score)
|
|
||||||
"""
|
|
||||||
self.eval()
|
|
||||||
tokens = []
|
|
||||||
pos1s = []
|
|
||||||
pos2s = []
|
|
||||||
masks = []
|
|
||||||
for item in bag:
|
|
||||||
token, pos1, pos2, mask = self.sentence_encoder.tokenize(item)
|
|
||||||
tokens.append(token)
|
|
||||||
pos1s.append(pos1)
|
|
||||||
pos2s.append(pos2)
|
|
||||||
masks.append(mask)
|
|
||||||
tokens = torch.cat(tokens, 0).unsqueeze(0) # (n, L)
|
|
||||||
pos1s = torch.cat(pos1s, 0).unsqueeze(0)
|
|
||||||
pos2s = torch.cat(pos2s, 0).unsqueeze(0)
|
|
||||||
masks = torch.cat(masks, 0).unsqueeze(0)
|
|
||||||
scope = torch.tensor([[0, len(bag)]]).long() # (1, 2)
|
|
||||||
bag_logits = self.forward(None, scope, tokens, pos1s, pos2s, masks, train=False).squeeze(0) # (N) after softmax
|
|
||||||
score, pred = bag_logits.max(0)
|
|
||||||
score = score.item()
|
|
||||||
pred = pred.item()
|
|
||||||
rel = self.id2rel[pred]
|
|
||||||
return (rel, score)
|
|
||||||
|
|
||||||
def forward(self, label, scope, token, pos1, pos2, mask=None, train=True, bag_size=0):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
label: (B), label of the bag
|
|
||||||
scope: (B), scope for each bag
|
|
||||||
token: (nsum, L), index of tokens
|
|
||||||
pos1: (nsum, L), relative position to head entity
|
|
||||||
pos2: (nsum, L), relative position to tail entity
|
|
||||||
mask: (nsum, L), used for piece-wise CNN
|
|
||||||
Return:
|
|
||||||
logits, (B, N)
|
|
||||||
"""
|
|
||||||
# Encode
|
|
||||||
if bag_size > 0:
|
|
||||||
token = token.view(-1, token.size(-1))
|
|
||||||
pos1 = pos1.view(-1, pos1.size(-1))
|
|
||||||
pos2 = pos2.view(-1, pos2.size(-1))
|
|
||||||
if mask is not None:
|
|
||||||
mask = mask.view(-1, mask.size(-1))
|
|
||||||
else:
|
|
||||||
begin, end = scope[0][0], scope[-1][1]
|
|
||||||
token = token[:, begin:end, :].view(-1, token.size(-1))
|
|
||||||
pos1 = pos1[:, begin:end, :].view(-1, pos1.size(-1))
|
|
||||||
pos2 = pos2[:, begin:end, :].view(-1, pos2.size(-1))
|
|
||||||
if mask is not None:
|
|
||||||
mask = mask[:, begin:end, :].view(-1, mask.size(-1))
|
|
||||||
scope = torch.sub(scope, torch.zeros_like(scope).fill_(begin))
|
|
||||||
|
|
||||||
if train or bag_size > 0:
|
|
||||||
if mask is not None:
|
|
||||||
rep = self.sentence_encoder(token, pos1, pos2, mask) # (nsum, H)
|
|
||||||
else:
|
|
||||||
rep = self.sentence_encoder(token, pos1, pos2) # (nsum, H)
|
|
||||||
else:
|
|
||||||
rep = []
|
|
||||||
bs = 256
|
|
||||||
total_bs = len(token) // bs + (1 if len(token) % bs != 0 else 0)
|
|
||||||
for b in range(total_bs):
|
|
||||||
with torch.no_grad():
|
|
||||||
left = bs * b
|
|
||||||
right = min(bs * (b + 1), len(token))
|
|
||||||
if mask is not None:
|
|
||||||
rep.append(self.sentence_encoder(token[left:right], pos1[left:right], pos2[left:right], mask[left:right]).detach()) # (nsum, H)
|
|
||||||
else:
|
|
||||||
rep.append(self.sentence_encoder(token[left:right], pos1[left:right], pos2[left:right]).detach()) # (nsum, H)
|
|
||||||
rep = torch.cat(rep, 0)
|
|
||||||
|
|
||||||
# Max
|
|
||||||
if train:
|
|
||||||
if bag_size == 0:
|
|
||||||
bag_rep = []
|
|
||||||
query = torch.zeros((rep.size(0))).long()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
query = query.cuda()
|
|
||||||
for i in range(len(scope)):
|
|
||||||
query[scope[i][0]:scope[i][1]] = label[i]
|
|
||||||
|
|
||||||
for i in range(len(scope)): # iterate over bags
|
|
||||||
bag_mat = rep[scope[i][0]:scope[i][1]] # (n, H)
|
|
||||||
instance_logit = self.softmax(self.fc(bag_mat)) # (n, N)
|
|
||||||
# select j* which scores highest on the known label
|
|
||||||
max_index = instance_logit[:, query[i]].argmax() # (1)
|
|
||||||
bag_rep.append(bag_mat[max_index]) # (n, H) -> (H)
|
|
||||||
bag_rep = torch.stack(bag_rep, 0) # (B, H)
|
|
||||||
bag_rep = self.drop(bag_rep)
|
|
||||||
bag_logits = self.fc(bag_rep) # (B, N)
|
|
||||||
else:
|
|
||||||
batch_size = label.size(0)
|
|
||||||
query = label # (B)
|
|
||||||
rep = rep.view(batch_size, bag_size, -1)
|
|
||||||
instance_logit = self.softmax(self.fc(rep))
|
|
||||||
max_index = instance_logit[torch.arange(batch_size), :, query].argmax(-1)
|
|
||||||
bag_rep = rep[torch.arange(batch_size), max_index]
|
|
||||||
|
|
||||||
bag_rep = self.drop(bag_rep)
|
|
||||||
bag_logits = self.fc(bag_rep) # (B, N)
|
|
||||||
|
|
||||||
else:
|
|
||||||
if bag_size == 0:
|
|
||||||
bag_logits = []
|
|
||||||
for i in range(len(scope)):
|
|
||||||
bag_mat = rep[scope[i][0]:scope[i][1]] # (n, H)
|
|
||||||
instance_logit = self.softmax(self.fc(bag_mat)) # (n, N)
|
|
||||||
logit_for_each_rel = instance_logit.max(dim=0)[0] # (N)
|
|
||||||
bag_logits.append(logit_for_each_rel)
|
|
||||||
bag_logits = torch.stack(bag_logits, 0) # after **softmax**
|
|
||||||
else:
|
|
||||||
batch_size = rep.size(0) // bag_size
|
|
||||||
rep = rep.view(batch_size, bag_size, -1)
|
|
||||||
bag_logits = self.softmax(self.fc(rep)).max(1)[0]
|
|
||||||
|
|
||||||
return bag_logits
|
|
||||||
|
|
|
@ -1,71 +0,0 @@
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
import json
|
|
||||||
|
|
||||||
class SentenceRE(nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def infer(self, item):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
item: {'text' or 'token', 'h': {'pos': [start, end]}, 't': ...}
|
|
||||||
Return:
|
|
||||||
(Name of the relation of the sentence, score)
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class BagRE(nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def infer(self, bag):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
bag: bag of sentences with the same entity pair
|
|
||||||
[{
|
|
||||||
'text' or 'token': ...,
|
|
||||||
'h': {'pos': [start, end], ...},
|
|
||||||
't': {'pos': [start, end], ...}
|
|
||||||
}]
|
|
||||||
Return:
|
|
||||||
(relation, score)
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
class FewShotRE(nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def infer(self, support, query):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
support: supporting set.
|
|
||||||
[{'text' or 'token': ...,
|
|
||||||
'h': {'pos': [start, end], ...},
|
|
||||||
't': {'pos': [start, end], ...},
|
|
||||||
'relation': ...}]
|
|
||||||
query: same format as support
|
|
||||||
Return:
|
|
||||||
[(relation, score), ...]
|
|
||||||
|
|
||||||
|
|
||||||
For few-shot relation extraction, please refer to FewRel
|
|
||||||
https://github.com/thunlp/FewRel
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
class NER(nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def ner(self, sentence, is_token=False):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
sentence: string, the input sentence
|
|
||||||
is_token: if is_token == True, senetence becomes an array of token
|
|
||||||
Return:
|
|
||||||
[{name: xx, pos: [start, end]}], a list of named entities
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user