This commit is contained in:
mustafa 2025-04-09 10:38:54 +03:30
parent 18eb11dea7
commit a991f030e4
215 changed files with 11 additions and 137808 deletions

View File

@ -1,228 +0,0 @@
### JetBrains template
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
# User-specific stuff
.idea/**/workspace.xml
.idea/**/tasks.xml
.idea/**/usage.statistics.xml
.idea/**/dictionaries
.idea/**/shelf
# Generated files
.idea/**/contentModel.xml
*.jsonl
*.7zip
*.7z
.idea/
fastapi_flair/idea/
# Sensitive or high-churn files
.idea/**/dataSources/
.idea/**/dataSources.ids
.idea/**/dataSources.local.xml
.idea/**/sqlDataSources.xml
.idea/**/dynamic.xml
.idea/**/uiDesigner.xml
.idea/**/dbnavigator.xml
# Gradle
.idea/**/gradle.xml
.idea/**/libraries
# Gradle and Maven with auto-import
# When using Gradle or Maven with auto-import, you should exclude module files,
# since they will be recreated, and may cause churn. Uncomment if using
# auto-import.
# .idea/artifacts
# .idea/compiler.xml
# .idea/jarRepositories.xml
# .idea/modules.xml
# .idea/*.iml
# .idea/modules
# *.iml
# *.ipr
# CMake
cmake-build-*/
# Mongo Explorer plugin
.idea/**/mongoSettings.xml
# File-based project format
*.iws
# IntelliJ
out/
# mpeltonen/sbt-idea plugin
.idea_modules/
# JIRA plugin
atlassian-ide-plugin.xml
# Cursive Clojure plugin
.idea/replstate.xml
# Crashlytics plugin (for Android Studio and IntelliJ)
com_crashlytics_export_strings.xml
crashlytics.properties
crashlytics-build.properties
fabric.properties
# Editor-based Rest Client
.idea/httpRequests
# Android studio 3.1+ serialized cache file
.idea/caches/build_file_checksums.ser
### Python template
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
data/ARMAN-MSR-persian-base-PN-summary
data/mT5_multilingual_XLSum
data/ner-model.pt
data/pos-model.pt
data/pos-model.7z

View File

@ -1,34 +0,0 @@
FROM python:3.9-slim
# Maintainer info
LABEL maintainer="khaledihkh@gmail.com"
# Make working directories
RUN mkdir -p /service
WORKDIR /service
# Upgrade pip with no cache
RUN pip install --no-cache-dir -U pip
# Copy application requirements file to the created working directory
COPY requirements.txt .
# Install application dependencies from the requirements file
RUN pip install -r requirements.txt
# Copy every file in the source folder to the created working directory
COPY . .
RUN pip install ./src/piraye
RUN pip install transformers[sentencepiece]
ENV NER_MODEL_NAME ""
ENV NORMALIZE "False"
ENV SHORT_MODEL_NAME ""
ENV LONG_MODEL_NAME ""
ENV POS_MODEL_NAME ""
ENV BATCH_SIZE "4"
ENV DEVICE "gpu"
EXPOSE 80
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "80"]

View File

@ -1,18 +0,0 @@
# Services
## Requirements
````shell
pip install -r requirements.txt
````
## Download Models
download models and place in data folder
https://www.mediafire.com/folder/tz3t9c9rpf6fo/models
## Getting started
````shell
cd fastapi_flair
uvicorn main:app --reload
````
## Documentation
[Open Documentations](./docs/docs.md)

View File

@ -1,10 +0,0 @@
[
{
"client-id":"hassan",
"client-password":"1234"
},
{
"client-id":"ali",
"client-password":"12345"
}
]

View File

@ -1,8 +0,0 @@
version: '3.0'
services:
web:
image: "services"
command: uvicorn main:app --host 0.0.0.0
ports:
- 8008:8000

View File

@ -1,10 +0,0 @@
### Libraries:
1. Uvicorn: Uvicorn is an ASGI (The Asynchronous Server Gateway Interface is a calling convention for web servers to forward requests to asynchronous-capable Python programming language frameworks, and applications. It is built as a successor to the Web Server Gateway Interface.) web server implementation for Python. Until recently Python has lacked a minimal low-level server/application interface for async frameworks. The ASGI specification fills this gap, and means we're now able to start building a common set of tooling usable across all async frameworks.
2. FastAPI: FastAPI framework, high performance, easy to learn, fast to code, ready for production
FastAPI is a modern, fast (high-performance), web framework for building APIs with Python 3.7+ based on standard Python type hints.
3. NLTK: NLTK is a leading platform for building Python programs to work with human language data. It provides easy-to-use interfaces to over 50 corpora and lexical resources such as WordNet, along with a suite of text processing libraries for classification, tokenization, stemming, tagging, parsing, and semantic reasoning, wrappers for industrial-strength NLP libraries, and an active discussion forum.
4. Piraye: NLP Utils, A utility for normalizing persian, arabic and english texts
5. ...

View File

@ -1,199 +0,0 @@
from __future__ import annotations
import uvicorn
import torch
from fastapi import FastAPI, HTTPException, Request
from starlette.status import HTTP_201_CREATED
from src.database_handler import Database
from src.requests_data import Requests
from model.request_models import InputNerType, InputSummaryType, InputNREType
from model.response_models import BaseResponse, ResponseNerModel, ResponseSummaryModel, ResponseOIEModel, \
ResponseNREModel
from src.response_ner import NerResponse, ResponsesNer
from fastapi import BackgroundTasks
from src.response_nre import NREResponse, ResponsesNRE
from src.response_oie import OIEResponse, ResponsesOIE
from src.response_summary import ShortResponse, LongResponse, ResponsesSummary, NormalizerTexts
print("torch version : ", torch.__version__)
app = FastAPI()
# ner_model_name = os.getenv("NER_MODEL_NAME")
# normalize = os.getenv("NORMALIZE")
# short_model_name = os.getenv("SHORT_MODEL_NAME")
# long_model_name = os.getenv("LONG_MODEL_NAME")
# pos_model_name = os.getenv("POS_MODEL_NAME")
# batch_size = int(os.getenv("BATCH_SIZE"))
# device = os.getenv("DEVICE")
#
# print(ner_model_name, normalize, short_model_name, long_model_name, pos_model_name, batch_size, device)
# initial docker parameters
pos_model_name = "./data/pos-model.pt"
ner_model_name = "./data/ner-model.pt"
short_model_name = "csebuetnlp/mT5_multilingual_XLSum"
long_model_name = "alireza7/ARMAN-MSR-persian-base-PN-summary"
# oie_model_name = "default"
oie_model_name = ""
ner_model_name = "default"
pos_model_name = "default"
nre_model_name = ""
bert_oie = './data/mbert-base-parsinlu-entailment'
short_model_name = ""
long_model_name = ""
normalize = "False"
batch_size = 4
device = "cpu"
if ner_model_name == "":
ner_response = None
else:
if ner_model_name.lower() == "default":
ner_model_name = "./data/ner-model.pt"
ner_response = NerResponse(ner_model_name, batch_size=batch_size, device=device)
if oie_model_name == "":
oie_response = None
else:
if oie_model_name.lower() == "default":
oie_model_name = "./data/mbertEntail-2GBEn.bin"
oie_response = OIEResponse(oie_model_name, BERT=bert_oie, batch_size=batch_size, device=device)
if nre_model_name == "":
nre_response = None
else:
if nre_model_name.lower() == "default":
nre_model_name = "./data/model.ckpt.pth.tar"
nre_response = NREResponse(nre_model_name)
if normalize.lower() == "false" or pos_model_name == "":
short_response = ShortResponse(normalizer_input=None, device=device)
long_response = LongResponse(normalizer_input=None, device=device)
else:
if pos_model_name.lower() == "default":
pos_model_name = "./data/pos-model.pt"
normalizer_pos = NormalizerTexts(pos_model_name)
short_response = ShortResponse(normalizer_input=normalizer_pos, device=device)
long_response = LongResponse(normalizer_input=normalizer_pos, device=device)
if short_model_name != "":
short_response.load_model(short_model_name)
if long_model_name != "":
long_response.load_model(long_model_name)
requests: Requests = Requests()
database: Database = Database()
@app.get("/")
async def root():
return {"message": "Hello World"}
@app.get("/ner/{item_id}", response_model=ResponseNerModel)
async def read_item(item_id, request: Request) -> ResponseNerModel:
if database.check_client(client_id=request.headers.get('Client-Id'),
client_password=request.headers.get('Client-Password')) and ner_response:
if not ResponsesNer.check_id(int(item_id)):
raise HTTPException(status_code=404, detail="Item not found")
return ResponsesNer.get_response(int(item_id))
else:
raise HTTPException(status_code=403, detail="you are not allowed")
@app.get("/oie/{item_id}", response_model=ResponseOIEModel)
async def read_item(item_id, request: Request) -> ResponseOIEModel:
if database.check_client(client_id=request.headers.get('Client-Id'),
client_password=request.headers.get('Client-Password')) and oie_response:
if not ResponsesOIE.check_id(int(item_id)):
raise HTTPException(status_code=404, detail="Item not found")
return ResponsesOIE.get_response(int(item_id))
else:
raise HTTPException(status_code=403, detail="you are not allowed")
@app.get("/nre/{item_id}", response_model=ResponseNREModel)
async def read_item(item_id, request: Request) -> ResponseNREModel:
if database.check_client(client_id=request.headers.get('Client-Id'),
client_password=request.headers.get('Client-Password')) and nre_response:
if not ResponsesNRE.check_id(int(item_id)):
raise HTTPException(status_code=404, detail="Item not found")
return ResponsesNRE.get_response(int(item_id))
else:
raise HTTPException(status_code=403, detail="you are not allowed")
@app.get("/summary/{item_id}", response_model=ResponseSummaryModel)
async def read_item(item_id, request: Request) -> ResponseSummaryModel:
if database.check_client(client_id=request.headers.get('Client-Id'),
client_password=request.headers.get('Client-Password')):
if not ResponsesSummary.check_id(int(item_id)):
raise HTTPException(status_code=404, detail="Item not found")
return ResponsesSummary.get_response(int(item_id))
else:
raise HTTPException(status_code=403, detail="you are not allowed")
@app.post("/ner", response_model=BaseResponse, status_code=HTTP_201_CREATED)
async def ner(input_json: InputNerType, background_tasks: BackgroundTasks, request: Request) -> BaseResponse:
if database.check_client(client_id=request.headers.get('Client-Id'),
client_password=request.headers.get('Client-Password')) and ner_model_name != "":
request_id = requests.add_request_ner(input_json)
background_tasks.add_task(ner_response.predict_json, input_json, request_id)
return BaseResponse(status="input received successfully", id=request_id)
else:
raise HTTPException(status_code=403, detail="you are not allowed")
# return ner_response.predict_json(input_json)
@app.post("/oie", response_model=BaseResponse, status_code=HTTP_201_CREATED)
async def oie(input_json: InputNerType, background_tasks: BackgroundTasks, request: Request) -> BaseResponse:
if database.check_client(client_id=request.headers.get('Client-Id'),
client_password=request.headers.get('Client-Password')) and oie_model_name != "":
request_id = requests.add_request_oie(input_json)
background_tasks.add_task(oie_response.predict_json, input_json, request_id)
return BaseResponse(status="input received successfully", id=request_id)
else:
raise HTTPException(status_code=403, detail="you are not allowed")
# return ner_response.predict_json(input_json)
@app.post("/nre", response_model=BaseResponse, status_code=HTTP_201_CREATED)
async def oie(input_json: InputNREType, background_tasks: BackgroundTasks, request: Request) -> BaseResponse:
if database.check_client(client_id=request.headers.get('Client-Id'),
client_password=request.headers.get('Client-Password')) and nre_model_name != "":
request_id = requests.add_request_nre(input_json)
background_tasks.add_task(nre_response.predict_json, input_json, request_id)
return BaseResponse(status="input received successfully", id=request_id)
else:
raise HTTPException(status_code=403, detail="you are not allowed")
# return ner_response.predict_json(input_json)
@app.post("/summary/short", response_model=BaseResponse, status_code=HTTP_201_CREATED)
async def short_summary(input_json: InputSummaryType, background_tasks: BackgroundTasks,
request: Request) -> BaseResponse:
if database.check_client(client_id=request.headers.get('Client-Id'),
client_password=request.headers.get('Client-Password')) and short_model_name != "":
request_id = requests.add_request_summary(input_json)
background_tasks.add_task(short_response.get_result, input_json, request_id)
return BaseResponse(status="input received successfully", id=request_id)
else:
raise HTTPException(status_code=403, detail="you are not allowed")
# return ner_response.predict_json(input_json)
@app.post("/summary/long", response_model=BaseResponse, status_code=HTTP_201_CREATED)
async def long_summary(input_json: InputSummaryType, background_tasks: BackgroundTasks,
request: Request) -> BaseResponse:
if database.check_client(client_id=request.headers.get('Client-Id'),
client_password=request.headers.get('Client-Password')) and long_model_name != "":
request_id = requests.add_request_summary(input_json)
background_tasks.add_task(long_response.get_result, input_json, request_id)
return BaseResponse(status="input received successfully", id=request_id)
else:
raise HTTPException(status_code=403, detail="you are not allowed")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)

View File

@ -1,28 +0,0 @@
from typing import List, Dict, Tuple
from pydantic import BaseModel
class InputElement(BaseModel):
lang: str
text: str
class NREHType(BaseModel):
pos: str
class NRETType(BaseModel):
pos: str
class InputElementNRE(BaseModel):
text: str
h: NREHType
t: NRETType
InputNerType = List[InputElement]
InputOIEType = List[InputElement]
InputNREType = List[InputElementNRE]
InputSummaryType = List[str]

View File

@ -1,43 +0,0 @@
from typing import List, Optional, Tuple
from pydantic import BaseModel
class EntityNerResponseModel(BaseModel):
entity_group: str
word: str
start: int
end: int
score: float
class EntityOIEResponseModel(BaseModel):
score: str
relation: str
arg1: str
arg2: str
class BaseResponse(BaseModel):
id: int
status: str
class ResponseNerModel(BaseModel):
progression: str
result: Optional[List[List[EntityNerResponseModel]]]
class ResponseSummaryModel(BaseModel):
progression: str
result: Optional[List[str]]
class ResponseOIEModel(BaseModel):
progression: str
result: Optional[List[EntityOIEResponseModel]]
class ResponseNREModel(BaseModel):
progression: str
result: Optional[List[Tuple[str, float]]]

View File

@ -1,672 +0,0 @@
{
"openapi": "3.0.2",
"info": {
"title": "FastAPI",
"version": "0.1.0"
},
"paths": {
"/": {
"get": {
"summary": "Root",
"operationId": "root__get",
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {}
}
}
}
}
}
},
"/ner/{item_id}": {
"get": {
"summary": "Read Item",
"operationId": "read_item_ner__item_id__get",
"parameters": [
{
"required": true,
"schema": {
"title": "Item Id"
},
"name": "item_id",
"in": "path"
}
],
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ResponseNerModel"
}
}
}
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/HTTPValidationError"
}
}
}
}
}
}
},
"/oie/{item_id}": {
"get": {
"summary": "Read Item",
"operationId": "read_item_oie__item_id__get",
"parameters": [
{
"required": true,
"schema": {
"title": "Item Id"
},
"name": "item_id",
"in": "path"
}
],
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ResponseOIEModel"
}
}
}
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/HTTPValidationError"
}
}
}
}
}
}
},
"/nre/{item_id}": {
"get": {
"summary": "Read Item",
"operationId": "read_item_nre__item_id__get",
"parameters": [
{
"required": true,
"schema": {
"title": "Item Id"
},
"name": "item_id",
"in": "path"
}
],
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ResponseNREModel"
}
}
}
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/HTTPValidationError"
}
}
}
}
}
}
},
"/summary/{item_id}": {
"get": {
"summary": "Read Item",
"operationId": "read_item_summary__item_id__get",
"parameters": [
{
"required": true,
"schema": {
"title": "Item Id"
},
"name": "item_id",
"in": "path"
}
],
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ResponseSummaryModel"
}
}
}
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/HTTPValidationError"
}
}
}
}
}
}
},
"/ner": {
"post": {
"summary": "Ner",
"operationId": "ner_ner_post",
"requestBody": {
"content": {
"application/json": {
"schema": {
"title": "Input Json",
"type": "array",
"items": {
"$ref": "#/components/schemas/InputElement"
}
}
}
},
"required": true
},
"responses": {
"201": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/BaseResponse"
}
}
}
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/HTTPValidationError"
}
}
}
}
}
}
},
"/oie": {
"post": {
"summary": "Oie",
"operationId": "oie_oie_post",
"requestBody": {
"content": {
"application/json": {
"schema": {
"title": "Input Json",
"type": "array",
"items": {
"$ref": "#/components/schemas/InputElement"
}
}
}
},
"required": true
},
"responses": {
"201": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/BaseResponse"
}
}
}
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/HTTPValidationError"
}
}
}
}
}
}
},
"/nre": {
"post": {
"summary": "Oie",
"operationId": "oie_nre_post",
"requestBody": {
"content": {
"application/json": {
"schema": {
"title": "Input Json",
"type": "array",
"items": {
"$ref": "#/components/schemas/InputElementNRE"
}
}
}
},
"required": true
},
"responses": {
"201": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/BaseResponse"
}
}
}
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/HTTPValidationError"
}
}
}
}
}
}
},
"/summary/short": {
"post": {
"summary": "Short Summary",
"operationId": "short_summary_summary_short_post",
"requestBody": {
"content": {
"application/json": {
"schema": {
"title": "Input Json",
"type": "array",
"items": {
"type": "string"
}
}
}
},
"required": true
},
"responses": {
"201": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/BaseResponse"
}
}
}
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/HTTPValidationError"
}
}
}
}
}
}
},
"/summary/long": {
"post": {
"summary": "Long Summary",
"operationId": "long_summary_summary_long_post",
"requestBody": {
"content": {
"application/json": {
"schema": {
"title": "Input Json",
"type": "array",
"items": {
"type": "string"
}
}
}
},
"required": true
},
"responses": {
"201": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/BaseResponse"
}
}
}
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/HTTPValidationError"
}
}
}
}
}
}
}
},
"components": {
"schemas": {
"BaseResponse": {
"title": "BaseResponse",
"required": [
"id",
"status"
],
"type": "object",
"properties": {
"id": {
"title": "Id",
"type": "integer"
},
"status": {
"title": "Status",
"type": "string"
}
}
},
"EntityNerResponseModel": {
"title": "EntityNerResponseModel",
"required": [
"entity_group",
"word",
"start",
"end",
"score"
],
"type": "object",
"properties": {
"entity_group": {
"title": "Entity Group",
"type": "string"
},
"word": {
"title": "Word",
"type": "string"
},
"start": {
"title": "Start",
"type": "integer"
},
"end": {
"title": "End",
"type": "integer"
},
"score": {
"title": "Score",
"type": "number"
}
}
},
"EntityOIEResponseModel": {
"title": "EntityOIEResponseModel",
"required": [
"score",
"relation",
"arg1",
"arg2"
],
"type": "object",
"properties": {
"score": {
"title": "Score",
"type": "string"
},
"relation": {
"title": "Relation",
"type": "string"
},
"arg1": {
"title": "Arg1",
"type": "string"
},
"arg2": {
"title": "Arg2",
"type": "string"
}
}
},
"HTTPValidationError": {
"title": "HTTPValidationError",
"type": "object",
"properties": {
"detail": {
"title": "Detail",
"type": "array",
"items": {
"$ref": "#/components/schemas/ValidationError"
}
}
}
},
"InputElement": {
"title": "InputElement",
"required": [
"lang",
"text"
],
"type": "object",
"properties": {
"lang": {
"title": "Lang",
"type": "string"
},
"text": {
"title": "Text",
"type": "string"
}
}
},
"InputElementNRE": {
"title": "InputElementNRE",
"required": [
"text",
"h",
"t"
],
"type": "object",
"properties": {
"text": {
"title": "Text",
"type": "string"
},
"h": {
"$ref": "#/components/schemas/NREHType"
},
"t": {
"$ref": "#/components/schemas/NRETType"
}
}
},
"NREHType": {
"title": "NREHType",
"required": [
"pos"
],
"type": "object",
"properties": {
"pos": {
"title": "Pos",
"type": "string"
}
}
},
"NRETType": {
"title": "NRETType",
"required": [
"pos"
],
"type": "object",
"properties": {
"pos": {
"title": "Pos",
"type": "string"
}
}
},
"ResponseNREModel": {
"title": "ResponseNREModel",
"required": [
"progression"
],
"type": "object",
"properties": {
"progression": {
"title": "Progression",
"type": "string"
},
"result": {
"title": "Result",
"type": "array",
"items": {
"type": "array",
"items": [
{
"type": "string"
},
{
"type": "number"
}
]
}
}
}
},
"ResponseNerModel": {
"title": "ResponseNerModel",
"required": [
"progression"
],
"type": "object",
"properties": {
"progression": {
"title": "Progression",
"type": "string"
},
"result": {
"title": "Result",
"type": "array",
"items": {
"type": "array",
"items": {
"$ref": "#/components/schemas/EntityNerResponseModel"
}
}
}
}
},
"ResponseOIEModel": {
"title": "ResponseOIEModel",
"required": [
"progression"
],
"type": "object",
"properties": {
"progression": {
"title": "Progression",
"type": "string"
},
"result": {
"title": "Result",
"type": "array",
"items": {
"$ref": "#/components/schemas/EntityOIEResponseModel"
}
}
}
},
"ResponseSummaryModel": {
"title": "ResponseSummaryModel",
"required": [
"progression"
],
"type": "object",
"properties": {
"progression": {
"title": "Progression",
"type": "string"
},
"result": {
"title": "Result",
"type": "array",
"items": {
"type": "string"
}
}
}
},
"ValidationError": {
"title": "ValidationError",
"required": [
"loc",
"msg",
"type"
],
"type": "object",
"properties": {
"loc": {
"title": "Location",
"type": "array",
"items": {
"anyOf": [
{
"type": "string"
},
{
"type": "integer"
}
]
}
},
"msg": {
"title": "Message",
"type": "string"
},
"type": {
"title": "Error Type",
"type": "string"
}
}
}
}
}
}

View File

@ -1 +0,0 @@
web : gunicorn -w 4 -k uvicorn.workers.UvicornWorker main:app

View File

@ -1,14 +0,0 @@
requests==2.28.0
flair==0.10
langdetect==1.0.9
transformers
pydantic==1.8.2
uvicorn==0.17.6
torch>=1.12.1+cu116
fastapi==0.78.0
protobuf==3.20.0
sacremoses
nltk>=3.3
tqdm>=4.64.0
setuptools>=62.6.0
starlette>=0.19.1

View File

@ -1,21 +0,0 @@
MIT License
Copyright (c) 2020 Youngbin Ro
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -1,372 +0,0 @@
import torch
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertModel
from torch.nn.modules.container import ModuleList
import copy
# from dataset import load_data
from transformers import BertTokenizer
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from tqdm import tqdm
import numpy as np
import src.Multi2OIE.utils.bio as bio
# from extract import extract
def _get_clones(module, n):
return ModuleList([copy.deepcopy(module) for _ in range(n)])
def _get_position_idxs(pred_mask, input_ids):
position_idxs = torch.zeros(pred_mask.shape, dtype=int, device=pred_mask.device)
for mask_idx, cur_mask in enumerate(pred_mask):
position_idxs[mask_idx, :] += 2
cur_nonzero = (cur_mask == 0).nonzero()
start = torch.min(cur_nonzero).item()
end = torch.max(cur_nonzero).item()
position_idxs[mask_idx, start:end + 1] = 1
pad_start = max(input_ids[mask_idx].nonzero()).item() + 1
position_idxs[mask_idx, pad_start:] = 0
return position_idxs
def _get_pred_feature(pred_hidden, pred_mask):
B, L, D = pred_hidden.shape
pred_features = torch.zeros((B, L, D), device=pred_mask.device)
for mask_idx, cur_mask in enumerate(pred_mask):
pred_position = (cur_mask == 0).nonzero().flatten()
pred_feature = torch.mean(pred_hidden[mask_idx, pred_position], dim=0)
pred_feature = torch.cat(L * [pred_feature.unsqueeze(0)])
pred_features[mask_idx, :, :] = pred_feature
return pred_features
def _get_activation_fn(activation):
if activation == "relu":
return F.relu
elif activation == "gelu":
return F.gelu
else:
raise RuntimeError("activation should be relu/gelu, not %s." % activation)
class ArgExtractorLayer(nn.Module):
def __init__(self,
d_model=768,
n_heads=8,
d_feedforward=2048,
dropout=0.1,
activation='relu'):
"""
A layer similar to Transformer decoder without decoder self-attention.
(only encoder-decoder multi-head attention followed by feed-forward layers)
:param d_model: model dimensionality (default=768 from BERT-base)
:param n_heads: number of heads in multi-head attention layer
:param d_feedforward: dimensionality of point-wise feed-forward layer
:param dropout: drop rate of all layers
:param activation: activation function after first feed-forward layer
"""
super(ArgExtractorLayer, self).__init__()
self.multihead_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
self.linear1 = nn.Linear(d_model, d_feedforward)
self.dropout1 = nn.Dropout(dropout)
self.linear2 = nn.Linear(d_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
def forward(self, target, source, key_mask=None):
"""
Single Transformer Decoder layer without self-attention
:param target: a tensor which takes a role as a query
:param source: a tensor which takes a role as a key & value
:param key_mask: key mask tensor with the shape of (batch_size, sequence_length)
"""
# Multi-head attention layer (+ add & norm)
attended = self.multihead_attn(
target, source, source,
key_padding_mask=key_mask)[0]
skipped = target + self.dropout1(attended)
normed = self.norm1(skipped)
# Point-wise feed-forward layer (+ add & norm)
projected = self.linear2(self.dropout2(self.activation(self.linear1(normed))))
skipped = normed + self.dropout1(projected)
normed = self.norm2(skipped)
return normed
class ArgModule(nn.Module):
def __init__(self, arg_layer, n_layers):
"""
Module for extracting arguments based on given encoder output and predicates.
It uses ArgExtractorLayer as a base block and repeat the block N('n_layers') times
:param arg_layer: an instance of the ArgExtractorLayer() class (required)
:param n_layers: the number of sub-layers in the ArgModule (required).
"""
super(ArgModule, self).__init__()
self.layers = _get_clones(arg_layer, n_layers)
self.n_layers = n_layers
def forward(self, encoded, predicate, pred_mask=None):
"""
:param encoded: output from sentence encoder with the shape of (L, B, D),
where L is the sequence length, B is the batch size, D is the embedding dimension
:param predicate: output from predicate module with the shape of (L, B, D)
:param pred_mask: mask that prevents attention to tokens which are not predicates
with the shape of (B, L)
:return: tensor like Transformer Decoder Layer Output
"""
output = encoded
for layer_idx in range(self.n_layers):
output = self.layers[layer_idx](
target=output, source=predicate, key_mask=pred_mask)
return output
class Multi2OIE(nn.Module):
def __init__(self,
bert_config,
mh_dropout=0.1,
pred_clf_dropout=0.,
arg_clf_dropout=0.3,
n_arg_heads=8,
n_arg_layers=4,
pos_emb_dim=64,
pred_n_labels=3,
arg_n_labels=9):
super(Multi2OIE, self).__init__()
self.pred_n_labels = pred_n_labels
self.arg_n_labels = arg_n_labels
self.bert = BertModel.from_pretrained(
bert_config,
output_hidden_states=True)
d_model = self.bert.config.hidden_size
self.pred_dropout = nn.Dropout(pred_clf_dropout)
self.pred_classifier = nn.Linear(d_model, self.pred_n_labels)
self.position_emb = nn.Embedding(3, pos_emb_dim, padding_idx=0)
d_model += (d_model + pos_emb_dim)
arg_layer = ArgExtractorLayer(
d_model=d_model,
n_heads=n_arg_heads,
dropout=mh_dropout)
self.arg_module = ArgModule(arg_layer, n_arg_layers)
self.arg_dropout = nn.Dropout(arg_clf_dropout)
self.arg_classifier = nn.Linear(d_model, arg_n_labels)
def forward(self,
input_ids,
attention_mask,
predicate_mask=None,
predicate_hidden=None,
total_pred_labels=None,
arg_labels=None):
# predicate extraction
bert_hidden = self.bert(input_ids, attention_mask)[0]
pred_logit = self.pred_classifier(self.pred_dropout(bert_hidden))
# predicate loss
if total_pred_labels is not None:
loss_fct = nn.CrossEntropyLoss()
active_loss = attention_mask.view(-1) == 1
active_logits = pred_logit.view(-1, self.pred_n_labels)
active_labels = torch.where(
active_loss, total_pred_labels.view(-1),
torch.tensor(loss_fct.ignore_index).type_as(total_pred_labels))
pred_loss = loss_fct(active_logits, active_labels)
# inputs for argument extraction
pred_feature = _get_pred_feature(bert_hidden, predicate_mask)
position_vectors = self.position_emb(_get_position_idxs(predicate_mask, input_ids))
bert_hidden = torch.cat([bert_hidden, pred_feature, position_vectors], dim=2)
bert_hidden = bert_hidden.transpose(0, 1)
# argument extraction
arg_hidden = self.arg_module(bert_hidden, bert_hidden, predicate_mask)
arg_hidden = arg_hidden.transpose(0, 1)
arg_logit = self.arg_classifier(self.arg_dropout(arg_hidden))
# argument loss
if arg_labels is not None:
loss_fct = nn.CrossEntropyLoss()
active_loss = attention_mask.view(-1) == 1
active_logits = arg_logit.view(-1, self.arg_n_labels)
active_labels = torch.where(
active_loss, arg_labels.view(-1),
torch.tensor(loss_fct.ignore_index).type_as(arg_labels))
arg_loss = loss_fct(active_logits, active_labels)
# total loss
batch_loss = pred_loss + arg_loss
outputs = (batch_loss, pred_loss, arg_loss)
return outputs
def extract_predicate(self,
input_ids,
attention_mask):
bert_hidden = self.bert(input_ids, attention_mask)[0]
pred_logit = self.pred_classifier(bert_hidden)
return pred_logit, bert_hidden
def extract_argument(self,
input_ids,
predicate_hidden,
predicate_mask):
pred_feature = _get_pred_feature(predicate_hidden, predicate_mask)
position_vectors = self.position_emb(_get_position_idxs(predicate_mask, input_ids))
arg_input = torch.cat([predicate_hidden, pred_feature, position_vectors], dim=2)
arg_input = arg_input.transpose(0, 1)
arg_hidden = self.arg_module(arg_input, arg_input, predicate_mask)
arg_hidden = arg_hidden.transpose(0, 1)
return self.arg_classifier(arg_hidden)
class OieEvalDataset(Dataset):
def __init__(self, sentences, max_len, tokenizer_config):
self.sentences = sentences
self.tokenizer = BertTokenizer.from_pretrained(tokenizer_config)
self.vocab = self.tokenizer.vocab
self.max_len = max_len
self.pad_idx = self.vocab['[PAD]']
self.cls_idx = self.vocab['[CLS]']
self.sep_idx = self.vocab['[SEP]']
self.mask_idx = self.vocab['[MASK]']
def add_pad(self, token_ids):
diff = self.max_len - len(token_ids)
if diff > 0:
token_ids += [self.pad_idx] * diff
else:
token_ids = token_ids[:self.max_len - 1] + [self.sep_idx]
return token_ids
def idx2mask(self, token_ids):
return [token_id != self.pad_idx for token_id in token_ids]
def __len__(self):
return len(self.sentences)
def __getitem__(self, idx):
token_ids = self.add_pad(self.tokenizer.encode(self.sentences[idx]))
att_mask = self.idx2mask(token_ids)
token_strs = self.tokenizer.convert_ids_to_tokens(token_ids)
sentence = self.sentences[idx]
assert len(token_ids) == self.max_len
assert len(att_mask) == self.max_len
assert len(token_strs) == self.max_len
batch = [
torch.tensor(token_ids),
torch.tensor(att_mask),
token_strs,
sentence
]
return batch
def extract(model, loader, device, tokenizer):
# model.eval()
result = []
for step, batch in tqdm(enumerate(loader), desc='eval_steps', total=len(loader)):
token_strs = [[word for word in sent] for sent in np.asarray(batch[-2]).T]
sentences = batch[-1]
print(sentences)
token_ids, att_mask = map(lambda x: x.to(device), batch[:-2])
with torch.no_grad():
"""
We will iterate B(batch_size) times
because there are more than one predicate in one batch.
In feeding to argument extractor, # of predicates takes a role as batch size.
pred_logit: (B, L, 3)
pred_hidden: (B, L, D)
pred_tags: (B, P, L) ~ list of tensors, where P is # of predicate in each batch
"""
pred_logit, pred_hidden = model.extract_predicate(
input_ids=token_ids, attention_mask=att_mask)
pred_tags = torch.argmax(pred_logit, 2)
pred_tags = bio.filter_pred_tags(pred_tags, token_strs)
pred_tags = bio.get_single_predicate_idxs(pred_tags)
pred_probs = torch.nn.Softmax(2)(pred_logit)
# iterate B times (one iteration means extraction for one sentence)
for cur_pred_tags, cur_pred_hidden, cur_att_mask, cur_token_id, cur_pred_probs, token_str, sentence \
in zip(pred_tags, pred_hidden, att_mask, token_ids, pred_probs, token_strs, sentences):
# generate temporary batch for this sentence and feed to argument module
cur_pred_masks = bio.get_pred_mask(cur_pred_tags).to(device)
n_predicates = cur_pred_masks.shape[0]
if n_predicates == 0:
continue # if there is no predicate, we cannot extract.
cur_pred_hidden = torch.cat(n_predicates * [cur_pred_hidden.unsqueeze(0)])
cur_token_id = torch.cat(n_predicates * [cur_token_id.unsqueeze(0)])
cur_arg_logit = model.extract_argument(
input_ids=cur_token_id,
predicate_hidden=cur_pred_hidden,
predicate_mask=cur_pred_masks)
# filter and get argument tags with highest probability
cur_arg_tags = torch.argmax(cur_arg_logit, 2)
cur_arg_probs = torch.nn.Softmax(2)(cur_arg_logit)
cur_arg_tags = bio.filter_arg_tags(cur_arg_tags, cur_pred_tags, token_str)
# get string tuples and write results
cur_extractions, cur_extraction_idxs = bio.get_tuple(sentence, cur_pred_tags, cur_arg_tags, tokenizer)
cur_confidences = bio.get_confidence_score(cur_pred_probs, cur_arg_probs, cur_extraction_idxs)
for extraction, confidence in zip(cur_extractions, cur_confidences):
# print('\n')
# print("\t".join([sentence] + [str(1.0)] + extraction[:3]))
res_dict = {
'score': str(confidence),
'relation': extraction[0],
'arg1': extraction[1],
'arg2': extraction[2]
}
result.append(res_dict)
# print("\nExtraction Done.\n")
return result
class OIE:
def __init__(self, model_path, BERT, batch_size):
if torch.cuda.is_available():
self.device = torch.device("cuda")
else:
self.device = torch.device("cpu")
self.model = Multi2OIE(bert_config=BERT).to(self.device)
model_weights = torch.load(model_path, map_location=torch.device(self.device))
self.model.load_state_dict(model_weights)
self.max_len = 64
self.batch_size = batch_size
self.tokenizer_config = BERT
self.tokenizer = BertTokenizer.from_pretrained(BERT)
def predict(self, sentences):
loader = DataLoader(
dataset=OieEvalDataset(
sentences,
self.max_len,
self.tokenizer_config),
batch_size=self.batch_size,
num_workers=4,
pin_memory=True)
return extract(model=self.model, loader=loader, tokenizer=self.tokenizer, device=self.device)

View File

@ -1,182 +0,0 @@
# Multi^2OIE: <u>Multi</u>lingual Open Information Extraction Based on <u>Multi</u>-Head Attention with BERT
> Source code for learning Multi^2OIE for (multilingual) open information extraction.
## Paper
[**Multi^2OIE: <u>Multi</u>lingual Open Information Extraction Based on <u>Multi</u>-Head Attention with BERT**](https://arxiv.org/abs/2009.08128)<br>
[Youngbin Ro](https://github.com/youngbin-ro), [Yukyung Lee](https://github.com/yukyunglee), and [Pilsung Kang](https://github.com/pilsung-kang)*<br>
Accepted to Findings of ACL: EMNLP 2020. (*corresponding author)
<br>
## Overview
### What is Open Information Extraction (Open IE)?
[Niklaus et al. (2018)](https://www.aclweb.org/anthology/C18-1326/) describes Open IE as follows:
> Information extraction (IE) **<u>turns the unstructured information expressed in natural language text into a structured representation</u>** in the form of relational tuples consisting of a set of arguments and a phrase denoting a semantic relation between them: <arg1; rel; arg2>. (...) Unlike traditional IE methods, Open IE is **<u>not limited to a small set of target relations</u>** known in advance, but rather extracts all types of relations found in a text.
![openie_overview](https://github.com/youngbin-ro/Multi2OIE/blob/master/images/openie_overview.PNG?raw=true)
#### Note
- Systems adopting sequence generation scheme ([Cui et al., 2018](https://www.aclweb.org/anthology/P18-2065/); [Kolluru et al., 2020](https://www.aclweb.org/anthology/2020.acl-main.521/)) can extract (actually generate) relations outside of given texts.
- Multi^2OIE, however, is adopting sequence labeling scheme ([Stanovsky et al., 2018](https://www.aclweb.org/anthology/N18-1081/)) for computational efficiency and multilingual ability
### Our Approach
![multi2oie_overview](https://github.com/youngbin-ro/Multi2OIE/blob/master/images/multi2oie_overview.PNG?raw=true)
#### Step 1: Extract predicates (relations) from the input sentence using BERT
- Conduct token-level classification on the BERT output sequence
- Use BIO Tagging for representing arguments and predicates
#### Step 2: Extract arguments using multi-head attention blocks
- Concatenate BERT whole hidden sequence, average vector of hidden sequence at predicate position, and binary embedding vector indicating the token is included in predicate span.
- Apply multi-head attention operation over N times
- Query: whole hidden sequence
- Key-Value pairs: hidden states of predicate positions
- Conduct token-level classification on the multi-head attention output sequence
#### Multilingual Extraction
- Replace English BERT to Multilingual BERT
- Train the model only with English data
- Test the model in three difference languages (English, Spanish, and Portuguese) in zero-shot manner.
<br>
## Usage
### Prerequisites
- Python 3.7
- CUDA 10.0 or above
### Environmental Setup
#### Install
##### using 'conda' command,
~~~~
# this makes a new conda environment
conda env create -f environment.yml
conda activate multi2oie
~~~~
##### using 'pip' command,
~~~~
pip install -r requirements.txt
~~~~
#### NLTK setup
```
python -c "import nltk; nltk.download('stopwords')"
```
### Datasets
#### Dataset Released
- `openie4_train.pkl`: https://drive.google.com/file/d/1DrWj1CjLFIno-UBfLI3_uIratN4QY6Y3/view?usp=sharing
#### Do-it-yourself
Original data file (bootstrapped sample from OpenIE4; used in SpanOIE) can be downloaded from [here](https://drive.google.com/file/d/1AEfwbh3BQnsv2VM977cS4tEoldrayKB6/view).
Following download, put the downloaded data in './datasets' and use preprocess.py to convert the data into the format suitable for Multi^2OIE.
~~~~
cd utils
python preprocess.py \
--mode 'train' \
--data '../datasets/structured_data.json' \
--save_path '../datasets/openie4_train.pkl' \
--bert_config 'bert-base-cased' \
--max_len 64
~~~~
For multilingual training data, set **'bert_config'** as **'bert-base-multilingual-cased'**.
### Run the Code
#### Model Released
- English Model: https://drive.google.com/file/d/11BaLuGjMVVB16WHcyaHWLgL6dg0_9xHQ/view?usp=sharing
- Multilingual Model: https://drive.google.com/file/d/1lHQeetbacFOqvyPQ3ZzVUGPgn-zwTRA_/view?usp=sharing
We used TITAN RTX GPU for training, and the use of other GPU can make the final performance different.
##### for training,
~~~~
python main.py [--FLAGS]
~~~~
##### for testing,
~~~~
python test.py [--FLAGS]
~~~~
<br>
## Model Configurations
### # of Parameters
- Original BERT: 110M
- \+ Multi-Head Attention Blocks: 66M
### Hyper-parameters {& searching bounds}
- epochs: 1 {**1**, 2, 3}
- dropout rate for multi-head attention blocks: 0.2 {0.0, 0.1, **0.2**}
- dropout rate for argument classifier: 0.2 {0.0, 0.1, **0.2**, 0.3}
- batch size: 128 {64, **128**, 256, 512}
- learning rate: 3e-5 {2e-5, **3e-5**, 5e-5}
- number of multi-head attention heads: 8 {4, **8**}
- number of multi-head attention blocks: 4 {2, **4**, 8}
- position embedding dimension: 64 {**64**, 128, 256}
- gradient clipping norm: 1.0 (not tuned)
- learning rate warm-up steps: 10% of total steps (not tuned)
- for other unspecified parameters, the default values can be used.
<br>
## Expected Results
### Development set
#### OIE2016
- F1: 71.7
- AUC: 55.4
#### CaRB
- F1: 54.3
- AUC: 34.8
### Testing set
#### Re-OIE2016
- F1: 83.9
- AUC: 74.6
#### CaRB
- F1: 52.3
- AUC: 32.6
<br>
## References
- https://github.com/gabrielStanovsky/oie-benchmark
- https://github.com/dair-iitd/CaRB
- https://github.com/zhanjunlang/Span_OIE

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,21 +0,0 @@
import nltk
from operator import itemgetter
class Argument:
def __init__(self, arg):
self.words = [x for x in arg[0].strip().split(' ') if x]
self.posTags = map(itemgetter(1), nltk.pos_tag(self.words))
self.indices = arg[1]
self.feats = {}
def __str__(self):
return "({})".format('\t'.join(map(str,
[escape_special_chars(' '.join(self.words)),
str(self.indices)])))
COREF = 'coref'
## Helper functions
def escape_special_chars(s):
return s.replace('\t', '\\t')

View File

@ -1,380 +0,0 @@
'''
Usage:
benchmark --gold=GOLD_OIE --out=OUTPUT_FILE (--openiefive=OPENIE5 | --stanford=STANFORD_OIE | --ollie=OLLIE_OIE |--reverb=REVERB_OIE | --clausie=CLAUSIE_OIE | --openiefour=OPENIEFOUR_OIE | --props=PROPS_OIE | --tabbed=TABBED_OIE | --benchmarkGold=BENCHMARK_GOLD | --allennlp=ALLENNLP_OIE ) [--exactMatch | --predMatch | --lexicalMatch | --binaryMatch | --simpleMatch | --strictMatch] [--error-file=ERROR_FILE] [--binary]
Options:
--gold=GOLD_OIE The gold reference Open IE file (by default, it should be under ./oie_corpus/all.oie).
--benchmarkgold=GOLD_OIE The benchmark's gold reference.
--out-OUTPUT_FILE The output file, into which the precision recall curve will be written.
--clausie=CLAUSIE_OIE Read ClausIE format from file CLAUSIE_OIE.
--ollie=OLLIE_OIE Read OLLIE format from file OLLIE_OIE.
--openiefour=OPENIEFOUR_OIE Read Open IE 4 format from file OPENIEFOUR_OIE.
--openiefive=OPENIE5 Read Open IE 5 format from file OPENIE5.
--props=PROPS_OIE Read PropS format from file PROPS_OIE
--reverb=REVERB_OIE Read ReVerb format from file REVERB_OIE
--stanford=STANFORD_OIE Read Stanford format from file STANFORD_OIE
--tabbed=TABBED_OIE Read simple tab format file, where each line consists of:
sent, prob, pred,arg1, arg2, ...
--exactmatch Use exact match when judging whether an extraction is correct.
'''
from __future__ import division
import docopt
import string
import numpy as np
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import auc, roc_auc_score
import re
import logging
import pdb
import ipdb
from _collections import defaultdict
from carb.goldReader import GoldReader
from carb.gold_relabel import Relabel_GoldReader
logging.basicConfig(level = logging.INFO)
from operator import itemgetter
import pprint
from copy import copy
pp = pprint.PrettyPrinter(indent=4)
class Benchmark:
''' Compare the gold OIE dataset against a predicted equivalent '''
def __init__(self, gold_fn):
''' Load gold Open IE, this will serve to compare against using the compare function '''
if 'Re-OIE2016' in gold_fn:
gr = Relabel_GoldReader()
else:
gr = GoldReader()
gr.read(gold_fn)
self.gold = gr.oie
def compare(self, predicted, matchingFunc, output_fn, error_file = None, binary=False):
''' Compare gold against predicted using a specified matching function.
Outputs PR curve to output_fn '''
y_true = []
y_scores = []
errors = []
correct = 0
incorrect = 0
correctTotal = 0
unmatchedCount = 0
predicted = Benchmark.normalizeDict(predicted)
gold = Benchmark.normalizeDict(self.gold)
if binary:
predicted = Benchmark.binarize(predicted)
gold = Benchmark.binarize(gold)
#gold = self.gold
# taking all distinct values of confidences as thresholds
confidence_thresholds = set()
for sent in predicted:
for predicted_ex in predicted[sent]:
confidence_thresholds.add(predicted_ex.confidence)
confidence_thresholds = sorted(list(confidence_thresholds))
num_conf = len(confidence_thresholds)
results = {}
p = np.zeros(num_conf)
pl = np.zeros(num_conf)
r = np.zeros(num_conf)
rl = np.zeros(num_conf)
for sent, goldExtractions in gold.items():
if sent in predicted:
predictedExtractions = predicted[sent]
else:
predictedExtractions = []
scores = [[None for _ in predictedExtractions] for __ in goldExtractions]
# print("***Gold Extractions***")
# print("\n".join([goldExtractions[i].pred + ' ' + " ".join(goldExtractions[i].args) for i in range(len(goldExtractions))]))
# print("***Predicted Extractions***")
# print("\n".join([predictedExtractions[i].pred+ " ".join(predictedExtractions[i].args) for i in range(len(predictedExtractions))]))
for i, goldEx in enumerate(goldExtractions):
for j, predictedEx in enumerate(predictedExtractions):
score = matchingFunc(goldEx, predictedEx,ignoreStopwords = True,ignoreCase = True)
scores[i][j] = score
# OPTIMISED GLOBAL MATCH
sent_confidences = [extraction.confidence for extraction in predictedExtractions]
sent_confidences.sort()
prev_c = 0
for conf in sent_confidences:
c = confidence_thresholds.index(conf)
ext_indices = []
for ext_indx, extraction in enumerate(predictedExtractions):
if extraction.confidence >= conf:
ext_indices.append(ext_indx)
recall_numerator = 0
for i, row in enumerate(scores):
max_recall_row = max([row[ext_indx][1] for ext_indx in ext_indices ], default=0)
recall_numerator += max_recall_row
precision_numerator = 0
selected_rows = []
selected_cols = []
num_precision_matches = min(len(scores), len(ext_indices))
for t in range(num_precision_matches):
matched_row = -1
matched_col = -1
matched_precision = -1 # initialised to <0 so that it updates whenever precision is 0 as well
for i in range(len(scores)):
if i in selected_rows:
continue
for ext_indx in ext_indices:
if ext_indx in selected_cols:
continue
if scores[i][ext_indx][0] > matched_precision:
matched_precision = scores[i][ext_indx][0]
matched_row = i
matched_col = ext_indx
selected_rows.append(matched_row)
selected_cols.append(matched_col)
precision_numerator += scores[matched_row][matched_col][0]
p[prev_c:c+1] += precision_numerator
pl[prev_c:c+1] += len(ext_indices)
r[prev_c:c+1] += recall_numerator
rl[prev_c:c+1] += len(scores)
prev_c = c+1
# for indices beyond the maximum sentence confidence, len(scores) has to be added to the denominator of recall
rl[prev_c:] += len(scores)
prec_scores = [a/b if b>0 else 1 for a,b in zip(p,pl) ]
rec_scores = [a/b if b>0 else 0 for a,b in zip(r,rl)]
f1s = [Benchmark.f1(p,r) for p,r in zip(prec_scores, rec_scores)]
try:
optimal_idx = np.nanargmax(f1s)
optimal = (prec_scores[optimal_idx], rec_scores[optimal_idx], f1s[optimal_idx])
except ValueError:
# When there is no prediction
optimal = (0,0,0)
# In order to calculate auc, we need to add the point corresponding to precision=1 , recall=0 to the PR-curve
temp_rec_scores = rec_scores.copy()
temp_prec_scores = prec_scores.copy()
temp_rec_scores.append(0)
temp_prec_scores.append(1)
# print("AUC: {}\t Optimal (precision, recall, F1): {}".format( np.round(auc(temp_rec_scores, temp_prec_scores),3), np.round(optimal,3) ))
with open(output_fn, 'w') as fout:
fout.write('{0}\t{1}\t{2}\n'.format("Precision", "Recall", "Confidence"))
for cur_p, cur_r, cur_conf in sorted(zip(prec_scores, rec_scores, confidence_thresholds), key = lambda cur: cur[1]):
fout.write('{0}\t{1}\t{2}\n'.format(cur_p, cur_r, cur_conf))
if len(f1s)>0:
rec_prec_dict = {rec: prec for rec, prec in zip(temp_rec_scores, temp_prec_scores)}
rec_prec_dict = sorted(rec_prec_dict.items(), key=lambda x: x[0])
temp_rec_scores = [rec for rec, _ in rec_prec_dict]
temp_prec_scores = [prec for _, prec in rec_prec_dict]
return np.round(auc(temp_rec_scores, temp_prec_scores),5), np.round(optimal,5)
else:
# When there is no prediction
return 0, (0,0,0)
@staticmethod
def binarize(extrs):
res = defaultdict(lambda: [])
for sent,extr in extrs.items():
for ex in extr:
#Add (a1, r, a2)
temp = copy(ex)
temp.args = ex.args[:2]
res[sent].append(temp)
if len(ex.args) <= 2:
continue
#Add (a1, r a2 , a3 ...)
for arg in ex.args[2:]:
temp.args = [ex.args[0]]
temp.pred = ex.pred + ' ' + ex.args[1]
words = arg.split()
#Add preposition of arg to rel
if words[0].lower() in Benchmark.PREPS:
temp.pred += ' ' + words[0]
words = words[1:]
temp.args.append(' '.join(words))
res[sent].append(temp)
return res
@staticmethod
def f1(prec, rec):
try:
return 2*prec*rec / (prec+rec)
except ZeroDivisionError:
return 0
@staticmethod
def aggregate_scores_greedily(scores):
# Greedy match: pick the prediction/gold match with the best f1 and exclude
# them both, until nothing left matches. Each input square is a [prec, rec]
# pair. Returns precision and recall as score-and-denominator pairs.
matches = []
while True:
max_s = 0
gold, pred = None, None
for i, gold_ss in enumerate(scores):
if i in [m[0] for m in matches]:
# Those are already taken rows
continue
for j, pred_s in enumerate(scores[i]):
if j in [m[1] for m in matches]:
# Those are used columns
continue
if pred_s and Benchmark.f1(*pred_s) > max_s:
max_s = Benchmark.f1(*pred_s)
gold = i
pred = j
if max_s == 0:
break
matches.append([gold, pred])
# Now that matches are determined, compute final scores.
prec_scores = [scores[i][j][0] for i,j in matches]
rec_scores = [scores[i][j][1] for i,j in matches]
total_prec = sum(prec_scores)
total_rec = sum(rec_scores)
scoring_metrics = {"precision" : [total_prec, len(scores[0])],
"recall" : [total_rec, len(scores)],
"precision_of_matches" : prec_scores,
"recall_of_matches" : rec_scores
}
return scoring_metrics
# Helper functions:
@staticmethod
def normalizeDict(d):
return dict([(Benchmark.normalizeKey(k), v) for k, v in d.items()])
@staticmethod
def normalizeKey(k):
# return Benchmark.removePunct(unicode(Benchmark.PTB_unescape(k.replace(' ','')), errors = 'ignore'))
return Benchmark.removePunct(str(Benchmark.PTB_unescape(k.replace(' ',''))))
@staticmethod
def PTB_escape(s):
for u, e in Benchmark.PTB_ESCAPES:
s = s.replace(u, e)
return s
@staticmethod
def PTB_unescape(s):
for u, e in Benchmark.PTB_ESCAPES:
s = s.replace(e, u)
return s
@staticmethod
def removePunct(s):
return Benchmark.regex.sub('', s)
# CONSTANTS
regex = re.compile('[%s]' % re.escape(string.punctuation))
# Penn treebank bracket escapes
# Taken from: https://github.com/nlplab/brat/blob/master/server/src/gtbtokenize.py
PTB_ESCAPES = [('(', '-LRB-'),
(')', '-RRB-'),
('[', '-LSB-'),
(']', '-RSB-'),
('{', '-LCB-'),
('}', '-RCB-'),]
PREPS = ['above','across','against','along','among','around','at','before','behind','below','beneath','beside','between','by','for','from','in','into','near','of','off','on','to','toward','under','upon','with','within']
def f_beta(precision, recall, beta = 1):
"""
Get F_beta score from precision and recall.
"""
beta = float(beta) # Make sure that results are in float
return (1 + pow(beta, 2)) * (precision * recall) / ((pow(beta, 2) * precision) + recall)
if __name__ == '__main__':
args = docopt.docopt(__doc__)
logging.debug(args)
if args['--stanford']:
predicted = StanfordReader()
predicted.read(args['--stanford'])
if args['--props']:
predicted = PropSReader()
predicted.read(args['--props'])
if args['--ollie']:
predicted = OllieReader()
predicted.read(args['--ollie'])
if args['--reverb']:
predicted = ReVerbReader()
predicted.read(args['--reverb'])
if args['--clausie']:
predicted = ClausieReader()
predicted.read(args['--clausie'])
if args['--openiefour']:
predicted = OpenieFourReader()
predicted.read(args['--openiefour'])
if args['--openiefive']:
predicted = OpenieFiveReader()
predicted.read(args['--openiefive'])
if args['--benchmarkGold']:
predicted = BenchmarkGoldReader()
predicted.read(args['--benchmarkGold'])
if args['--tabbed']:
predicted = TabReader()
predicted.read(args['--tabbed'])
if args['--binaryMatch']:
matchingFunc = Matcher.binary_tuple_match
elif args['--simpleMatch']:
matchingFunc = Matcher.simple_tuple_match
elif args['--exactMatch']:
matchingFunc = Matcher.argMatch
elif args['--predMatch']:
matchingFunc = Matcher.predMatch
elif args['--lexicalMatch']:
matchingFunc = Matcher.lexicalMatch
elif args['--strictMatch']:
matchingFunc = Matcher.tuple_match
else:
matchingFunc = Matcher.binary_linient_tuple_match
b = Benchmark(args['--gold'])
out_filename = args['--out']
logging.info("Writing PR curve of {} to {}".format(predicted.name, out_filename))
auc, optimal_f1_point = b.compare(predicted = predicted.oie,
matchingFunc = matchingFunc,
output_fn = out_filename,
error_file = args["--error-file"],
binary = args["--binary"])
print("AUC: {}\t Optimal (precision, recall, F1): {}".format( auc, optimal_f1_point ))

View File

@ -1,444 +0,0 @@
from sklearn.preprocessing.data import binarize
from carb.argument import Argument
from operator import itemgetter
from collections import defaultdict
import nltk
import itertools
import logging
import numpy as np
import pdb
class Extraction:
"""
Stores sentence, single predicate and corresponding arguments.
"""
def __init__(self, pred, head_pred_index, sent, confidence, question_dist = '', index = -1):
self.pred = pred
self.head_pred_index = head_pred_index
self.sent = sent
self.args = []
self.confidence = confidence
self.matched = []
self.questions = {}
self.indsForQuestions = defaultdict(lambda: set())
self.is_mwp = False
self.question_dist = question_dist
self.index = index
def distArgFromPred(self, arg):
assert(len(self.pred) == 2)
dists = []
for x in self.pred[1]:
for y in arg.indices:
dists.append(abs(x - y))
return min(dists)
def argsByDistFromPred(self, question):
return sorted(self.questions[question], key = lambda arg: self.distArgFromPred(arg))
def addArg(self, arg, question = None):
self.args.append(arg)
if question:
self.questions[question] = self.questions.get(question,[]) + [Argument(arg)]
def noPronounArgs(self):
"""
Returns True iff all of this extraction's arguments are not pronouns.
"""
for (a, _) in self.args:
tokenized_arg = nltk.word_tokenize(a)
if len(tokenized_arg) == 1:
_, pos_tag = nltk.pos_tag(tokenized_arg)[0]
if ('PRP' in pos_tag):
return False
return True
def isContiguous(self):
return all([indices for (_, indices) in self.args])
def toBinary(self):
''' Try to represent this extraction's arguments as binary
If fails, this function will return an empty list. '''
ret = [self.elementToStr(self.pred)]
if len(self.args) == 2:
# we're in luck
return ret + [self.elementToStr(arg) for arg in self.args]
return []
if not self.isContiguous():
# give up on non contiguous arguments (as we need indexes)
return []
# otherwise, try to merge based on indices
# TODO: you can explore other methods for doing this
binarized = self.binarizeByIndex()
if binarized:
return ret + binarized
return []
def elementToStr(self, elem, print_indices = True):
''' formats an extraction element (pred or arg) as a raw string
removes indices and trailing spaces '''
if print_indices:
return str(elem)
if isinstance(elem, str):
return elem
if isinstance(elem, tuple):
ret = elem[0].rstrip().lstrip()
else:
ret = ' '.join(elem.words)
assert ret, "empty element? {0}".format(elem)
return ret
def binarizeByIndex(self):
extraction = [self.pred] + self.args
markPred = [(w, ind, i == 0) for i, (w, ind) in enumerate(extraction)]
sortedExtraction = sorted(markPred, key = lambda ws, indices, f : indices[0])
s = ' '.join(['{1} {0} {1}'.format(self.elementToStr(elem), SEP) if elem[2] else self.elementToStr(elem) for elem in sortedExtraction])
binArgs = [a for a in s.split(SEP) if a.rstrip().lstrip()]
if len(binArgs) == 2:
return binArgs
# failure
return []
def bow(self):
return ' '.join([self.elementToStr(elem) for elem in [self.pred] + self.args])
def getSortedArgs(self):
"""
Sort the list of arguments.
If a question distribution is provided - use it,
otherwise, default to the order of appearance in the sentence.
"""
if self.question_dist:
# There's a question distribtuion - use it
return self.sort_args_by_distribution()
ls = []
for q, args in self.questions.iteritems():
if (len(args) != 1):
logging.debug("Not one argument: {}".format(args))
continue
arg = args[0]
indices = list(self.indsForQuestions[q].union(arg.indices))
if not indices:
logging.debug("Empty indexes for arg {} -- backing to zero".format(arg))
indices = [0]
ls.append(((arg, q), indices))
return [a for a, _ in sorted(ls,
key = lambda _, indices: min(indices))]
def question_prob_for_loc(self, question, loc):
"""
Returns the probability of the given question leading to argument
appearing in the given location in the output slot.
"""
gen_question = generalize_question(question)
q_dist = self.question_dist[gen_question]
logging.debug("distribution of {}: {}".format(gen_question,
q_dist))
return float(q_dist.get(loc, 0)) / \
sum(q_dist.values())
def sort_args_by_distribution(self):
"""
Use this instance's question distribution (this func assumes it exists)
in determining the positioning of the arguments.
Greedy algorithm:
0. Decide on which argument will serve as the ``subject'' (first slot) of this extraction
0.1 Based on the most probable one for this spot
(special care is given to select the highly-influential subject position)
1. For all other arguments, sort arguments by the prevalance of their questions
2. For each argument:
2.1 Assign to it the most probable slot still available
2.2 If non such exist (fallback) - default to put it in the last location
"""
INF_LOC = 100 # Used as an impractical last argument
# Store arguments by slot
ret = {INF_LOC: []}
logging.debug("sorting: {}".format(self.questions))
# Find the most suitable arguemnt for the subject location
logging.debug("probs for subject: {}".format([(q, self.question_prob_for_loc(q, 0))
for (q, _) in self.questions.iteritems()]))
subj_question, subj_args = max(self.questions.iteritems(),
key = lambda q, _: self.question_prob_for_loc(q, 0))
ret[0] = [(subj_args[0], subj_question)]
# Find the rest
for (question, args) in sorted([(q, a)
for (q, a) in self.questions.iteritems() if (q not in [subj_question])],
key = lambda q, _: \
sum(self.question_dist[generalize_question(q)].values()),
reverse = True):
gen_question = generalize_question(question)
arg = args[0]
assigned_flag = False
for (loc, count) in sorted(self.question_dist[gen_question].iteritems(),
key = lambda _ , c: c,
reverse = True):
if loc not in ret:
# Found an empty slot for this item
# Place it there and break out
ret[loc] = [(arg, question)]
assigned_flag = True
break
if not assigned_flag:
# Add this argument to the non-assigned (hopefully doesn't happen much)
logging.debug("Couldn't find an open assignment for {}".format((arg, gen_question)))
ret[INF_LOC].append((arg, question))
logging.debug("Linearizing arg list: {}".format(ret))
# Finished iterating - consolidate and return a list of arguments
return [arg
for (_, arg_ls) in sorted(ret.iteritems(),
key = lambda k, v: int(k))
for arg in arg_ls]
def __str__(self):
pred_str = self.elementToStr(self.pred)
return '{}\t{}\t{}'.format(self.get_base_verb(pred_str),
self.compute_global_pred(pred_str,
self.questions.keys()),
'\t'.join([escape_special_chars(self.augment_arg_with_question(self.elementToStr(arg),
question))
for arg, question in self.getSortedArgs()]))
def get_base_verb(self, surface_pred):
"""
Given the surface pred, return the original annotated verb
"""
# Assumes that at this point the verb is always the last word
# in the surface predicate
return surface_pred.split(' ')[-1]
def compute_global_pred(self, surface_pred, questions):
"""
Given the surface pred and all instansiations of questions,
make global coherence decisions regarding the final form of the predicate
This should hopefully take care of multi word predicates and correct inflections
"""
from operator import itemgetter
split_surface = surface_pred.split(' ')
if len(split_surface) > 1:
# This predicate has a modal preceding the base verb
verb = split_surface[-1]
ret = split_surface[:-1] # get all of the elements in the modal
else:
verb = split_surface[0]
ret = []
split_questions = map(lambda question: question.split(' '),
questions)
preds = map(normalize_element,
map(itemgetter(QUESTION_TRG_INDEX),
split_questions))
if len(set(preds)) > 1:
# This predicate is appears in multiple ways, let's stick to the base form
ret.append(verb)
if len(set(preds)) == 1:
# Change the predciate to the inflected form
# if there's exactly one way in which the predicate is conveyed
ret.append(preds[0])
pps = map(normalize_element,
map(itemgetter(QUESTION_PP_INDEX),
split_questions))
obj2s = map(normalize_element,
map(itemgetter(QUESTION_OBJ2_INDEX),
split_questions))
if (len(set(pps)) == 1):
# If all questions for the predicate include the same pp attachemnt -
# assume it's a multiword predicate
self.is_mwp = True # Signal to arguments that they shouldn't take the preposition
ret.append(pps[0])
# Concat all elements in the predicate and return
return " ".join(ret).strip()
def augment_arg_with_question(self, arg, question):
"""
Decide what elements from the question to incorporate in the given
corresponding argument
"""
# Parse question
wh, aux, sbj, trg, obj1, pp, obj2 = map(normalize_element,
question.split(' ')[:-1]) # Last split is the question mark
# Place preposition in argument
# This is safer when dealing with n-ary arguments, as it's directly attaches to the
# appropriate argument
if (not self.is_mwp) and pp and (not obj2):
if not(arg.startswith("{} ".format(pp))):
# Avoid repeating the preporition in cases where both question and answer contain it
return " ".join([pp,
arg])
# Normal cases
return arg
def clusterScore(self, cluster):
"""
Calculate cluster density score as the mean distance of the maximum distance of each slot.
Lower score represents a denser cluster.
"""
logging.debug("*-*-*- Cluster: {}".format(cluster))
# Find global centroid
arr = np.array([x for ls in cluster for x in ls])
centroid = np.sum(arr)/arr.shape[0]
logging.debug("Centroid: {}".format(centroid))
# Calculate mean over all maxmimum points
return np.average([max([abs(x - centroid) for x in ls]) for ls in cluster])
def resolveAmbiguity(self):
"""
Heursitic to map the elments (argument and predicates) of this extraction
back to the indices of the sentence.
"""
## TODO: This removes arguments for which there was no consecutive span found
## Part of these are non-consecutive arguments,
## but other could be a bug in recognizing some punctuation marks
elements = [self.pred] \
+ [(s, indices)
for (s, indices)
in self.args
if indices]
logging.debug("Resolving ambiguity in: {}".format(elements))
# Collect all possible combinations of arguments and predicate indices
# (hopefully it's not too much)
all_combinations = list(itertools.product(*map(itemgetter(1), elements)))
logging.debug("Number of combinations: {}".format(len(all_combinations)))
# Choose the ones with best clustering and unfold them
resolved_elements = zip(map(itemgetter(0), elements),
min(all_combinations,
key = lambda cluster: self.clusterScore(cluster)))
logging.debug("Resolved elements = {}".format(resolved_elements))
self.pred = resolved_elements[0]
self.args = resolved_elements[1:]
def conll(self, external_feats = {}):
"""
Return a CoNLL string representation of this extraction
"""
return '\n'.join(["\t".join(map(str,
[i, w] + \
list(self.pred) + \
[self.head_pred_index] + \
external_feats + \
[self.get_label(i)]))
for (i, w)
in enumerate(self.sent.split(" "))]) + '\n'
def get_label(self, index):
"""
Given an index of a word in the sentence -- returns the appropriate BIO conll label
Assumes that ambiguation was already resolved.
"""
# Get the element(s) in which this index appears
ent = [(elem_ind, elem)
for (elem_ind, elem)
in enumerate(map(itemgetter(1),
[self.pred] + self.args))
if index in elem]
if not ent:
# index doesnt appear in any element
return "O"
if len(ent) > 1:
# The same word appears in two different answers
# In this case we choose the first one as label
logging.warn("Index {} appears in one than more element: {}".\
format(index,
"\t".join(map(str,
[ent,
self.sent,
self.pred,
self.args]))))
## Some indices appear in more than one argument (ones where the above message appears)
## From empricial observation, these seem to mostly consist of different levels of granularity:
## what had _ been taken _ _ _ ? loan commitments topping $ 3 billion
## how much had _ been taken _ _ _ ? topping $ 3 billion
## In these cases we heuristically choose the shorter answer span, hopefully creating minimal spans
## E.g., in this example two arguemnts are created: (loan commitments, topping $ 3 billion)
elem_ind, elem = min(ent, key = lambda _, ls: len(ls))
# Distinguish between predicate and arguments
prefix = "P" if elem_ind == 0 else "A{}".format(elem_ind - 1)
# Distinguish between Beginning and Inside labels
suffix = "B" if index == elem[0] else "I"
return "{}-{}".format(prefix, suffix)
def __str__(self):
return '{0}\t{1}'.format(self.elementToStr(self.pred,
print_indices = True),
'\t'.join([self.elementToStr(arg)
for arg
in self.args]))
# Flatten a list of lists
flatten = lambda l: [item for sublist in l for item in sublist]
def normalize_element(elem):
"""
Return a surface form of the given question element.
the output should be properly able to precede a predicate (or blank otherwise)
"""
return elem.replace("_", " ") \
if (elem != "_")\
else ""
## Helper functions
def escape_special_chars(s):
return s.replace('\t', '\\t')
def generalize_question(question):
"""
Given a question in the context of the sentence and the predicate index within
the question - return a generalized version which extracts only order-imposing features
"""
import nltk # Using nltk since couldn't get spaCy to agree on the tokenization
wh, aux, sbj, trg, obj1, pp, obj2 = question.split(' ')[:-1] # Last split is the question mark
return ' '.join([wh, sbj, obj1])
## CONSTANTS
SEP = ';;;'
QUESTION_TRG_INDEX = 3 # index of the predicate within the question
QUESTION_PP_INDEX = 5
QUESTION_OBJ2_INDEX = 6

View File

@ -1,53 +0,0 @@
from carb.oieReader import OieReader
from carb.extraction import Extraction
from _collections import defaultdict
import ipdb
class GoldReader(OieReader):
# Path relative to repo root folder
default_filename = './oie_corpus/all.oie'
def __init__(self):
self.name = 'Gold'
def read(self, fn):
d = defaultdict(lambda: [])
multilingual = False
for lang in ['spanish']:
if lang in fn:
multilingual = True
encoding = lang
break
if multilingual and encoding == 'spanish':
fin = open(fn, 'r', encoding='latin-1')
else:
fin = open(fn)
#with open(fn) as fin:
for line_ind, line in enumerate(fin):
data = line.strip().split('\t')
text, rel = data[:2]
args = data[2:]
confidence = 1
curExtraction = Extraction(pred = rel.strip(),
head_pred_index = None,
sent = text.strip(),
confidence = float(confidence),
index = line_ind)
for arg in args:
if "C: " in arg:
continue
curExtraction.addArg(arg.strip())
d[text.strip()].append(curExtraction)
self.oie = d
if __name__ == '__main__' :
g = GoldReader()
g.read('../oie_corpus/all.oie', includeNominal = False)
d = g.oie
e = d.items()[0]
print(e[1][0].bow())
print(g.count())

View File

@ -1,46 +0,0 @@
from carb.oieReader import OieReader
from carb.extraction import Extraction
from _collections import defaultdict
import json
class Relabel_GoldReader(OieReader):
# Path relative to repo root folder
default_filename = './oie_corpus/all.oie'
def __init__(self):
self.name = 'Relabel_Gold'
def read(self, fn):
d = defaultdict(lambda: [])
with open(fn) as fin:
data = json.load(fin)
for sentence in data:
tuples = data[sentence]
for t in tuples:
if t["pred"].strip() == "<be>":
rel = "[is]"
else:
rel = t["pred"].replace("<be> ","")
confidence = 1
curExtraction = Extraction(pred = rel,
head_pred_index = None,
sent = sentence,
confidence = float(confidence),
index = None)
if t["arg0"] != "":
curExtraction.addArg(t["arg0"])
if t["arg1"] != "":
curExtraction.addArg(t["arg1"])
if t["arg2"] != "":
curExtraction.addArg(t["arg2"])
if t["arg3"] != "":
curExtraction.addArg(t["arg3"])
if t["temp"] != "":
curExtraction.addArg(t["temp"])
if t["loc"] != "":
curExtraction.addArg(t["loc"])
d[sentence].append(curExtraction)
self.oie = d

View File

@ -1,340 +0,0 @@
from __future__ import division
import string
from nltk.translate.bleu_score import sentence_bleu
from nltk.corpus import stopwords
from copy import copy
import ipdb
class Matcher:
@staticmethod
def bowMatch(ref, ex, ignoreStopwords, ignoreCase):
"""
A binary function testing for exact lexical match (ignoring ordering) between reference
and predicted extraction
"""
s1 = ref.bow()
s2 = ex.bow()
if ignoreCase:
s1 = s1.lower()
s2 = s2.lower()
s1Words = s1.split(' ')
s2Words = s2.split(' ')
if ignoreStopwords:
s1Words = Matcher.removeStopwords(s1Words)
s2Words = Matcher.removeStopwords(s2Words)
return sorted(s1Words) == sorted(s2Words)
@staticmethod
def predMatch(ref, ex, ignoreStopwords, ignoreCase):
"""
Return whehter gold and predicted extractions agree on the predicate
"""
s1 = ref.elementToStr(ref.pred)
s2 = ex.elementToStr(ex.pred)
if ignoreCase:
s1 = s1.lower()
s2 = s2.lower()
s1Words = s1.split(' ')
s2Words = s2.split(' ')
if ignoreStopwords:
s1Words = Matcher.removeStopwords(s1Words)
s2Words = Matcher.removeStopwords(s2Words)
return s1Words == s2Words
@staticmethod
def argMatch(ref, ex, ignoreStopwords, ignoreCase):
"""
Return whehter gold and predicted extractions agree on the arguments
"""
sRef = ' '.join([ref.elementToStr(elem) for elem in ref.args])
sEx = ' '.join([ex.elementToStr(elem) for elem in ex.args])
count = 0
for w1 in sRef:
for w2 in sEx:
if w1 == w2:
count += 1
# We check how well does the extraction lexically cover the reference
# Note: this is somewhat lenient as it doesn't penalize the extraction for
# being too long
coverage = float(count) / len(sRef)
return coverage > Matcher.LEXICAL_THRESHOLD
@staticmethod
def bleuMatch(ref, ex, ignoreStopwords, ignoreCase):
sRef = ref.bow()
sEx = ex.bow()
bleu = sentence_bleu(references = [sRef.split(' ')], hypothesis = sEx.split(' '))
return bleu > Matcher.BLEU_THRESHOLD
@staticmethod
def lexicalMatch(ref, ex, ignoreStopwords, ignoreCase):
sRef = ref.bow().split(' ')
sEx = ex.bow().split(' ')
count = 0
#for w1 in sRef:
# if w1 in sEx:
# count += 1
# sEx.remove(w1)
for w1 in sRef:
for w2 in sEx:
if w1 == w2:
count += 1
# We check how well does the extraction lexically cover the reference
# Note: this is somewhat lenient as it doesn't penalize the extraction for
# being too long
coverage = float(count) / len(sRef)
return coverage > Matcher.LEXICAL_THRESHOLD
@staticmethod
def tuple_match(ref, ex, ignoreStopwords, ignoreCase):
precision = [0, 0] # 0 out of 0 predicted words match
recall = [0, 0] # 0 out of 0 reference words match
# If, for each part, any word is the same as a reference word, then it's a match.
predicted_words = ex.pred.split()
gold_words = ref.pred.split()
precision[1] += len(predicted_words)
recall[1] += len(gold_words)
# matching_words = sum(1 for w in predicted_words if w in gold_words)
matching_words = 0
for w in gold_words:
if w in predicted_words:
matching_words += 1
predicted_words.remove(w)
if matching_words == 0:
return False # t <-> gt is not a match
precision[0] += matching_words
recall[0] += matching_words
for i in range(len(ref.args)):
gold_words = ref.args[i].split()
recall[1] += len(gold_words)
if len(ex.args) <= i:
if i<2:
return False
else:
continue
predicted_words = ex.args[i].split()
precision[1] += len(predicted_words)
matching_words = 0
for w in gold_words:
if w in predicted_words:
matching_words += 1
predicted_words.remove(w)
if matching_words == 0 and i<2:
return False # t <-> gt is not a match
precision[0] += matching_words
# Currently this slightly penalises systems when the reference
# reformulates the sentence words, because the reformulation doesn't
# match the predicted word. It's a one-wrong-word penalty to precision,
# to all systems that correctly extracted the reformulated word.
recall[0] += matching_words
prec = 1.0 * precision[0] / precision[1]
rec = 1.0 * recall[0] / recall[1]
return [prec, rec]
# STRICTER LINIENT MATCH
def linient_tuple_match(ref, ex, ignoreStopwords, ignoreCase):
precision = [0, 0] # 0 out of 0 predicted words match
recall = [0, 0] # 0 out of 0 reference words match
# If, for each part, any word is the same as a reference word, then it's a match.
predicted_words = ex.pred.split()
gold_words = ref.pred.split()
precision[1] += len(predicted_words)
recall[1] += len(gold_words)
# matching_words = sum(1 for w in predicted_words if w in gold_words)
matching_words = 0
for w in gold_words:
if w in predicted_words:
matching_words += 1
predicted_words.remove(w)
# matching 'be' with its different forms
forms_of_be = ["be","is","am","are","was","were","been","being"]
if "be" in predicted_words:
for form in forms_of_be:
if form in gold_words:
matching_words += 1
predicted_words.remove("be")
break
if matching_words == 0:
return [0,0] # t <-> gt is not a match
precision[0] += matching_words
recall[0] += matching_words
for i in range(len(ref.args)):
gold_words = ref.args[i].split()
recall[1] += len(gold_words)
if len(ex.args) <= i:
if i<2:
return [0,0] # changed
else:
continue
predicted_words = ex.args[i].split()
precision[1] += len(predicted_words)
matching_words = 0
for w in gold_words:
if w in predicted_words:
matching_words += 1
predicted_words.remove(w)
precision[0] += matching_words
# Currently this slightly penalises systems when the reference
# reformulates the sentence words, because the reformulation doesn't
# match the predicted word. It's a one-wrong-word penalty to precision,
# to all systems that correctly extracted the reformulated word.
recall[0] += matching_words
if(precision[1] == 0):
prec = 0
else:
prec = 1.0 * precision[0] / precision[1]
if(recall[1] == 0):
rec = 0
else:
rec = 1.0 * recall[0] / recall[1]
return [prec, rec]
@staticmethod
def simple_tuple_match(ref, ex, ignoreStopwords, ignoreCase):
ref.args = [ref.args[0], ' '.join(ref.args[1:])]
ex.args = [ex.args[0], ' '.join(ex.args[1:])]
precision = [0, 0] # 0 out of 0 predicted words match
recall = [0, 0] # 0 out of 0 reference words match
# If, for each part, any word is the same as a reference word, then it's a match.
predicted_words = ex.pred.split()
gold_words = ref.pred.split()
precision[1] += len(predicted_words)
recall[1] += len(gold_words)
matching_words = 0
for w in gold_words:
if w in predicted_words:
matching_words += 1
predicted_words.remove(w)
precision[0] += matching_words
recall[0] += matching_words
for i in range(len(ref.args)):
gold_words = ref.args[i].split()
recall[1] += len(gold_words)
if len(ex.args) <= i:
break
predicted_words = ex.args[i].split()
precision[1] += len(predicted_words)
matching_words = 0
for w in gold_words:
if w in predicted_words:
matching_words += 1
predicted_words.remove(w)
precision[0] += matching_words
# Currently this slightly penalises systems when the reference
# reformulates the sentence words, because the reformulation doesn't
# match the predicted word. It's a one-wrong-word penalty to precision,
# to all systems that correctly extracted the reformulated word.
recall[0] += matching_words
prec = 1.0 * precision[0] / precision[1]
rec = 1.0 * recall[0] / recall[1]
return [prec, rec]
# @staticmethod
# def binary_linient_tuple_match(ref, ex, ignoreStopwords, ignoreCase):
# if len(ref.args)>=2:
# # r = ref.copy()
# r = copy(ref)
# r.args = [ref.args[0], ' '.join(ref.args[1:])]
# else:
# r = ref
# if len(ex.args)>=2:
# # e = ex.copy()
# e = copy(ex)
# e.args = [ex.args[0], ' '.join(ex.args[1:])]
# else:
# e = ex
# return Matcher.linient_tuple_match(r, e, ignoreStopwords, ignoreCase)
@staticmethod
def binary_linient_tuple_match(ref, ex, ignoreStopwords, ignoreCase):
if len(ref.args)>=2:
r = copy(ref)
r.args = [ref.args[0], ' '.join(ref.args[1:])]
else:
r = ref
if len(ex.args)>=2:
e = copy(ex)
e.args = [ex.args[0], ' '.join(ex.args[1:])]
else:
e = ex
stright_match = Matcher.linient_tuple_match(r, e, ignoreStopwords, ignoreCase)
said_type_reln = ["said", "told", "added", "adds", "says", "adds"]
said_type_sentence = False
for said_verb in said_type_reln:
if said_verb in ref.pred:
said_type_sentence = True
break
if not said_type_sentence:
return stright_match
else:
if len(ex.args)>=2:
e = copy(ex)
e.args = [' '.join(ex.args[1:]), ex.args[0]]
else:
e = ex
reverse_match = Matcher.linient_tuple_match(r, e, ignoreStopwords, ignoreCase)
return max(stright_match, reverse_match)
@staticmethod
def binary_tuple_match(ref, ex, ignoreStopwords, ignoreCase):
if len(ref.args)>=2:
# r = ref.copy()
r = copy(ref)
r.args = [ref.args[0], ' '.join(ref.args[1:])]
else:
r = ref
if len(ex.args)>=2:
# e = ex.copy()
e = copy(ex)
e.args = [ex.args[0], ' '.join(ex.args[1:])]
else:
e = ex
return Matcher.tuple_match(r, e, ignoreStopwords, ignoreCase)
@staticmethod
def removeStopwords(ls):
return [w for w in ls if w.lower() not in Matcher.stopwords]
# CONSTANTS
BLEU_THRESHOLD = 0.4
LEXICAL_THRESHOLD = 0.5 # Note: changing this value didn't change the ordering of the tested systems
stopwords = stopwords.words('english') + list(string.punctuation)

View File

@ -1,45 +0,0 @@
class OieReader:
def read(self, fn, includeNominal):
''' should set oie as a class member
as a dictionary of extractions by sentence'''
raise Exception("Don't run me")
def count(self):
''' number of extractions '''
return sum([len(extractions) for _, extractions in self.oie.items()])
def split_to_corpus(self, corpus_fn, out_fn):
"""
Given a corpus file name, containing a list of sentences
print only the extractions pertaining to it to out_fn in a tab separated format:
sent, prob, pred, arg1, arg2, ...
"""
raw_sents = [line.strip() for line in open(corpus_fn)]
with open(out_fn, 'w') as fout:
for line in self.get_tabbed().split('\n'):
data = line.split('\t')
sent = data[0]
if sent in raw_sents:
fout.write(line + '\n')
def output_tabbed(self, out_fn):
"""
Write a tabbed represenation of this corpus.
"""
with open(out_fn, 'w') as fout:
fout.write(self.get_tabbed())
def get_tabbed(self):
"""
Get a tabbed format representation of this corpus (assumes that input was
already read).
"""
return "\n".join(['\t'.join(map(str,
[ex.sent,
ex.confidence,
ex.pred,
'\t'.join(ex.args)]))
for (sent, exs) in self.oie.iteritems()
for ex in exs])

View File

@ -1,21 +0,0 @@
import nltk
from operator import itemgetter
class Argument:
def __init__(self, arg):
self.words = [x for x in arg[0].strip().split(' ') if x]
self.posTags = map(itemgetter(1), nltk.pos_tag(self.words))
self.indices = arg[1]
self.feats = {}
def __str__(self):
return "({})".format('\t'.join(map(str,
[escape_special_chars(' '.join(self.words)),
str(self.indices)])))
COREF = 'coref'
## Helper functions
def escape_special_chars(s):
return s.replace('\t', '\\t')

View File

@ -1,55 +0,0 @@
""" Usage:
benchmarkGoldReader --in=INPUT_FILE
Read a tab-formatted file.
Each line consists of:
sent, prob, pred, arg1, arg2, ...
"""
from oie_readers.oieReader import OieReader
from oie_readers.extraction import Extraction
from docopt import docopt
import logging
logging.basicConfig(level = logging.DEBUG)
class BenchmarkGoldReader(OieReader):
def __init__(self):
self.name = 'BenchmarkGoldReader'
def read(self, fn):
"""
Read a tabbed format line
Each line consists of:
sent, prob, pred, arg1, arg2, ...
"""
d = {}
ex_index = 0
with open(fn) as fin:
for line in fin:
if not line.strip():
continue
data = line.strip().split('\t')
text, rel = data[:2]
curExtraction = Extraction(pred = rel.strip(),
head_pred_index = None,
sent = text.strip(),
confidence = 1.0,
question_dist = "./question_distributions/dist_wh_sbj_obj1.json",
index = ex_index)
ex_index += 1
for arg in data[2:]:
curExtraction.addArg(arg.strip())
d[text] = d.get(text, []) + [curExtraction]
self.oie = d
if __name__ == "__main__":
args = docopt(__doc__)
input_fn = args["--in"]
tr = BenchmarkGoldReader()
tr.read(input_fn)

View File

@ -1,90 +0,0 @@
""" Usage:
<file-name> --in=INPUT_FILE --out=OUTPUT_FILE [--debug]
Convert to tabbed format
"""
# External imports
import logging
from pprint import pprint
from pprint import pformat
from docopt import docopt
# Local imports
from oie_readers.oieReader import OieReader
from oie_readers.extraction import Extraction
import ipdb
#=-----
class ClausieReader(OieReader):
def __init__(self):
self.name = 'ClausIE'
def read(self, fn):
d = {}
with open(fn, encoding="utf-8") as fin:
for line in fin:
data = line.strip().split('\t')
if len(data) == 1:
text = data[0]
elif len(data) == 5:
arg1, rel, arg2 = [s[1:-1] for s in data[1:4]]
confidence = data[4]
curExtraction = Extraction(pred = rel,
head_pred_index = -1,
sent = text,
confidence = float(confidence))
curExtraction.addArg(arg1)
curExtraction.addArg(arg2)
d[text] = d.get(text, []) + [curExtraction]
self.oie = d
# self.normalizeConfidence()
# # remove exxtractions below the confidence threshold
# if type(self.threshold) != type(None):
# new_d = {}
# for sent in self.oie:
# for extraction in self.oie[sent]:
# if extraction.confidence < self.threshold:
# continue
# else:
# new_d[sent] = new_d.get(sent, []) + [extraction]
# self.oie = new_d
def normalizeConfidence(self):
''' Normalize confidence to resemble probabilities '''
EPSILON = 1e-3
confidences = [extraction.confidence for sent in self.oie for extraction in self.oie[sent]]
maxConfidence = max(confidences)
minConfidence = min(confidences)
denom = maxConfidence - minConfidence + (2*EPSILON)
for sent, extractions in self.oie.items():
for extraction in extractions:
extraction.confidence = ( (extraction.confidence - minConfidence) + EPSILON) / denom
if __name__ == "__main__":
# Parse command line arguments
args = docopt(__doc__)
inp_fn = args["--in"]
out_fn = args["--out"]
debug = args["--debug"]
if debug:
logging.basicConfig(level = logging.DEBUG)
else:
logging.basicConfig(level = logging.INFO)
oie = ClausieReader()
oie.read(inp_fn)
oie.output_tabbed(out_fn)
logging.info("DONE")

View File

@ -1,444 +0,0 @@
from sklearn.preprocessing.data import binarize
from oie_readers.argument import Argument
from operator import itemgetter
from collections import defaultdict
import nltk
import itertools
import logging
import numpy as np
import pdb
class Extraction:
"""
Stores sentence, single predicate and corresponding arguments.
"""
def __init__(self, pred, head_pred_index, sent, confidence, question_dist = '', index = -1):
self.pred = pred
self.head_pred_index = head_pred_index
self.sent = sent
self.args = []
self.confidence = confidence
self.matched = []
self.questions = {}
self.indsForQuestions = defaultdict(lambda: set())
self.is_mwp = False
self.question_dist = question_dist
self.index = index
def distArgFromPred(self, arg):
assert(len(self.pred) == 2)
dists = []
for x in self.pred[1]:
for y in arg.indices:
dists.append(abs(x - y))
return min(dists)
def argsByDistFromPred(self, question):
return sorted(self.questions[question], key = lambda arg: self.distArgFromPred(arg))
def addArg(self, arg, question = None):
self.args.append(arg)
if question:
self.questions[question] = self.questions.get(question,[]) + [Argument(arg)]
def noPronounArgs(self):
"""
Returns True iff all of this extraction's arguments are not pronouns.
"""
for (a, _) in self.args:
tokenized_arg = nltk.word_tokenize(a)
if len(tokenized_arg) == 1:
_, pos_tag = nltk.pos_tag(tokenized_arg)[0]
if ('PRP' in pos_tag):
return False
return True
def isContiguous(self):
return all([indices for (_, indices) in self.args])
def toBinary(self):
''' Try to represent this extraction's arguments as binary
If fails, this function will return an empty list. '''
ret = [self.elementToStr(self.pred)]
if len(self.args) == 2:
# we're in luck
return ret + [self.elementToStr(arg) for arg in self.args]
return []
if not self.isContiguous():
# give up on non contiguous arguments (as we need indexes)
return []
# otherwise, try to merge based on indices
# TODO: you can explore other methods for doing this
binarized = self.binarizeByIndex()
if binarized:
return ret + binarized
return []
def elementToStr(self, elem, print_indices = True):
''' formats an extraction element (pred or arg) as a raw string
removes indices and trailing spaces '''
if print_indices:
return str(elem)
if isinstance(elem, str):
return elem
if isinstance(elem, tuple):
ret = elem[0].rstrip().lstrip()
else:
ret = ' '.join(elem.words)
assert ret, "empty element? {0}".format(elem)
return ret
def binarizeByIndex(self):
extraction = [self.pred] + self.args
markPred = [(w, ind, i == 0) for i, (w, ind) in enumerate(extraction)]
sortedExtraction = sorted(markPred, key = lambda ws, indices, f : indices[0])
s = ' '.join(['{1} {0} {1}'.format(self.elementToStr(elem), SEP) if elem[2] else self.elementToStr(elem) for elem in sortedExtraction])
binArgs = [a for a in s.split(SEP) if a.rstrip().lstrip()]
if len(binArgs) == 2:
return binArgs
# failure
return []
def bow(self):
return ' '.join([self.elementToStr(elem) for elem in [self.pred] + self.args])
def getSortedArgs(self):
"""
Sort the list of arguments.
If a question distribution is provided - use it,
otherwise, default to the order of appearance in the sentence.
"""
if self.question_dist:
# There's a question distribtuion - use it
return self.sort_args_by_distribution()
ls = []
for q, args in self.questions.iteritems():
if (len(args) != 1):
logging.debug("Not one argument: {}".format(args))
continue
arg = args[0]
indices = list(self.indsForQuestions[q].union(arg.indices))
if not indices:
logging.debug("Empty indexes for arg {} -- backing to zero".format(arg))
indices = [0]
ls.append(((arg, q), indices))
return [a for a, _ in sorted(ls,
key = lambda _, indices: min(indices))]
def question_prob_for_loc(self, question, loc):
"""
Returns the probability of the given question leading to argument
appearing in the given location in the output slot.
"""
gen_question = generalize_question(question)
q_dist = self.question_dist[gen_question]
logging.debug("distribution of {}: {}".format(gen_question,
q_dist))
return float(q_dist.get(loc, 0)) / \
sum(q_dist.values())
def sort_args_by_distribution(self):
"""
Use this instance's question distribution (this func assumes it exists)
in determining the positioning of the arguments.
Greedy algorithm:
0. Decide on which argument will serve as the ``subject'' (first slot) of this extraction
0.1 Based on the most probable one for this spot
(special care is given to select the highly-influential subject position)
1. For all other arguments, sort arguments by the prevalance of their questions
2. For each argument:
2.1 Assign to it the most probable slot still available
2.2 If non such exist (fallback) - default to put it in the last location
"""
INF_LOC = 100 # Used as an impractical last argument
# Store arguments by slot
ret = {INF_LOC: []}
logging.debug("sorting: {}".format(self.questions))
# Find the most suitable arguemnt for the subject location
logging.debug("probs for subject: {}".format([(q, self.question_prob_for_loc(q, 0))
for (q, _) in self.questions.iteritems()]))
subj_question, subj_args = max(self.questions.iteritems(),
key = lambda q, _: self.question_prob_for_loc(q, 0))
ret[0] = [(subj_args[0], subj_question)]
# Find the rest
for (question, args) in sorted([(q, a)
for (q, a) in self.questions.iteritems() if (q not in [subj_question])],
key = lambda q, _: \
sum(self.question_dist[generalize_question(q)].values()),
reverse = True):
gen_question = generalize_question(question)
arg = args[0]
assigned_flag = False
for (loc, count) in sorted(self.question_dist[gen_question].iteritems(),
key = lambda _ , c: c,
reverse = True):
if loc not in ret:
# Found an empty slot for this item
# Place it there and break out
ret[loc] = [(arg, question)]
assigned_flag = True
break
if not assigned_flag:
# Add this argument to the non-assigned (hopefully doesn't happen much)
logging.debug("Couldn't find an open assignment for {}".format((arg, gen_question)))
ret[INF_LOC].append((arg, question))
logging.debug("Linearizing arg list: {}".format(ret))
# Finished iterating - consolidate and return a list of arguments
return [arg
for (_, arg_ls) in sorted(ret.iteritems(),
key = lambda k, v: int(k))
for arg in arg_ls]
def __str__(self):
pred_str = self.elementToStr(self.pred)
return '{}\t{}\t{}'.format(self.get_base_verb(pred_str),
self.compute_global_pred(pred_str,
self.questions.keys()),
'\t'.join([escape_special_chars(self.augment_arg_with_question(self.elementToStr(arg),
question))
for arg, question in self.getSortedArgs()]))
def get_base_verb(self, surface_pred):
"""
Given the surface pred, return the original annotated verb
"""
# Assumes that at this point the verb is always the last word
# in the surface predicate
return surface_pred.split(' ')[-1]
def compute_global_pred(self, surface_pred, questions):
"""
Given the surface pred and all instansiations of questions,
make global coherence decisions regarding the final form of the predicate
This should hopefully take care of multi word predicates and correct inflections
"""
from operator import itemgetter
split_surface = surface_pred.split(' ')
if len(split_surface) > 1:
# This predicate has a modal preceding the base verb
verb = split_surface[-1]
ret = split_surface[:-1] # get all of the elements in the modal
else:
verb = split_surface[0]
ret = []
split_questions = map(lambda question: question.split(' '),
questions)
preds = map(normalize_element,
map(itemgetter(QUESTION_TRG_INDEX),
split_questions))
if len(set(preds)) > 1:
# This predicate is appears in multiple ways, let's stick to the base form
ret.append(verb)
if len(set(preds)) == 1:
# Change the predciate to the inflected form
# if there's exactly one way in which the predicate is conveyed
ret.append(preds[0])
pps = map(normalize_element,
map(itemgetter(QUESTION_PP_INDEX),
split_questions))
obj2s = map(normalize_element,
map(itemgetter(QUESTION_OBJ2_INDEX),
split_questions))
if (len(set(pps)) == 1):
# If all questions for the predicate include the same pp attachemnt -
# assume it's a multiword predicate
self.is_mwp = True # Signal to arguments that they shouldn't take the preposition
ret.append(pps[0])
# Concat all elements in the predicate and return
return " ".join(ret).strip()
def augment_arg_with_question(self, arg, question):
"""
Decide what elements from the question to incorporate in the given
corresponding argument
"""
# Parse question
wh, aux, sbj, trg, obj1, pp, obj2 = map(normalize_element,
question.split(' ')[:-1]) # Last split is the question mark
# Place preposition in argument
# This is safer when dealing with n-ary arguments, as it's directly attaches to the
# appropriate argument
if (not self.is_mwp) and pp and (not obj2):
if not(arg.startswith("{} ".format(pp))):
# Avoid repeating the preporition in cases where both question and answer contain it
return " ".join([pp,
arg])
# Normal cases
return arg
def clusterScore(self, cluster):
"""
Calculate cluster density score as the mean distance of the maximum distance of each slot.
Lower score represents a denser cluster.
"""
logging.debug("*-*-*- Cluster: {}".format(cluster))
# Find global centroid
arr = np.array([x for ls in cluster for x in ls])
centroid = np.sum(arr)/arr.shape[0]
logging.debug("Centroid: {}".format(centroid))
# Calculate mean over all maxmimum points
return np.average([max([abs(x - centroid) for x in ls]) for ls in cluster])
def resolveAmbiguity(self):
"""
Heursitic to map the elments (argument and predicates) of this extraction
back to the indices of the sentence.
"""
## TODO: This removes arguments for which there was no consecutive span found
## Part of these are non-consecutive arguments,
## but other could be a bug in recognizing some punctuation marks
elements = [self.pred] \
+ [(s, indices)
for (s, indices)
in self.args
if indices]
logging.debug("Resolving ambiguity in: {}".format(elements))
# Collect all possible combinations of arguments and predicate indices
# (hopefully it's not too much)
all_combinations = list(itertools.product(*map(itemgetter(1), elements)))
logging.debug("Number of combinations: {}".format(len(all_combinations)))
# Choose the ones with best clustering and unfold them
resolved_elements = zip(map(itemgetter(0), elements),
min(all_combinations,
key = lambda cluster: self.clusterScore(cluster)))
logging.debug("Resolved elements = {}".format(resolved_elements))
self.pred = resolved_elements[0]
self.args = resolved_elements[1:]
def conll(self, external_feats = {}):
"""
Return a CoNLL string representation of this extraction
"""
return '\n'.join(["\t".join(map(str,
[i, w] + \
list(self.pred) + \
[self.head_pred_index] + \
external_feats + \
[self.get_label(i)]))
for (i, w)
in enumerate(self.sent.split(" "))]) + '\n'
def get_label(self, index):
"""
Given an index of a word in the sentence -- returns the appropriate BIO conll label
Assumes that ambiguation was already resolved.
"""
# Get the element(s) in which this index appears
ent = [(elem_ind, elem)
for (elem_ind, elem)
in enumerate(map(itemgetter(1),
[self.pred] + self.args))
if index in elem]
if not ent:
# index doesnt appear in any element
return "O"
if len(ent) > 1:
# The same word appears in two different answers
# In this case we choose the first one as label
logging.warn("Index {} appears in one than more element: {}".\
format(index,
"\t".join(map(str,
[ent,
self.sent,
self.pred,
self.args]))))
## Some indices appear in more than one argument (ones where the above message appears)
## From empricial observation, these seem to mostly consist of different levels of granularity:
## what had _ been taken _ _ _ ? loan commitments topping $ 3 billion
## how much had _ been taken _ _ _ ? topping $ 3 billion
## In these cases we heuristically choose the shorter answer span, hopefully creating minimal spans
## E.g., in this example two arguemnts are created: (loan commitments, topping $ 3 billion)
elem_ind, elem = min(ent, key = lambda _, ls: len(ls))
# Distinguish between predicate and arguments
prefix = "P" if elem_ind == 0 else "A{}".format(elem_ind - 1)
# Distinguish between Beginning and Inside labels
suffix = "B" if index == elem[0] else "I"
return "{}-{}".format(prefix, suffix)
def __str__(self):
return '{0}\t{1}'.format(self.elementToStr(self.pred,
print_indices = True),
'\t'.join([self.elementToStr(arg)
for arg
in self.args]))
# Flatten a list of lists
flatten = lambda l: [item for sublist in l for item in sublist]
def normalize_element(elem):
"""
Return a surface form of the given question element.
the output should be properly able to precede a predicate (or blank otherwise)
"""
return elem.replace("_", " ") \
if (elem != "_")\
else ""
## Helper functions
def escape_special_chars(s):
return s.replace('\t', '\\t')
def generalize_question(question):
"""
Given a question in the context of the sentence and the predicate index within
the question - return a generalized version which extracts only order-imposing features
"""
import nltk # Using nltk since couldn't get spaCy to agree on the tokenization
wh, aux, sbj, trg, obj1, pp, obj2 = question.split(' ')[:-1] # Last split is the question mark
return ' '.join([wh, sbj, obj1])
## CONSTANTS
SEP = ';;;'
QUESTION_TRG_INDEX = 3 # index of the predicate within the question
QUESTION_PP_INDEX = 5
QUESTION_OBJ2_INDEX = 6

View File

@ -1,44 +0,0 @@
from oie_readers.oieReader import OieReader
from oie_readers.extraction import Extraction
from _collections import defaultdict
import ipdb
class GoldReader(OieReader):
# Path relative to repo root folder
default_filename = './oie_corpus/all.oie'
def __init__(self):
self.name = 'Gold'
def read(self, fn):
d = defaultdict(lambda: [])
with open(fn) as fin:
for line_ind, line in enumerate(fin):
# print line
data = line.strip().split('\t')
text, rel = data[:2]
args = data[2:]
confidence = 1
curExtraction = Extraction(pred = rel.strip(),
head_pred_index = None,
sent = text.strip(),
confidence = float(confidence),
index = line_ind)
for arg in args:
if "C: " in arg:
continue
curExtraction.addArg(arg.strip())
d[text.strip()].append(curExtraction)
self.oie = d
if __name__ == '__main__' :
g = GoldReader()
g.read('../oie_corpus/all.oie', includeNominal = False)
d = g.oie
e = d.items()[0]
print(e[1][0].bow())
print(g.count())

View File

@ -1,45 +0,0 @@
class OieReader:
def read(self, fn, includeNominal):
''' should set oie as a class member
as a dictionary of extractions by sentence'''
raise Exception("Don't run me")
def count(self):
''' number of extractions '''
return sum([len(extractions) for _, extractions in self.oie.items()])
def split_to_corpus(self, corpus_fn, out_fn):
"""
Given a corpus file name, containing a list of sentences
print only the extractions pertaining to it to out_fn in a tab separated format:
sent, prob, pred, arg1, arg2, ...
"""
raw_sents = [line.strip() for line in open(corpus_fn)]
with open(out_fn, 'w') as fout:
for line in self.get_tabbed().split('\n'):
data = line.split('\t')
sent = data[0]
if sent in raw_sents:
fout.write(line + '\n')
def output_tabbed(self, out_fn):
"""
Write a tabbed represenation of this corpus.
"""
with open(out_fn, 'w') as fout:
fout.write(self.get_tabbed())
def get_tabbed(self):
"""
Get a tabbed format representation of this corpus (assumes that input was
already read).
"""
return "\n".join(['\t'.join(map(str,
[ex.sent,
ex.confidence,
ex.pred,
'\t'.join(ex.args)]))
for (sent, exs) in self.oie.iteritems()
for ex in exs])

View File

@ -1,22 +0,0 @@
from oie_readers.oieReader import OieReader
from oie_readers.extraction import Extraction
class OllieReader(OieReader):
def __init__(self):
self.name = 'OLLIE'
def read(self, fn):
d = {}
with open(fn) as fin:
fin.readline() #remove header
for line in fin:
data = line.strip().split('\t')
confidence, arg1, rel, arg2, enabler, attribution, text = data[:7]
curExtraction = Extraction(pred = rel, head_pred_index = -1, sent = text, confidence = float(confidence))
curExtraction.addArg(arg1)
curExtraction.addArg(arg2)
d[text] = d.get(text, []) + [curExtraction]
self.oie = d

View File

@ -1,38 +0,0 @@
from oie_readers.oieReader import OieReader
from oie_readers.extraction import Extraction
class OpenieFiveReader(OieReader):
def __init__(self):
self.name = 'OpenIE-5'
def read(self, fn):
d = {}
with open(fn) as fin:
for line in fin:
data = line.strip().split('\t')
confidence = data[0]
if not all(data[2:5]):
continue
arg1, rel = [s[s.index('(') + 1:s.index(',List(')] for s in data[2:4]]
#args = data[4].strip().split(');')
#print arg2s
args = [s[s.index('(') + 1:s.index(',List(')] for s in data[4].strip().split(');')]
# if arg1 == "the younger La Flesche":
# print len(args)
text = data[5]
if data[1]:
#print arg1, rel
s = data[1]
if not (arg1 + ' ' + rel).startswith(s[s.index('(') + 1:s.index(',List(')]):
#print "##########Not adding context"
arg1 = s[s.index('(') + 1:s.index(',List(')] + ' ' + arg1
#print arg1 + rel, ",,,,, ", s[s.index('(') + 1:s.index(',List(')]
#curExtraction = Extraction(pred = rel, sent = text, confidence = float(confidence))
curExtraction = Extraction(pred = rel, head_pred_index = -1, sent = text, confidence = float(confidence))
curExtraction.addArg(arg1)
for arg in args:
curExtraction.addArg(arg)
d[text] = d.get(text, []) + [curExtraction]
self.oie = d

View File

@ -1,59 +0,0 @@
""" Usage:
<file-name> --in=INPUT_FILE --out=OUTPUT_FILE [--debug]
Convert to tabbed format
"""
# External imports
import logging
from pprint import pprint
from pprint import pformat
from docopt import docopt
# Local imports
from oie_readers.oieReader import OieReader
from oie_readers.extraction import Extraction
import ipdb
#=-----
class OpenieFourReader(OieReader):
def __init__(self):
self.name = 'OpenIE-4'
def read(self, fn):
d = {}
with open(fn) as fin:
for line in fin:
data = line.strip().split('\t')
confidence = data[0]
if not all(data[2:5]):
logging.debug("Skipped line: {}".format(line))
continue
arg1, rel, arg2 = [s[s.index('(') + 1:s.index(',List(')] for s in data[2:5]]
text = data[5]
curExtraction = Extraction(pred = rel, head_pred_index = -1, sent = text, confidence = float(confidence))
curExtraction.addArg(arg1)
curExtraction.addArg(arg2)
d[text] = d.get(text, []) + [curExtraction]
self.oie = d
if __name__ == "__main__":
# Parse command line arguments
args = docopt(__doc__)
inp_fn = args["--in"]
out_fn = args["--out"]
debug = args["--debug"]
if debug:
logging.basicConfig(level = logging.DEBUG)
else:
logging.basicConfig(level = logging.INFO)
oie = OpenieFourReader()
oie.read(inp_fn)
oie.output_tabbed(out_fn)
logging.info("DONE")

View File

@ -1,44 +0,0 @@
from oie_readers.oieReader import OieReader
from oie_readers.extraction import Extraction
class PropSReader(OieReader):
def __init__(self):
self.name = 'PropS'
def read(self, fn):
d = {}
with open(fn) as fin:
for line in fin:
if not line.strip():
continue
data = line.strip().split('\t')
confidence, text, rel = data[:3]
curExtraction = Extraction(pred = rel, sent = text, confidence = float(confidence), head_pred_index=-1)
for arg in data[4::2]:
curExtraction.addArg(arg)
d[text] = d.get(text, []) + [curExtraction]
self.oie = d
# self.normalizeConfidence()
def normalizeConfidence(self):
''' Normalize confidence to resemble probabilities '''
EPSILON = 1e-3
self.confidences = [extraction.confidence for sent in self.oie for extraction in self.oie[sent]]
maxConfidence = max(self.confidences)
minConfidence = min(self.confidences)
denom = maxConfidence - minConfidence + (2*EPSILON)
for sent, extractions in self.oie.items():
for extraction in extractions:
extraction.confidence = ( (extraction.confidence - minConfidence) + EPSILON) / denom

View File

@ -1,29 +0,0 @@
from oie_readers.oieReader import OieReader
from oie_readers.extraction import Extraction
class ReVerbReader(OieReader):
def __init__(self):
self.inputSents = [sent.strip() for sent in open(ReVerbReader.RAW_SENTS_FILE).readlines()]
self.name = 'ReVerb'
def read(self, fn):
d = {}
with open(fn) as fin:
for line in fin:
data = line.strip().split('\t')
arg1, rel, arg2 = data[2:5]
confidence = data[11]
text = self.inputSents[int(data[1])-1]
curExtraction = Extraction(pred = rel, sent = text, confidence = float(confidence))
curExtraction.addArg(arg1)
curExtraction.addArg(arg2)
d[text] = d.get(text, []) + [curExtraction]
self.oie = d
# ReVerb requires a different files from which to get the input sentences
# Relative to repo root folder
RAW_SENTS_FILE = './raw_sentences/all.txt'

View File

@ -1,37 +0,0 @@
""" Usage:
split_corpus --corpus=CORPUS_FN --reader=READER --in=INPUT_FN --out=OUTPUT_FN
Split OIE extractions according to raw sentences.
This is used in order to split a large file into train, dev and test.
READER - points out which oie reader to use (see dictionary for possible entries)
"""
from clausieReader import ClausieReader
from ollieReader import OllieReader
from openieFourReader import OpenieFourReader
from propsReader import PropSReader
from reVerbReader import ReVerbReader
from stanfordReader import StanfordReader
from docopt import docopt
import logging
logging.basicConfig(level = logging.INFO)
available_readers = {
"clausie": ClausieReader,
"ollie": OllieReader,
"openie4": OpenieFourReader,
"props": PropSReader,
"reverb": ReVerbReader,
"stanford": StanfordReader
}
if __name__ == "__main__":
args = docopt(__doc__)
inp = args["--in"]
out = args["--out"]
corpus = args["--corpus"]
reader = available_readers[args["--reader"]]()
reader.read(inp)
reader.split_to_corpus(corpus,
out)

View File

@ -1,22 +0,0 @@
from oie_readers.oieReader import OieReader
from oie_readers.extraction import Extraction
class StanfordReader(OieReader):
def __init__(self):
self.name = 'Stanford'
def read(self, fn):
d = {}
with open(fn) as fin:
for line in fin:
data = line.strip().split('\t')
arg1, rel, arg2 = data[2:5]
confidence = data[11]
text = data[12]
curExtraction = Extraction(pred = rel, head_pred_index = -1, sent = text, confidence = float(confidence))
curExtraction.addArg(arg1)
curExtraction.addArg(arg2)
d[text] = d.get(text, []) + [curExtraction]
self.oie = d

View File

@ -1,56 +0,0 @@
""" Usage:
tabReader --in=INPUT_FILE
Read a tab-formatted file.
Each line consists of:
sent, prob, pred, arg1, arg2, ...
"""
from oie_readers.oieReader import OieReader
from oie_readers.extraction import Extraction
from docopt import docopt
import logging
import ipdb
logging.basicConfig(level = logging.DEBUG)
class TabReader(OieReader):
def __init__(self):
self.name = 'TabReader'
def read(self, fn):
"""
Read a tabbed format line
Each line consists of:
sent, prob, pred, arg1, arg2, ...
"""
d = {}
ex_index = 0
with open(fn) as fin:
for line in fin:
if not line.strip():
continue
data = line.strip().split('\t')
text, confidence, rel = data[:3]
curExtraction = Extraction(pred = rel,
head_pred_index = None,
sent = text,
confidence = float(confidence),
question_dist = "./question_distributions/dist_wh_sbj_obj1.json",
index = ex_index)
ex_index += 1
for arg in data[3:]:
curExtraction.addArg(arg)
d[text] = d.get(text, []) + [curExtraction]
self.oie = d
if __name__ == "__main__":
args = docopt(__doc__)
input_fn = args["--in"]
tr = TabReader()
tr.read(input_fn)

View File

@ -1,59 +0,0 @@
""" Usage:
tabReader --in=INPUT_FILE
Read a tab-formatted file.
Each line consists of:
sent, prob, pred, arg1, arg2, ...
"""
from carb.oieReader import OieReader
from carb.extraction import Extraction
from docopt import docopt
import logging
import ipdb
logging.basicConfig(level = logging.DEBUG)
class TabReader(OieReader):
def __init__(self):
self.name = 'TabReader'
def read(self, fn):
"""
Read a tabbed format line
Each line consists of:
sent, prob, pred, arg1, arg2, ...
"""
d = {}
ex_index = 0
with open(fn) as fin:
for line in fin:
if not line.strip():
continue
data = line.strip().split('\t')
try:
text, confidence, rel = data[:3]
except ValueError:
continue
curExtraction = Extraction(pred = rel,
head_pred_index = None,
sent = text,
confidence = float(confidence),
question_dist = "./question_distributions/dist_wh_sbj_obj1.json",
index = ex_index)
ex_index += 1
for arg in data[3:]:
curExtraction.addArg(arg)
d[text] = d.get(text, []) + [curExtraction]
self.oie = d
if __name__ == "__main__":
args = docopt(__doc__)
input_fn = args["--in"]
tr = TabReader()
tr.read(input_fn)

View File

@ -1,162 +0,0 @@
import torch
import numpy as np
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from utils import utils
from transformers import BertTokenizer
from utils.bio import pred_tag2idx, arg_tag2idx
def load_data(data_path,
batch_size,
max_len=64,
train=True,
tokenizer_config='bert-base-cased'):
if train:
return DataLoader(
dataset=OieDataset(
data_path,
max_len,
tokenizer_config),
batch_size=batch_size,
shuffle=True,
num_workers=4,
pin_memory=True,
drop_last=True)
else:
return DataLoader(
dataset=OieEvalDataset(
data_path,
max_len,
tokenizer_config),
batch_size=batch_size,
num_workers=4,
pin_memory=True)
class OieDataset(Dataset):
def __init__(self, data_path, max_len=64, tokenizer_config='bert-base-cased'):
data = utils.load_pkl(data_path)
self.tokens = data['tokens']
self.single_pred_labels = data['single_pred_labels']
self.single_arg_labels = data['single_arg_labels']
self.all_pred_labels = data['all_pred_labels']
self.max_len = max_len
self.tokenizer = BertTokenizer.from_pretrained(tokenizer_config)
self.vocab = self.tokenizer.vocab
self.pad_idx = self.vocab['[PAD]']
self.cls_idx = self.vocab['[CLS]']
self.sep_idx = self.vocab['[SEP]']
self.mask_idx = self.vocab['[MASK]']
def add_pad(self, token_ids):
diff = self.max_len - len(token_ids)
if diff > 0:
token_ids += [self.pad_idx] * diff
else:
token_ids = token_ids[:self.max_len-1] + [self.sep_idx]
return token_ids
def add_special_token(self, token_ids):
return [self.cls_idx] + token_ids + [self.sep_idx]
def idx2mask(self, token_ids):
return [token_id != self.pad_idx for token_id in token_ids]
def add_pad_to_labels(self, pred_label, arg_label, all_pred_label):
pred_outside = np.array([pred_tag2idx['O']])
arg_outside = np.array([arg_tag2idx['O']])
pred_label = np.concatenate([pred_outside, pred_label, pred_outside])
arg_label = np.concatenate([arg_outside, arg_label, arg_outside])
all_pred_label = np.concatenate([pred_outside, all_pred_label, pred_outside])
diff = self.max_len - pred_label.shape[0]
if diff > 0:
pred_pad = np.array([pred_tag2idx['O']] * diff)
arg_pad = np.array([arg_tag2idx['O']] * diff)
pred_label = np.concatenate([pred_label, pred_pad])
arg_label = np.concatenate([arg_label, arg_pad])
all_pred_label = np.concatenate([all_pred_label, pred_pad])
elif diff == 0:
pass
else:
pred_label = np.concatenate([pred_label[:-1], pred_outside])
arg_label = np.concatenate([arg_label[:-1], arg_outside])
all_pred_label = np.concatenate([all_pred_label[:-1], pred_outside])
return [pred_label, arg_label, all_pred_label]
def __len__(self):
return len(self.tokens)
def __getitem__(self, idx):
token_ids = self.tokenizer.convert_tokens_to_ids(self.tokens[idx])
token_ids_padded = self.add_pad(self.add_special_token(token_ids))
att_mask = self.idx2mask(token_ids_padded)
labels = self.add_pad_to_labels(
self.single_pred_labels[idx],
self.single_arg_labels[idx],
self.all_pred_labels[idx])
single_pred_label, single_arg_label, all_pred_label = labels
assert len(token_ids_padded) == self.max_len
assert len(att_mask) == self.max_len
assert single_pred_label.shape[0] == self.max_len
assert single_arg_label.shape[0] == self.max_len
assert all_pred_label.shape[0] == self.max_len
batch = [
torch.tensor(token_ids_padded),
torch.tensor(att_mask),
torch.tensor(single_pred_label),
torch.tensor(single_arg_label),
torch.tensor(all_pred_label)
]
return batch
class OieEvalDataset(Dataset):
def __init__(self, data_path, max_len, tokenizer_config='bert-base-cased'):
self.sentences = utils.load_pkl(data_path)
self.tokenizer = BertTokenizer.from_pretrained(tokenizer_config)
self.vocab = self.tokenizer.vocab
self.max_len = max_len
self.pad_idx = self.vocab['[PAD]']
self.cls_idx = self.vocab['[CLS]']
self.sep_idx = self.vocab['[SEP]']
self.mask_idx = self.vocab['[MASK]']
def add_pad(self, token_ids):
diff = self.max_len - len(token_ids)
if diff > 0:
token_ids += [self.pad_idx] * diff
else:
token_ids = token_ids[:self.max_len-1] + [self.sep_idx]
return token_ids
def idx2mask(self, token_ids):
return [token_id != self.pad_idx for token_id in token_ids]
def __len__(self):
return len(self.sentences)
def __getitem__(self, idx):
token_ids = self.add_pad(self.tokenizer.encode(self.sentences[idx]))
att_mask = self.idx2mask(token_ids)
token_strs = self.tokenizer.convert_ids_to_tokens(token_ids)
sentence = self.sentences[idx]
assert len(token_ids) == self.max_len
assert len(att_mask) == self.max_len
assert len(token_strs) == self.max_len
batch = [
torch.tensor(token_ids),
torch.tensor(att_mask),
token_strs,
sentence
]
return batch

View File

@ -1,97 +0,0 @@
name: multi2oie
channels:
- conda-forge
- defaults
dependencies:
- _libgcc_mutex=0.1=main
- _pytorch_select=0.2=gpu_0
- attrs=19.3.0=py_0
- blas=1.0=mkl
- brotlipy=0.7.0=py37h8f50634_1000
- ca-certificates=2020.4.5.1=hecc5488_0
- catalogue=1.0.0=py_0
- certifi=2020.4.5.1=py37hc8dfbb8_0
- cffi=1.14.0=py37he30daa8_1
- chardet=3.0.4=py37hc8dfbb8_1006
- cryptography=2.9.2=py37hb09aad4_0
- cudatoolkit=10.1.243=h6bb024c_0
- cudnn=7.6.5=cuda10.1_0
- cymem=2.0.3=py37h3340039_1
- cython-blis=0.4.1=py37h8f50634_1
- idna=2.9=py_1
- importlib-metadata=1.6.0=py37hc8dfbb8_0
- importlib_metadata=1.6.0=0
- intel-openmp=2020.1=217
- joblib=0.15.1=py_0
- jsonschema=3.2.0=py37hc8dfbb8_1
- ld_impl_linux-64=2.33.1=h53a641e_7
- libedit=3.1.20181209=hc058e9b_0
- libffi=3.3=he6710b0_1
- libgcc-ng=9.1.0=hdf63c60_0
- libgfortran-ng=7.3.0=hdf63c60_0
- libstdcxx-ng=9.1.0=hdf63c60_0
- mkl=2020.1=217
- mkl-service=2.3.0=py37he904b0f_0
- mkl_fft=1.0.15=py37ha843d7b_0
- mkl_random=1.1.1=py37h0573a6f_0
- murmurhash=1.0.0=py37h3340039_0
- ncurses=6.2=he6710b0_1
- ninja=1.9.0=py37hfd86e86_0
- numpy=1.18.1=py37h4f9e942_0
- numpy-base=1.18.1=py37hde5b4d6_1
- openssl=1.1.1g=h516909a_0
- pandas=1.0.3=py37h0573a6f_0
- pip=20.0.2=py37_3
- plac=0.9.6=py37_0
- preshed=3.0.2=py37h3340039_2
- pycparser=2.20=py_0
- pyopenssl=19.1.0=py_1
- pyrsistent=0.16.0=py37h8f50634_0
- pysocks=1.7.1=py37hc8dfbb8_1
- python=3.7.7=hcff3b4d_5
- python-dateutil=2.8.1=py_0
- python_abi=3.7=1_cp37m
- pytorch=1.4.0=cuda101py37h02f0884_0
- pytz=2020.1=py_0
- readline=8.0=h7b6447c_0
- requests=2.23.0=pyh8c360ce_2
- scikit-learn=0.22.1=py37hd81dba3_0
- scipy=1.4.1=py37h0b6359f_0
- setuptools=46.4.0=py37_0
- six=1.14.0=py37_0
- spacy=2.2.4=py37h99015e2_1
- sqlite=3.31.1=h62c20be_1
- srsly=1.0.2=py37h3340039_0
- thinc=7.4.0=py37h99015e2_2
- tk=8.6.8=hbc83047_0
- tqdm=4.46.0=pyh9f0ad1d_0
- urllib3=1.25.9=py_0
- wasabi=0.6.0=py_0
- wheel=0.34.2=py37_0
- xz=5.2.5=h7b6447c_0
- zipp=3.1.0=py_0
- zlib=1.2.11=h7b6447c_3
- pip:
- backcall==0.1.0
- click==7.1.2
- decorator==4.4.2
- docopt==0.6.2
- filelock==3.0.12
- ipdb==0.13.2
- ipython==7.14.0
- ipython-genutils==0.2.0
- jedi==0.17.0
- nltk==3.5
- parso==0.7.0
- pexpect==4.8.0
- pickleshare==0.7.5
- prompt-toolkit==3.0.5
- ptyprocess==0.6.0
- pygments==2.6.1
- regex==2020.5.14
- sacremoses==0.0.43
- sentencepiece==0.1.91
- tokenizers==0.7.0
- traitlets==4.3.3
- transformers==2.10.0
- wcwidth==0.1.9

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

View File

@ -1,21 +0,0 @@
import nltk
from operator import itemgetter
class Argument:
def __init__(self, arg):
self.words = [x for x in arg[0].strip().split(' ') if x]
self.posTags = map(itemgetter(1), nltk.pos_tag(self.words))
self.indices = arg[1]
self.feats = {}
def __str__(self):
return "({})".format('\t'.join(map(str,
[escape_special_chars(' '.join(self.words)),
str(self.indices)])))
COREF = 'coref'
## Helper functions
def escape_special_chars(s):
return s.replace('\t', '\\t')

File diff suppressed because it is too large Load Diff

View File

@ -1,219 +0,0 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Oct 25 19:24:30 2018
@author: longzhan
"""
import string
import numpy as np
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import auc
import re
import logging
import sys
logging.basicConfig(level=logging.INFO)
from evaluate.generalReader import GeneralReader
from evaluate.goldReader import GoldReader
from evaluate.gold_relabel import Relabel_GoldReader
from evaluate.matcher import Matcher
from operator import itemgetter
class Benchmark:
''' Compare the gold OIE dataset against a predicted equivalent '''
def __init__(self, gold_fn):
''' Load gold Open IE, this will serve to compare against using the compare function '''
if 'Re' in gold_fn:
gr = Relabel_GoldReader()
else:
gr = GoldReader()
gr.read(gold_fn)
self.gold = gr.oie
def compare(self, predicted, matchingFunc, output_fn, error_file = None):
''' Compare gold against predicted using a specified matching function.
Outputs PR curve to output_fn '''
y_true = []
y_scores = []
errors = []
correctTotal = 0
unmatchedCount = 0
non_sent = 0
non_match = 0
predicted = Benchmark.normalizeDict(predicted)
gold = Benchmark.normalizeDict(self.gold)
for sent, goldExtractions in gold.items():
if sent not in predicted:
# The extractor didn't find any extractions for this sentence
for goldEx in goldExtractions:
non_sent += 1
unmatchedCount += len(goldExtractions)
correctTotal += len(goldExtractions)
continue
predictedExtractions = predicted[sent]
for goldEx in goldExtractions:
correctTotal += 1
found = False
for predictedEx in predictedExtractions:
if output_fn in predictedEx.matched:
# This predicted extraction was already matched against a gold extraction
# Don't allow to match it again
continue
if matchingFunc(goldEx,
predictedEx,
ignoreStopwords=True,
ignoreCase=True):
y_true.append(1)
y_scores.append(predictedEx.confidence)
predictedEx.matched.append(output_fn)
found = True
break
if not found:
non_match += 1
errors.append(goldEx.index)
unmatchedCount += 1
for predictedEx in [x for x in predictedExtractions if (output_fn not in x.matched)]:
# Add false positives
y_true.append(0)
y_scores.append(predictedEx.confidence)
y_true = y_true
y_scores = y_scores
print("non_sent: ", non_sent)
print("non_match: ", non_match)
print("correctTotal: ", correctTotal)
print("unmatchedCount: ", unmatchedCount)
# recall on y_true, y (r')_scores computes |covered by extractor| / |True in what's covered by extractor|
# to get to true recall we do:
# r' * (|True in what's covered by extractor| / |True in gold|) = |true in what's covered| / |true in gold|
(p, r), optimal = Benchmark.prCurve(np.array(y_true), np.array(y_scores),
recallMultiplier = ((correctTotal - unmatchedCount)/float(correctTotal)))
cur_auc = auc(r, p)
print("AUC: {}\n Optimal (precision, recall, F1, threshold): {}".format(cur_auc, optimal))
# Write error log to file
if error_file:
logging.info("Writing {} error indices to {}".format(len(errors),
error_file))
with open(error_file, 'w') as fout:
fout.write('\n'.join([str(error)
for error
in errors]) + '\n')
# write PR to file
with open(output_fn, 'w') as fout:
fout.write('{0}\t{1}\n'.format("Precision", "Recall"))
for cur_p, cur_r in sorted(zip(p, r), key = lambda cur: cur[1]):
fout.write('{0}\t{1}\n'.format(cur_p, cur_r))
return optimal[:-1], cur_auc
@staticmethod
def prCurve(y_true, y_scores, recallMultiplier):
# Recall multiplier - accounts for the percentage examples unreached
# Return (precision [list], recall[list]), (Optimal F1, Optimal threshold)
y_scores = [score \
if not (np.isnan(score) or (not np.isfinite(score))) \
else 0
for score in y_scores]
precision_ls, recall_ls, thresholds = precision_recall_curve(y_true, y_scores)
recall_ls = recall_ls * recallMultiplier
optimal = max([(precision, recall, f_beta(precision, recall, beta = 1), threshold)
for ((precision, recall), threshold)
in zip(zip(precision_ls[:-1], recall_ls[:-1]),
thresholds)],
key=itemgetter(2)) # Sort by f1 score
return ((precision_ls, recall_ls),
optimal)
# Helper functions:
@staticmethod
def normalizeDict(d):
return dict([(Benchmark.normalizeKey(k), v) for k, v in d.items()])
@staticmethod
def normalizeKey(k):
return Benchmark.removePunct(Benchmark.PTB_unescape(k.replace(' ','')))
@staticmethod
def PTB_escape(s):
for u, e in Benchmark.PTB_ESCAPES:
s = s.replace(u, e)
return s
@staticmethod
def PTB_unescape(s):
for u, e in Benchmark.PTB_ESCAPES:
s = s.replace(e, u)
return s
@staticmethod
def removePunct(s):
return Benchmark.regex.sub('', s)
# CONSTANTS
regex = re.compile('[%s]' % re.escape(string.punctuation))
# Penn treebank bracket escapes
# Taken from: https://github.com/nlplab/brat/blob/master/server/src/gtbtokenize.py
PTB_ESCAPES = [('(', '-LRB-'),
(')', '-RRB-'),
('[', '-LSB-'),
(']', '-RSB-'),
('{', '-LCB-'),
('}', '-RCB-'),]
def f_beta(precision, recall, beta = 1):
"""
Get F_beta score from precision and recall.
"""
beta = float(beta) # Make sure that results are in float
return (1 + pow(beta, 2)) * (precision * recall) / ((pow(beta, 2) * precision) + recall)
f1 = lambda precision, recall: f_beta(precision, recall, beta = 1)
#gold_flag = sys.argv[1] # to choose whether to use OIE2016 or Re-OIE2016
#in_path = sys.argv[2] # input file
#out_path = sys.argv[3] # output file
if __name__ == '__main__':
gold = gold_flag
matchingFunc = Matcher.lexicalMatch
error_fn = "error.txt"
if gold == "old":
gold_fn = "OIE2016.txt"
else:
gold_fn = "Re-OIE2016.json"
b = Benchmark(gold_fn)
s_fn = in_path
p = GeneralReader()
other_p = GeneralReader()
other_p.read(s_fn)
b.compare(predicted = other_p.oie,
matchingFunc = matchingFunc,
output_fn = out_path,
error_file = error_fn)

View File

@ -1,444 +0,0 @@
from sklearn.preprocessing.data import binarize
from evaluate.argument import Argument
from operator import itemgetter
from collections import defaultdict
import nltk
import itertools
import logging
import numpy as np
import pdb
class Extraction:
"""
Stores sentence, single predicate and corresponding arguments.
"""
def __init__(self, pred, head_pred_index, sent, confidence, question_dist = '', index = -1):
self.pred = pred
self.head_pred_index = head_pred_index
self.sent = sent
self.args = []
self.confidence = confidence
self.matched = []
self.questions = {}
self.indsForQuestions = defaultdict(lambda: set())
self.is_mwp = False
self.question_dist = question_dist
self.index = index
def distArgFromPred(self, arg):
assert(len(self.pred) == 2)
dists = []
for x in self.pred[1]:
for y in arg.indices:
dists.append(abs(x - y))
return min(dists)
def argsByDistFromPred(self, question):
return sorted(self.questions[question], key = lambda arg: self.distArgFromPred(arg))
def addArg(self, arg, question = None):
self.args.append(arg)
if question:
self.questions[question] = self.questions.get(question,[]) + [Argument(arg)]
def noPronounArgs(self):
"""
Returns True iff all of this extraction's arguments are not pronouns.
"""
for (a, _) in self.args:
tokenized_arg = nltk.word_tokenize(a)
if len(tokenized_arg) == 1:
_, pos_tag = nltk.pos_tag(tokenized_arg)[0]
if ('PRP' in pos_tag):
return False
return True
def isContiguous(self):
return all([indices for (_, indices) in self.args])
def toBinary(self):
''' Try to represent this extraction's arguments as binary
If fails, this function will return an empty list. '''
ret = [self.elementToStr(self.pred)]
if len(self.args) == 2:
# we're in luck
return ret + [self.elementToStr(arg) for arg in self.args]
return []
if not self.isContiguous():
# give up on non contiguous arguments (as we need indexes)
return []
# otherwise, try to merge based on indices
# TODO: you can explore other methods for doing this
binarized = self.binarizeByIndex()
if binarized:
return ret + binarized
return []
def elementToStr(self, elem, print_indices = True):
''' formats an extraction element (pred or arg) as a raw string
removes indices and trailing spaces '''
if print_indices:
return str(elem)
if isinstance(elem, str):
return elem
if isinstance(elem, tuple):
ret = elem[0].rstrip().lstrip()
else:
ret = ' '.join(elem.words)
assert ret, "empty element? {0}".format(elem)
return ret
def binarizeByIndex(self):
extraction = [self.pred] + self.args
markPred = [(w, ind, i == 0) for i, (w, ind) in enumerate(extraction)]
sortedExtraction = sorted(markPred, key = lambda x : x[1][0])
s = ' '.join(['{1} {0} {1}'.format(self.elementToStr(elem), SEP) if elem[2] else self.elementToStr(elem) for elem in sortedExtraction])
binArgs = [a for a in s.split(SEP) if a.rstrip().lstrip()]
if len(binArgs) == 2:
return binArgs
# failure
return []
def bow(self):
return ' '.join([self.elementToStr(elem) for elem in [self.pred] + self.args])
def getSortedArgs(self):
"""
Sort the list of arguments.
If a question distribution is provided - use it,
otherwise, default to the order of appearance in the sentence.
"""
if self.question_dist:
# There's a question distribtuion - use it
return self.sort_args_by_distribution()
ls = []
for q, args in self.questions.iteritems():
if (len(args) != 1):
logging.debug("Not one argument: {}".format(args))
continue
arg = args[0]
indices = list(self.indsForQuestions[q].union(arg.indices))
if not indices:
logging.debug("Empty indexes for arg {} -- backing to zero".format(arg))
indices = [0]
ls.append(((arg, q), indices))
return [a for a, _ in sorted(ls,key = lambda x: min(x[1]))]
def question_prob_for_loc(self, question, loc):
"""
Returns the probability of the given question leading to argument
appearing in the given location in the output slot.
"""
gen_question = generalize_question(question)
q_dist = self.question_dist[gen_question]
logging.debug("distribution of {}: {}".format(gen_question,
q_dist))
return float(q_dist.get(loc, 0)) / \
sum(q_dist.values())
def sort_args_by_distribution(self):
"""
Use this instance's question distribution (this func assumes it exists)
in determining the positioning of the arguments.
Greedy algorithm:
0. Decide on which argument will serve as the ``subject'' (first slot) of this extraction
0.1 Based on the most probable one for this spot
(special care is given to select the highly-influential subject position)
1. For all other arguments, sort arguments by the prevalance of their questions
2. For each argument:
2.1 Assign to it the most probable slot still available
2.2 If non such exist (fallback) - default to put it in the last location
"""
INF_LOC = 100 # Used as an impractical last argument
# Store arguments by slot
ret = {INF_LOC: []}
logging.debug("sorting: {}".format(self.questions))
# Find the most suitable arguemnt for the subject location
logging.debug("probs for subject: {}".format([(q, self.question_prob_for_loc(q, 0))
for (q, _) in self.questions.iteritems()]))
subj_question, subj_args = max(self.questions.iteritems(),
key = lambda x: self.question_prob_for_loc(x[0], 0))
ret[0] = [(subj_args[0], subj_question)]
# Find the rest
for (question, args) in sorted([(q, a)
for (q, a) in self.questions.iteritems() if (q not in [subj_question])],
key = lambda x: \
sum(self.question_dist[generalize_question(x[0])].values()),
reverse = True):
gen_question = generalize_question(question)
arg = args[0]
assigned_flag = False
for (loc, count) in sorted(self.question_dist[gen_question].iteritems(),
key = lambda x: x[1],
reverse = True):
if loc not in ret:
# Found an empty slot for this item
# Place it there and break out
ret[loc] = [(arg, question)]
assigned_flag = True
break
if not assigned_flag:
# Add this argument to the non-assigned (hopefully doesn't happen much)
logging.debug("Couldn't find an open assignment for {}".format((arg, gen_question)))
ret[INF_LOC].append((arg, question))
logging.debug("Linearizing arg list: {}".format(ret))
# Finished iterating - consolidate and return a list of arguments
return [arg
for (_, arg_ls) in sorted(ret.iteritems(),
key = lambda x: int(x[0]))
for arg in arg_ls]
def __str__(self):
pred_str = self.elementToStr(self.pred)
return '{}\t{}\t{}'.format(self.get_base_verb(pred_str),
self.compute_global_pred(pred_str,
self.questions.keys()),
'\t'.join([escape_special_chars(self.augment_arg_with_question(self.elementToStr(arg),
question))
for arg, question in self.getSortedArgs()]))
def get_base_verb(self, surface_pred):
"""
Given the surface pred, return the original annotated verb
"""
# Assumes that at this point the verb is always the last word
# in the surface predicate
return surface_pred.split(' ')[-1]
def compute_global_pred(self, surface_pred, questions):
"""
Given the surface pred and all instansiations of questions,
make global coherence decisions regarding the final form of the predicate
This should hopefully take care of multi word predicates and correct inflections
"""
from operator import itemgetter
split_surface = surface_pred.split(' ')
if len(split_surface) > 1:
# This predicate has a modal preceding the base verb
verb = split_surface[-1]
ret = split_surface[:-1] # get all of the elements in the modal
else:
verb = split_surface[0]
ret = []
split_questions = map(lambda question: question.split(' '),
questions)
preds = map(normalize_element,
map(itemgetter(QUESTION_TRG_INDEX),
split_questions))
if len(set(preds)) > 1:
# This predicate is appears in multiple ways, let's stick to the base form
ret.append(verb)
if len(set(preds)) == 1:
# Change the predciate to the inflected form
# if there's exactly one way in which the predicate is conveyed
ret.append(preds[0])
pps = map(normalize_element,
map(itemgetter(QUESTION_PP_INDEX),
split_questions))
obj2s = map(normalize_element,
map(itemgetter(QUESTION_OBJ2_INDEX),
split_questions))
if (len(set(pps)) == 1):
# If all questions for the predicate include the same pp attachemnt -
# assume it's a multiword predicate
self.is_mwp = True # Signal to arguments that they shouldn't take the preposition
ret.append(pps[0])
# Concat all elements in the predicate and return
return " ".join(ret).strip()
def augment_arg_with_question(self, arg, question):
"""
Decide what elements from the question to incorporate in the given
corresponding argument
"""
# Parse question
wh, aux, sbj, trg, obj1, pp, obj2 = map(normalize_element,
question.split(' ')[:-1]) # Last split is the question mark
# Place preposition in argument
# This is safer when dealing with n-ary arguments, as it's directly attaches to the
# appropriate argument
if (not self.is_mwp) and pp and (not obj2):
if not(arg.startswith("{} ".format(pp))):
# Avoid repeating the preporition in cases where both question and answer contain it
return " ".join([pp,
arg])
# Normal cases
return arg
def clusterScore(self, cluster):
"""
Calculate cluster density score as the mean distance of the maximum distance of each slot.
Lower score represents a denser cluster.
"""
logging.debug("*-*-*- Cluster: {}".format(cluster))
# Find global centroid
arr = np.array([x for ls in cluster for x in ls])
centroid = np.sum(arr)/arr.shape[0]
logging.debug("Centroid: {}".format(centroid))
# Calculate mean over all maxmimum points
return np.average([max([abs(x - centroid) for x in ls]) for ls in cluster])
def resolveAmbiguity(self):
"""
Heursitic to map the elments (argument and predicates) of this extraction
back to the indices of the sentence.
"""
## TODO: This removes arguments for which there was no consecutive span found
## Part of these are non-consecutive arguments,
## but other could be a bug in recognizing some punctuation marks
elements = [self.pred] \
+ [(s, indices)
for (s, indices)
in self.args
if indices]
logging.debug("Resolving ambiguity in: {}".format(elements))
# Collect all possible combinations of arguments and predicate indices
# (hopefully it's not too much)
all_combinations = list(itertools.product(*map(itemgetter(1), elements)))
logging.debug("Number of combinations: {}".format(len(all_combinations)))
# Choose the ones with best clustering and unfold them
resolved_elements = zip(map(itemgetter(0), elements),
min(all_combinations,
key = lambda cluster: self.clusterScore(cluster)))
logging.debug("Resolved elements = {}".format(resolved_elements))
self.pred = resolved_elements[0]
self.args = resolved_elements[1:]
def conll(self, external_feats = {}):
"""
Return a CoNLL string representation of this extraction
"""
return '\n'.join(["\t".join(map(str,
[i, w] + \
list(self.pred) + \
[self.head_pred_index] + \
external_feats + \
[self.get_label(i)]))
for (i, w)
in enumerate(self.sent.split(" "))]) + '\n'
def get_label(self, index):
"""
Given an index of a word in the sentence -- returns the appropriate BIO conll label
Assumes that ambiguation was already resolved.
"""
# Get the element(s) in which this index appears
ent = [(elem_ind, elem)
for (elem_ind, elem)
in enumerate(map(itemgetter(1),
[self.pred] + self.args))
if index in elem]
if not ent:
# index doesnt appear in any element
return "O"
if len(ent) > 1:
# The same word appears in two different answers
# In this case we choose the first one as label
logging.warn("Index {} appears in one than more element: {}".\
format(index,
"\t".join(map(str,
[ent,
self.sent,
self.pred,
self.args]))))
## Some indices appear in more than one argument (ones where the above message appears)
## From empricial observation, these seem to mostly consist of different levels of granularity:
## what had _ been taken _ _ _ ? loan commitments topping $ 3 billion
## how much had _ been taken _ _ _ ? topping $ 3 billion
## In these cases we heuristically choose the shorter answer span, hopefully creating minimal spans
## E.g., in this example two arguemnts are created: (loan commitments, topping $ 3 billion)
elem_ind, elem = min(ent, key = lambda x: len(x[1]))
# Distinguish between predicate and arguments
prefix = "P" if elem_ind == 0 else "A{}".format(elem_ind - 1)
# Distinguish between Beginning and Inside labels
suffix = "B" if index == elem[0] else "I"
return "{}-{}".format(prefix, suffix)
def __str__(self):
return '{0}\t{1}'.format(self.elementToStr(self.pred,
print_indices = True),
'\t'.join([self.elementToStr(arg)
for arg
in self.args]))
# Flatten a list of lists
flatten = lambda l: [item for sublist in l for item in sublist]
def normalize_element(elem):
"""
Return a surface form of the given question element.
the output should be properly able to precede a predicate (or blank otherwise)
"""
return elem.replace("_", " ") \
if (elem != "_")\
else ""
## Helper functions
def escape_special_chars(s):
return s.replace('\t', '\\t')
def generalize_question(question):
"""
Given a question in the context of the sentence and the predicate index within
the question - return a generalized version which extracts only order-imposing features
"""
import nltk # Using nltk since couldn't get spaCy to agree on the tokenization
wh, aux, sbj, trg, obj1, pp, obj2 = question.split(' ')[:-1] # Last split is the question mark
return ' '.join([wh, sbj, obj1])
## CONSTANTS
SEP = ';;;'
QUESTION_TRG_INDEX = 3 # index of the predicate within the question
QUESTION_PP_INDEX = 5
QUESTION_OBJ2_INDEX = 6

View File

@ -1,43 +0,0 @@
# -*- coding: utf-8 -*-
"""
Created on Thu Nov 15 21:05:15 2018
@author: win 10
"""
from evaluate.oieReader import OieReader
from evaluate.extraction import Extraction
class GeneralReader(OieReader):
def __init__(self):
self.name = 'General'
def read(self, fn):
d = {}
with open(fn) as fin:
for line in fin:
data = line.strip().split('\t')
if len(data) >= 4:
arg1 = data[3]
rel = data[2]
arg_else = data[4:]
confidence = data[1]
text = data[0]
curExtraction = Extraction(pred = rel, head_pred_index=-1, sent = text, confidence = float(confidence))
curExtraction.addArg(arg1)
for arg in arg_else:
curExtraction.addArg(arg)
d[text] = d.get(text, []) + [curExtraction]
self.oie = d
if __name__ == "__main__":
fn = "../data/other_systems/openie4_test.txt"
reader = GeneralReader()
reader.read(fn)
for key in reader.oie:
print(key)
print(reader.oie[key][0].pred)
print(reader.oie[key][0].args)
print(reader.oie[key][0].confidence)

View File

@ -1,58 +0,0 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Oct 25 19:24:30 2018
@author: longzhan
"""
from evaluate.oieReader import OieReader
from evaluate.extraction import Extraction
from _collections import defaultdict
class GoldReader(OieReader):
# Path relative to repo root folder
default_filename = './oie_corpus/all.oie'
def __init__(self):
self.name = 'Gold'
def read(self, fn):
d = defaultdict(lambda: [])
multilingual = False
for lang in ['spanish']:
if lang in fn:
multilingual = True
encoding = lang
break
if multilingual and encoding == 'spanish':
fin = open(fn, 'r', encoding='latin-1')
else:
fin = open(fn)
for line_ind, line in enumerate(fin):
data = line.strip().split('\t')
text, rel = data[:2]
args = data[2:]
confidence = 1
curExtraction = Extraction(pred = rel.strip(),
head_pred_index = None,
sent = text.strip(),
confidence = float(confidence),
index = line_ind)
for arg in args:
curExtraction.addArg(arg)
d[text.strip()].append(curExtraction)
self.oie = d
if __name__ == '__main__' :
g = GoldReader()
g.read('../test.oie')
d = g.oie
e = list(d.items())[0]
print (e[1][0].bow())
print (g.count())

View File

@ -1,46 +0,0 @@
from evaluate.oieReader import OieReader
from evaluate.extraction import Extraction
from _collections import defaultdict
import json
class Relabel_GoldReader(OieReader):
# Path relative to repo root folder
default_filename = './oie_corpus/all.oie'
def __init__(self):
self.name = 'Relabel_Gold'
def read(self, fn):
d = defaultdict(lambda: [])
with open(fn) as fin:
data = json.load(fin)
for sentence in data:
tuples = data[sentence]
for t in tuples:
if t["pred"].strip() == "<be>":
rel = "[is]"
else:
rel = t["pred"].replace("<be> ","")
confidence = 1
curExtraction = Extraction(pred = rel,
head_pred_index = None,
sent = sentence,
confidence = float(confidence),
index = None)
if t["arg0"] != "":
curExtraction.addArg(t["arg0"])
if t["arg1"] != "":
curExtraction.addArg(t["arg1"])
if t["arg2"] != "":
curExtraction.addArg(t["arg2"])
if t["arg3"] != "":
curExtraction.addArg(t["arg3"])
if t["temp"] != "":
curExtraction.addArg(t["temp"])
if t["loc"] != "":
curExtraction.addArg(t["loc"])
d[sentence].append(curExtraction)
self.oie = d

View File

@ -1,109 +0,0 @@
import string
import nltk
from nltk.translate.bleu_score import sentence_bleu
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
nltk.download('wordnet', quiet=True)
lemmatizer = WordNetLemmatizer()
class Matcher:
@staticmethod
def bowMatch(ref, ex, ignoreStopwords, ignoreCase):
"""
A binary function testing for exact lexical match (ignoring ordering) between reference
and predicted extraction
"""
s1 = ref.bow()
s2 = ex.bow()
if ignoreCase:
s1 = s1.lower()
s2 = s2.lower()
s1Words = s1.split(' ')
s2Words = s2.split(' ')
if ignoreStopwords:
s1Words = Matcher.removeStopwords(s1Words)
s2Words = Matcher.removeStopwords(s2Words)
return sorted(s1Words) == sorted(s2Words)
@staticmethod
def predMatch(ref, ex, ignoreStopwords, ignoreCase):
"""
Return whehter gold and predicted extractions agree on the predicate
"""
s1 = ref.elementToStr(ref.pred)
s2 = ex.elementToStr(ex.pred)
if ignoreCase:
s1 = s1.lower()
s2 = s2.lower()
s1Words = s1.split(' ')
s2Words = s2.split(' ')
if ignoreStopwords:
s1Words = Matcher.removeStopwords(s1Words)
s2Words = Matcher.removeStopwords(s2Words)
return s1Words == s2Words
@staticmethod
def argMatch(ref, ex, ignoreStopwords, ignoreCase):
"""
Return whehter gold and predicted extractions agree on the arguments
"""
sRef = ' '.join([ref.elementToStr(elem) for elem in ref.args])
sEx = ' '.join([ex.elementToStr(elem) for elem in ex.args])
count = 0
for w1 in sRef:
for w2 in sEx:
if w1 == w2:
count += 1
# We check how well does the extraction lexically cover the reference
# Note: this is somewhat lenient as it doesn't penalize the extraction for
# being too long
coverage = float(count) / len(sRef)
return coverage > Matcher.LEXICAL_THRESHOLD
@staticmethod
def bleuMatch(ref, ex, ignoreStopwords, ignoreCase):
sRef = ref.bow()
sEx = ex.bow()
bleu = sentence_bleu(references = [sRef.split(' ')], hypothesis = sEx.split(' '))
return bleu > Matcher.BLEU_THRESHOLD
@staticmethod
def lexicalMatch(ref, ex, ignoreStopwords, ignoreCase):
sRef = ref.bow().split(' ')
sEx = ex.bow().split(' ')
count = 0
for w1 in sRef:
for w2 in sEx:
if w1 == w2:
count += 1
# We check how well does the extraction lexically cover the reference
# Note: this is somewhat lenient as it doesn't penalize the extraction for
# being too long
coverage = float(count) / len(sRef)
return coverage > Matcher.LEXICAL_THRESHOLD
@staticmethod
def removeStopwords(ls):
return [w for w in ls if w.lower() not in Matcher.stopwords]
# CONSTANTS
BLEU_THRESHOLD = 0.4
LEXICAL_THRESHOLD = 0.5 # Note: changing this value didn't change the ordering of the tested systems
stopwords = stopwords.words('english') + list(string.punctuation)

View File

@ -1,45 +0,0 @@
class OieReader:
def read(self, fn, includeNominal):
''' should set oie as a class member
as a dictionary of extractions by sentence'''
raise Exception("Don't run me")
def count(self):
''' number of extractions '''
return sum([len(extractions) for _, extractions in self.oie.items()])
def split_to_corpus(self, corpus_fn, out_fn):
"""
Given a corpus file name, containing a list of sentences
print only the extractions pertaining to it to out_fn in a tab separated format:
sent, prob, pred, arg1, arg2, ...
"""
raw_sents = [line.strip() for line in open(corpus_fn)]
with open(out_fn, 'w') as fout:
for line in self.get_tabbed().split('\n'):
data = line.split('\t')
sent = data[0]
if sent in raw_sents:
fout.write(line + '\n')
def output_tabbed(self, out_fn):
"""
Write a tabbed represenation of this corpus.
"""
with open(out_fn, 'w') as fout:
fout.write(self.get_tabbed())
def get_tabbed(self):
"""
Get a tabbed format representation of this corpus (assumes that input was
already read).
"""
return "\n".join(['\t'.join(map(str,
[ex.sent,
ex.confidence,
ex.pred,
'\t'.join(ex.args)]))
for (sent, exs) in self.oie.iteritems()
for ex in exs])

View File

@ -1,6 +0,0 @@
# Evaluation
This code is mainly based on the code from the origin [OIE2016 repository](https://github.com/gabrielStanovsky/oie-benchmark). The command to run the code is <br>
```python evaluate.py [new/old] input_file output_file```<old> <br>
"new" means that we use Re-OIE2016 as benchmark and "old" means that we use OIE2016 as benchmark. The input_file is the extraction of openIE system. Each line follows the following format (separated by tab):<br>
```sentence confidence_score predicate arg0 arg1 arg2 ...``` <br>
The script will output the AUC and best F1 score of the system. And the output file is used to draw the pr-curve. The script to draw the pr-curve is [here](https://github.com/gabrielStanovsky/oie-benchmark/blob/master/pr_plot.py).

View File

@ -1,72 +0,0 @@
import os
import torch
import numpy as np
import utils.bio as bio
from transformers import BertTokenizer
from tqdm import tqdm
def extract(args,
model,
loader,
output_path):
model.eval()
os.makedirs(output_path, exist_ok=True)
extraction_path = os.path.join(output_path, "extraction.txt")
tokenizer = BertTokenizer.from_pretrained(args.bert_config)
f = open(extraction_path, 'w')
for step, batch in tqdm(enumerate(loader), desc='eval_steps', total=len(loader)):
token_strs = [[word for word in sent] for sent in np.asarray(batch[-2]).T]
sentences = batch[-1]
token_ids, att_mask = map(lambda x: x.to(args.device), batch[:-2])
with torch.no_grad():
"""
We will iterate B(batch_size) times
because there are more than one predicate in one batch.
In feeding to argument extractor, # of predicates takes a role as batch size.
pred_logit: (B, L, 3)
pred_hidden: (B, L, D)
pred_tags: (B, P, L) ~ list of tensors, where P is # of predicate in each batch
"""
pred_logit, pred_hidden = model.extract_predicate(
input_ids=token_ids, attention_mask=att_mask)
pred_tags = torch.argmax(pred_logit, 2)
pred_tags = bio.filter_pred_tags(pred_tags, token_strs)
pred_tags = bio.get_single_predicate_idxs(pred_tags)
pred_probs = torch.nn.Softmax(2)(pred_logit)
# iterate B times (one iteration means extraction for one sentence)
for cur_pred_tags, cur_pred_hidden, cur_att_mask, cur_token_id, cur_pred_probs, token_str, sentence \
in zip(pred_tags, pred_hidden, att_mask, token_ids, pred_probs, token_strs, sentences):
# generate temporary batch for this sentence and feed to argument module
cur_pred_masks = bio.get_pred_mask(cur_pred_tags).to(args.device)
n_predicates = cur_pred_masks.shape[0]
if n_predicates == 0:
continue # if there is no predicate, we cannot extract.
cur_pred_hidden = torch.cat(n_predicates * [cur_pred_hidden.unsqueeze(0)])
cur_token_id = torch.cat(n_predicates * [cur_token_id.unsqueeze(0)])
cur_arg_logit = model.extract_argument(
input_ids=cur_token_id,
predicate_hidden=cur_pred_hidden,
predicate_mask=cur_pred_masks)
# filter and get argument tags with highest probability
cur_arg_tags = torch.argmax(cur_arg_logit, 2)
cur_arg_probs = torch.nn.Softmax(2)(cur_arg_logit)
cur_arg_tags = bio.filter_arg_tags(cur_arg_tags, cur_pred_tags, token_str)
# get string tuples and write results
cur_extractions, cur_extraction_idxs = bio.get_tuple(sentence, cur_pred_tags, cur_arg_tags, tokenizer)
cur_confidences = bio.get_confidence_score(cur_pred_probs, cur_arg_probs, cur_extraction_idxs)
for extraction, confidence in zip(cur_extractions, cur_confidences):
if args.binary:
f.write("\t".join([sentence] + [str(1.0)] + extraction[:3]) + '\n')
else:
f.write("\t".join([sentence] + [str(confidence)] + extraction) + '\n')
f.close()
print("\nExtraction Done.\n")

Binary file not shown.

Before

Width:  |  Height:  |  Size: 124 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 49 KiB

View File

@ -1,112 +0,0 @@
import argparse
import os
import torch
from utils import utils
from utils.utils import SummaryManager
from dataset import load_data
from tqdm import tqdm
from train import train
from extract import extract
from src.Multi2OIE.test import do_eval
def main(args):
utils.set_seed(args.seed)
model = utils.get_models(
bert_config=args.bert_config,
pred_n_labels=args.pred_n_labels,
arg_n_labels=args.arg_n_labels,
n_arg_heads=args.n_arg_heads,
n_arg_layers=args.n_arg_layers,
lstm_dropout=args.lstm_dropout,
mh_dropout=args.mh_dropout,
pred_clf_dropout=args.pred_clf_dropout,
arg_clf_dropout=args.arg_clf_dropout,
pos_emb_dim=args.pos_emb_dim,
use_lstm=args.use_lstm,
device=args.device)
trn_loader = load_data(
data_path=args.trn_data_path,
batch_size=args.batch_size,
max_len=args.max_len,
tokenizer_config=args.bert_config)
dev_loaders = [
load_data(
data_path=cur_dev_path,
batch_size=args.dev_batch_size,
tokenizer_config=args.bert_config,
train=False)
for cur_dev_path in args.dev_data_path]
args.total_steps = round(len(trn_loader) * args.epochs)
args.warmup_steps = round(args.total_steps / 10)
optimizer, scheduler = utils.get_train_modules(
model=model,
lr=args.learning_rate,
total_steps=args.total_steps,
warmup_steps=args.warmup_steps)
model.zero_grad()
summarizer = SummaryManager(args)
print("\nTraining Starts\n")
for epoch in tqdm(range(1, args.epochs + 1), desc='epochs'):
trn_results = train(
args, epoch, model, trn_loader, dev_loaders,
summarizer, optimizer, scheduler)
# extraction on devset
dev_iter = zip(args.dev_data_path, args.dev_gold_path, dev_loaders)
dev_results = list()
total_sum = 0
for dev_input, dev_gold, dev_loader in dev_iter:
dev_name = dev_input.split('/')[-1].replace('.pkl', '')
output_path = os.path.join(args.save_path, f'epoch{epoch}_dev/end_epoch/{dev_name}')
extract(args, model, dev_loader, output_path)
dev_result = do_eval(output_path, dev_gold)
utils.print_results(f"EPOCH{epoch} EVAL",
dev_result, ["F1 ", "PREC", "REC ", "AUC "])
total_sum += dev_result[0] + dev_result[-1]
dev_result.append(dev_result[0] + dev_result[-1])
dev_results += dev_result
summarizer.save_results([epoch] + trn_results + dev_results + [total_sum])
model_name = utils.set_model_name(total_sum, epoch)
torch.save(model.state_dict(), os.path.join(args.save_path, model_name))
print("\nTraining Ended\n")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# settings
parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--save_path', default='./results')
parser.add_argument('--bert_config', default='bert-base-cased', help='or bert-base-multilingual-cased')
parser.add_argument('--trn_data_path', default='./datasets/openie4_train.pkl')
parser.add_argument('--dev_data_path', nargs='+', default=['./datasets/oie2016_dev.pkl', './datasets/carb_dev.pkl'])
parser.add_argument('--dev_gold_path', nargs='+', default=['./evaluate/OIE2016_dev.txt', './carb/CaRB_dev.tsv'])
parser.add_argument('--max_len', type=int, default=64)
parser.add_argument('--device', default='cuda:0')
parser.add_argument('--visible_device', default="0")
parser.add_argument('--summary_step', type=int, default=100)
parser.add_argument('--use_lstm', nargs='?', const=True, default=False, type=utils.str2bool)
parser.add_argument('--binary', nargs='?', const=True, default=False, type=utils.str2bool)
# hyper-parameters
parser.add_argument('--epochs', type=int, default=1)
parser.add_argument('--lstm_dropout', type=float, default=0.)
parser.add_argument('--mh_dropout', type=float, default=0.2)
parser.add_argument('--pred_clf_dropout', type=float, default=0.)
parser.add_argument('--arg_clf_dropout', type=float, default=0.2)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--dev_batch_size', type=int, default=32)
parser.add_argument('--learning_rate', type=float, default=3e-5)
parser.add_argument('--n_arg_heads', type=int, default=8)
parser.add_argument('--n_arg_layers', type=int, default=4)
parser.add_argument('--pos_emb_dim', type=int, default=64)
main_args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = main_args.visible_device
main_args = utils.clean_config(main_args)
main(main_args)

View File

@ -1,316 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
from torch.nn.modules.container import ModuleList
from transformers import BertModel
class ArgModule(nn.Module):
def __init__(self, arg_layer, n_layers):
"""
Module for extracting arguments based on given encoder output and predicates.
It uses ArgExtractorLayer as a base block and repeat the block N('n_layers') times
:param arg_layer: an instance of the ArgExtractorLayer() class (required)
:param n_layers: the number of sub-layers in the ArgModule (required).
"""
super(ArgModule, self).__init__()
self.layers = _get_clones(arg_layer, n_layers)
self.n_layers = n_layers
def forward(self, encoded, predicate, pred_mask=None):
"""
:param encoded: output from sentence encoder with the shape of (L, B, D),
where L is the sequence length, B is the batch size, D is the embedding dimension
:param predicate: output from predicate module with the shape of (L, B, D)
:param pred_mask: mask that prevents attention to tokens which are not predicates
with the shape of (B, L)
:return: tensor like Transformer Decoder Layer Output
"""
output = encoded
for layer_idx in range(self.n_layers):
output = self.layers[layer_idx](
target=output, source=predicate, key_mask=pred_mask)
return output
class ArgExtractorLayer(nn.Module):
def __init__(self,
d_model=768,
n_heads=8,
d_feedforward=2048,
dropout=0.1,
activation='relu'):
"""
A layer similar to Transformer decoder without decoder self-attention.
(only encoder-decoder multi-head attention followed by feed-forward layers)
:param d_model: model dimensionality (default=768 from BERT-base)
:param n_heads: number of heads in multi-head attention layer
:param d_feedforward: dimensionality of point-wise feed-forward layer
:param dropout: drop rate of all layers
:param activation: activation function after first feed-forward layer
"""
super(ArgExtractorLayer, self).__init__()
self.multihead_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
self.linear1 = nn.Linear(d_model, d_feedforward)
self.dropout1 = nn.Dropout(dropout)
self.linear2 = nn.Linear(d_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
def forward(self, target, source, key_mask=None):
"""
Single Transformer Decoder layer without self-attention
:param target: a tensor which takes a role as a query
:param source: a tensor which takes a role as a key & value
:param key_mask: key mask tensor with the shape of (batch_size, sequence_length)
"""
# Multi-head attention layer (+ add & norm)
attended = self.multihead_attn(
target, source, source,
key_padding_mask=key_mask)[0]
skipped = target + self.dropout1(attended)
normed = self.norm1(skipped)
# Point-wise feed-forward layer (+ add & norm)
projected = self.linear2(self.dropout2(self.activation(self.linear1(normed))))
skipped = normed + self.dropout1(projected)
normed = self.norm2(skipped)
return normed
class Multi2OIE(nn.Module):
def __init__(self,
bert_config='bert-base-cased',
mh_dropout=0.1,
pred_clf_dropout=0.,
arg_clf_dropout=0.3,
n_arg_heads=8,
n_arg_layers=4,
pos_emb_dim=64,
pred_n_labels=3,
arg_n_labels=9):
super(Multi2OIE, self).__init__()
self.pred_n_labels = pred_n_labels
self.arg_n_labels = arg_n_labels
self.bert = BertModel.from_pretrained(
bert_config,
output_hidden_states=True)
d_model = self.bert.config.hidden_size
self.pred_dropout = nn.Dropout(pred_clf_dropout)
self.pred_classifier = nn.Linear(d_model, self.pred_n_labels)
self.position_emb = nn.Embedding(3, pos_emb_dim, padding_idx=0)
d_model += (d_model + pos_emb_dim)
arg_layer = ArgExtractorLayer(
d_model=d_model,
n_heads=n_arg_heads,
dropout=mh_dropout)
self.arg_module = ArgModule(arg_layer, n_arg_layers)
self.arg_dropout = nn.Dropout(arg_clf_dropout)
self.arg_classifier = nn.Linear(d_model, arg_n_labels)
def forward(self,
input_ids,
attention_mask,
predicate_mask=None,
predicate_hidden=None,
total_pred_labels=None,
arg_labels=None):
# predicate extraction
bert_hidden = self.bert(input_ids, attention_mask)[0]
pred_logit = self.pred_classifier(self.pred_dropout(bert_hidden))
# predicate loss
if total_pred_labels is not None:
loss_fct = nn.CrossEntropyLoss()
active_loss = attention_mask.view(-1) == 1
active_logits = pred_logit.view(-1, self.pred_n_labels)
active_labels = torch.where(
active_loss, total_pred_labels.view(-1),
torch.tensor(loss_fct.ignore_index).type_as(total_pred_labels))
pred_loss = loss_fct(active_logits, active_labels)
# inputs for argument extraction
pred_feature = _get_pred_feature(bert_hidden, predicate_mask)
position_vectors = self.position_emb(_get_position_idxs(predicate_mask, input_ids))
bert_hidden = torch.cat([bert_hidden, pred_feature, position_vectors], dim=2)
bert_hidden = bert_hidden.transpose(0, 1)
# argument extraction
arg_hidden = self.arg_module(bert_hidden, bert_hidden, predicate_mask)
arg_hidden = arg_hidden.transpose(0, 1)
arg_logit = self.arg_classifier(self.arg_dropout(arg_hidden))
# argument loss
if arg_labels is not None:
loss_fct = nn.CrossEntropyLoss()
active_loss = attention_mask.view(-1) == 1
active_logits = arg_logit.view(-1, self.arg_n_labels)
active_labels = torch.where(
active_loss, arg_labels.view(-1),
torch.tensor(loss_fct.ignore_index).type_as(arg_labels))
arg_loss = loss_fct(active_logits, active_labels)
# total loss
batch_loss = pred_loss + arg_loss
outputs = (batch_loss, pred_loss, arg_loss)
return outputs
def extract_predicate(self,
input_ids,
attention_mask):
bert_hidden = self.bert(input_ids, attention_mask)[0]
pred_logit = self.pred_classifier(bert_hidden)
return pred_logit, bert_hidden
def extract_argument(self,
input_ids,
predicate_hidden,
predicate_mask):
pred_feature = _get_pred_feature(predicate_hidden, predicate_mask)
position_vectors = self.position_emb(_get_position_idxs(predicate_mask, input_ids))
arg_input = torch.cat([predicate_hidden, pred_feature, position_vectors], dim=2)
arg_input = arg_input.transpose(0, 1)
arg_hidden = self.arg_module(arg_input, arg_input, predicate_mask)
arg_hidden = arg_hidden.transpose(0, 1)
return self.arg_classifier(arg_hidden)
class BERTBiLSTM(nn.Module):
def __init__(self,
bert_config='bert-base-cased',
lstm_dropout=0.3,
pred_clf_dropout=0.,
arg_clf_dropout=0.3,
pos_emb_dim=256,
pred_n_labels=3,
arg_n_labels=9):
super(BERTBiLSTM, self).__init__()
self.pred_n_labels = pred_n_labels
self.arg_n_labels = arg_n_labels
self.bert = BertModel.from_pretrained(
bert_config,
output_hidden_states=True)
d_model = self.bert.config.hidden_size
self.pred_dropout = nn.Dropout(pred_clf_dropout)
self.pred_classifier = nn.Linear(d_model, self.pred_n_labels)
self.position_emb = nn.Embedding(3, pos_emb_dim, padding_idx=0)
d_model += pos_emb_dim
self.arg_module = nn.LSTM(
input_size=d_model,
hidden_size=d_model,
num_layers=3,
dropout=lstm_dropout,
batch_first=True,
bidirectional=True)
self.arg_dropout = nn.Dropout(arg_clf_dropout)
self.arg_classifier = nn.Linear(d_model * 2, arg_n_labels)
def forward(self,
input_ids,
attention_mask,
predicate_mask=None,
predicate_hidden=None,
total_pred_labels=None,
arg_labels=None):
# predicate extraction
bert_hidden = self.bert(input_ids, attention_mask)[0]
pred_logit = self.pred_classifier(self.pred_dropout(bert_hidden))
# predicate loss
if total_pred_labels is not None:
loss_fct = nn.CrossEntropyLoss()
active_loss = attention_mask.view(-1) == 1
active_logits = pred_logit.view(-1, self.pred_n_labels)
active_labels = torch.where(
active_loss, total_pred_labels.view(-1),
torch.tensor(loss_fct.ignore_index).type_as(total_pred_labels))
pred_loss = loss_fct(active_logits, active_labels)
# argument extraction
position_vectors = self.position_emb(_get_position_idxs(predicate_mask, input_ids))
bert_hidden = torch.cat([bert_hidden, position_vectors], dim=2)
arg_hidden = self.arg_module(bert_hidden)[0]
arg_logit = self.arg_classifier(self.arg_dropout(arg_hidden))
# argument loss
if arg_labels is not None:
loss_fct = nn.CrossEntropyLoss()
active_loss = attention_mask.view(-1) == 1
active_logits = arg_logit.view(-1, self.arg_n_labels)
active_labels = torch.where(
active_loss, arg_labels.view(-1),
torch.tensor(loss_fct.ignore_index).type_as(arg_labels))
arg_loss = loss_fct(active_logits, active_labels)
# total loss
batch_loss = pred_loss + arg_loss
outputs = (batch_loss, pred_loss, arg_loss)
return outputs
def extract_predicate(self,
input_ids,
attention_mask):
bert_hidden = self.bert(input_ids, attention_mask)[0]
pred_logit = self.pred_classifier(bert_hidden)
return pred_logit, bert_hidden
def extract_argument(self,
input_ids,
predicate_hidden,
predicate_mask):
position_vectors = self.position_emb(_get_position_idxs(predicate_mask, input_ids))
arg_input = torch.cat([predicate_hidden, position_vectors], dim=2)
arg_hidden = self.arg_module(arg_input)[0]
return self.arg_classifier(arg_hidden)
def _get_activation_fn(activation):
if activation == "relu":
return F.relu
elif activation == "gelu":
return F.gelu
else:
raise RuntimeError("activation should be relu/gelu, not %s." % activation)
def _get_clones(module, n):
return ModuleList([copy.deepcopy(module) for _ in range(n)])
def _get_position_idxs(pred_mask, input_ids):
position_idxs = torch.zeros(pred_mask.shape, dtype=int, device=pred_mask.device)
for mask_idx, cur_mask in enumerate(pred_mask):
position_idxs[mask_idx, :] += 2
cur_nonzero = (cur_mask == 0).nonzero()
start = torch.min(cur_nonzero).item()
end = torch.max(cur_nonzero).item()
position_idxs[mask_idx, start:end + 1] = 1
pad_start = max(input_ids[mask_idx].nonzero()).item() + 1
position_idxs[mask_idx, pad_start:] = 0
return position_idxs
def _get_pred_feature(pred_hidden, pred_mask):
B, L, D = pred_hidden.shape
pred_features = torch.zeros((B, L, D), device=pred_mask.device)
for mask_idx, cur_mask in enumerate(pred_mask):
pred_position = (cur_mask == 0).nonzero().flatten()
pred_feature = torch.mean(pred_hidden[mask_idx, pred_position], dim=0)
pred_feature = torch.cat(L * [pred_feature.unsqueeze(0)])
pred_features[mask_idx, :, :] = pred_feature
return pred_features

View File

@ -1,59 +0,0 @@
attrs==19.3.0
backcall==0.1.0
blis==0.4.1
brotlipy==0.7.0
catalogue==1.0.0
certifi==2020.4.5.1
cffi==1.14.0
chardet==3.0.4
click==7.1.2
cryptography==2.9.2
cymem==2.0.3
decorator==4.4.2
docopt==0.6.2
filelock==3.0.12
idna==2.9
importlib-metadata==1.6.0
ipdb==0.13.2
ipython==7.14.0
ipython-genutils==0.2.0
jedi==0.17.0
joblib==0.15.1
jsonschema==3.2.0
murmurhash==1.0.0
nltk==3.5
numpy==1.18.1
pandas==1.0.3
parso==0.7.0
pexpect==4.8.0
pickleshare==0.7.5
plac==0.9.6
preshed==3.0.2
prompt-toolkit==3.0.5
ptyprocess==0.6.0
pycparser==2.20
Pygments==2.6.1
pyOpenSSL==19.1.0
pyrsistent==0.16.0
PySocks==1.7.1
python-dateutil==2.8.1
pytz==2020.1
regex==2020.5.14
requests==2.23.0
sacremoses==0.0.43
scikit-learn==0.22.1
scipy==1.4.1
sentencepiece==0.1.91
six==1.14.0
spacy==2.2.4
srsly==1.0.2
thinc==7.4.0
tokenizers==0.7.0
torch==1.4.0
tqdm==4.46.0
traitlets==4.3.3
transformers==2.10.0
urllib3==1.25.9
wasabi==0.6.0
wcwidth==0.1.9
zipp==3.1.0

View File

@ -1,103 +0,0 @@
import argparse
import os
import time
import torch
from utils import utils
from dataset import load_data
from extract import extract
from evaluate.evaluate import Benchmark
from evaluate.matcher import Matcher
from evaluate.generalReader import GeneralReader
from carb.carb import Benchmark as CarbBenchmark
from carb.matcher import Matcher as CarbMatcher
from carb.tabReader import TabReader
def get_performance(output_path, gold_path):
auc, precision, recall, f1 = [None for _ in range(4)]
if 'evaluate' in gold_path:
matching_func = Matcher.lexicalMatch
error_fn = os.path.join(output_path, 'error_idxs.txt')
evaluator = Benchmark(gold_path)
reader = GeneralReader()
reader.read(os.path.join(output_path, 'extraction.txt'))
(precision, recall, f1), auc = evaluator.compare(
predicted=reader.oie,
matchingFunc=matching_func,
output_fn=os.path.join(output_path, 'pr_curve.txt'),
error_file=error_fn)
elif 'carb' in gold_path:
matching_func = CarbMatcher.binary_linient_tuple_match
error_fn = os.path.join(output_path, 'error_idxs.txt')
evaluator = CarbBenchmark(gold_path)
reader = TabReader()
reader.read(os.path.join(output_path, 'extraction.txt'))
auc, (precision, recall, f1) = evaluator.compare(
predicted=reader.oie,
matchingFunc=matching_func,
output_fn=os.path.join(output_path, 'pr_curve.txt'),
error_file=error_fn)
return auc, precision, recall, f1
def do_eval(output_path, gold_path):
auc, prec, rec, f1 = get_performance(output_path, gold_path)
eval_results = [f1, prec, rec, auc]
return eval_results
def main(args):
model = utils.get_models(
bert_config=args.bert_config,
pred_n_labels=args.pred_n_labels,
arg_n_labels=args.arg_n_labels,
n_arg_heads=args.n_arg_heads,
n_arg_layers=args.n_arg_layers,
pos_emb_dim=args.pos_emb_dim,
use_lstm=args.use_lstm,
device=args.device)
model.load_state_dict(torch.load(args.model_path))
model.zero_grad()
model.eval()
loader = load_data(
data_path=args.test_data_path,
batch_size=args.batch_size,
tokenizer_config=args.bert_config,
train=False)
start = time.time()
extract(args, model, loader, args.save_path)
print("TIME: ", time.time() - start)
test_results = do_eval(args.save_path, args.test_gold_path)
utils.print_results("TEST RESULT", test_results, ["F1 ", "PREC", "REC ", "AUC "])
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', default='./results/model.bin')
parser.add_argument('--save_path', default='./results/carb_test')
parser.add_argument('--bert_config', default='bert-base-cased')
parser.add_argument('--test_data_path', default='./datasets/carb_test.pkl')
parser.add_argument('--test_gold_path', default='./carb/CaRB_test.tsv')
parser.add_argument('--device', default='cuda:0')
parser.add_argument('--visible_device', default="0")
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--pos_emb_dim', type=int, default=64)
parser.add_argument('--n_arg_heads', type=int, default=8)
parser.add_argument('--n_arg_layers', type=int, default=4)
parser.add_argument('--use_lstm', nargs='?', const=True, default=False, type=utils.str2bool)
parser.add_argument('--binary', nargs='?', const=True, default=False, type=utils.str2bool)
main_args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = main_args.visible_device
main_args.pred_n_labels = 3
main_args.arg_n_labels = 9
device = torch.device(main_args.device if torch.cuda.is_available() else 'cpu')
main_args.device = device
main(main_args)

View File

@ -1,79 +0,0 @@
import os
import torch
import torch.nn as nn
import utils.bio as bio
from tqdm import tqdm
from extract import extract
from utils import utils
from src.Multi2OIE.test import do_eval
def train(args,
epoch,
model,
trn_loader,
dev_loaders,
summarizer,
optimizer,
scheduler):
total_pred_loss, total_arg_loss, trn_results = 0, 0, None
epoch_steps = int(args.total_steps / args.epochs)
iterator = tqdm(enumerate(trn_loader), desc='steps', total=epoch_steps)
for step, batch in iterator:
batch = map(lambda x: x.to(args.device), batch)
token_ids, att_mask, single_pred_label, single_arg_label, all_pred_label = batch
pred_mask = bio.get_pred_mask(single_pred_label)
model.train()
model.zero_grad()
# feed to predicate model
batch_loss, pred_loss, arg_loss = model(
input_ids=token_ids,
attention_mask=att_mask,
predicate_mask=pred_mask,
total_pred_labels=all_pred_label,
arg_labels=single_arg_label)
# get performance on this batch
total_pred_loss += pred_loss.item()
total_arg_loss += arg_loss.item()
batch_loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
trn_results = [total_pred_loss / (step + 1), total_arg_loss / (step + 1)]
if step > epoch_steps:
break
# interim evaluation
if step % 1000 == 0 and step != 0:
dev_iter = zip(args.dev_data_path, args.dev_gold_path, dev_loaders)
dev_results = list()
total_sum = 0
for dev_input, dev_gold, dev_loader in dev_iter:
dev_name = dev_input.split('/')[-1].replace('.pkl', '')
output_path = os.path.join(args.save_path, f'epoch{epoch}_dev/step{step}/{dev_name}')
extract(args, model, dev_loader, output_path)
dev_result = do_eval(output_path, dev_gold)
utils.print_results(f"EPOCH{epoch} STEP{step} EVAL",
dev_result, ["F1 ", "PREC", "REC ", "AUC "])
total_sum += dev_result[0] + dev_result[-1]
dev_result.append(dev_result[0] + dev_result[-1])
dev_results += dev_result
summarizer.save_results([step] + trn_results + dev_results + [total_sum])
model_name = utils.set_model_name(total_sum, epoch, step)
torch.save(model.state_dict(), os.path.join(args.save_path, model_name))
if step % args.summary_step == 0 and step != 0:
utils.print_results(f"EPOCH{epoch} STEP{step} TRAIN",
trn_results, ["PRED LOSS", "ARG LOSS "])
# end epoch summary
utils.print_results(f"EPOCH{epoch} TRAIN",
trn_results, ["PRED LOSS", "ARG LOSS "])
return trn_results

View File

@ -1,310 +0,0 @@
import torch
import numpy as np
from src.Multi2OIE.utils import utils
pred_tag2idx = {
'P-B': 0, 'P-I': 1, 'O': 2
}
arg_tag2idx = {
'A0-B': 0, 'A0-I': 1,
'A1-B': 2, 'A1-I': 3,
'A2-B': 4, 'A2-I': 5,
'A3-B': 6, 'A3-I': 7,
'O': 8,
}
def get_pred_idxs(pred_tags):
idxs = list()
for pred_tag in pred_tags:
idxs.append([idx.item() for idx in (pred_tag != 2).nonzero()])
return idxs
def get_pred_mask(tensor):
"""
Generate predicate masks by converting predicate index with 'O' tag to 1.
Other indexes are converted to 0 which means non-masking.
:param tensor: predicate tagged tensor with the shape of (B, L),
where B is the batch size, L is the sequence length.
:return: masked binary tensor with the same shape.
"""
res = tensor.clone()
res[tensor == pred_tag2idx['O']] = 1
res[tensor != pred_tag2idx['O']] = 0
return torch.tensor(res, dtype=torch.bool, device=tensor.device)
def filter_pred_tags(pred_tags, tokens):
"""
Filter useless tokens by converting them into 'Outside' tag.
We treat 'Inside' tag before 'Beginning' tag as meaningful signal,
so changed them to 'Beginning' tag unlike [Stanovsky et al., 2018].
:param pred_tags: predicate tags with the shape of (B, L).
:param tokens: list format sentence pieces with the shape of (B, L)
:return: tensor of filtered predicate tags with the same shape.
"""
assert len(pred_tags) == len(tokens)
assert len(pred_tags[0]) == len(tokens[0])
# filter by tokens ([CLS], [SEP], [PAD] tokens should be allocated as 'O')
for pred_idx, cur_tokens in enumerate(tokens):
for tag_idx, token in enumerate(cur_tokens):
if token in ['[CLS]', '[SEP]', '[PAD]']:
pred_tags[pred_idx][tag_idx] = pred_tag2idx['O']
# filter by tags
pred_copied = pred_tags.clone()
for pred_idx, cur_pred_tag in enumerate(pred_copied):
flag = False
tag_copied = cur_pred_tag.clone()
for tag_idx, tag in enumerate(tag_copied):
if not flag and tag == pred_tag2idx['P-B']:
flag = True
elif not flag and tag == pred_tag2idx['P-I']:
pred_tags[pred_idx][tag_idx] = pred_tag2idx['P-B']
flag = True
elif flag and tag == pred_tag2idx['O']:
flag = False
return pred_tags
def filter_arg_tags(arg_tags, pred_tags, tokens):
"""
Same as the description of @filter_pred_tags().
:param arg_tags: argument tags with the shape of (B, L).
:param pred_tags: predicate tags with the same shape.
It is used to force predicate position to be allocated the 'Outside' tag.
:param tokens: list of string tokens with the length of L.
It is used to force special tokens like [CLS] to be allocated the 'Outside' tag.
:return: tensor of filtered argument tags with the same shape.
"""
# filter by tokens ([CLS], [SEP], [PAD] tokens should be allocated as 'O')
for arg_idx, cur_arg_tag in enumerate(arg_tags):
for tag_idx, token in enumerate(tokens):
if token in ['[CLS]', '[SEP]', '[PAD]']:
arg_tags[arg_idx][tag_idx] = arg_tag2idx['O']
# filter by tags
arg_copied = arg_tags.clone()
for arg_idx, (cur_arg_tag, cur_pred_tag) in enumerate(zip(arg_copied, pred_tags)):
pred_idxs = [idx[0].item() for idx
in (cur_pred_tag != pred_tag2idx['O']).nonzero()]
arg_tags[arg_idx][pred_idxs] = arg_tag2idx['O']
cur_arg_copied = arg_tags[arg_idx].clone()
flag_idx = 999
for tag_idx, tag in enumerate(cur_arg_copied):
if tag == arg_tag2idx['O']:
flag_idx = 999
continue
arg_n = tag // 2 # 0: A0 / 1: A1 / ...
inside = tag % 2 # 0: begin / 1: inside
if not inside and flag_idx != arg_n:
flag_idx = arg_n
# connect_args
elif not inside and flag_idx == arg_n:
arg_tags[arg_idx][tag_idx] = arg_tag2idx[f'A{arg_n}-I']
elif inside and flag_idx != arg_n:
arg_tags[arg_idx][tag_idx] = arg_tag2idx[f'A{arg_n}-B']
flag_idx = arg_n
return arg_tags
def get_max_prob_args(arg_tags, arg_probs):
"""
Among predicted argument tags, remain only arguments with highest probs.
The comparison of probability is made only between the same argument labels.
:param arg_tags: argument tags with the shape of (B, L).
:param arg_probs: argument softmax probabilities with the shape of (B, L, T),
where B is the batch size, L is the sequence length, and T is the # of tag labels.
:return: tensor of filtered argument tags with the same shape.
"""
for cur_arg_tag, cur_probs in zip(arg_tags, arg_probs):
cur_tag_probs = [cur_probs[idx][tag] for idx, tag in enumerate(cur_arg_tag)]
for arg_n in range(4):
b_tag = arg_tag2idx[f"A{arg_n}-B"]
i_tag = arg_tag2idx[f"A{arg_n}-I"]
flag = False
total_tags = []
cur_tags = []
for idx, tag in enumerate(cur_arg_tag):
if not flag and tag == b_tag:
flag = True
cur_tags.append(idx)
elif flag and tag == i_tag:
cur_tags.append(idx)
elif flag and tag == b_tag:
total_tags.append(cur_tags)
cur_tags = [idx]
elif tag != b_tag or tag != i_tag:
total_tags.append(cur_tags)
cur_tags = []
flag = False
max_idxs, max_prob = None, 0.0
for idxs in total_tags:
all_probs = [cur_tag_probs[idx].item() for idx in idxs]
if len(all_probs) == 0:
continue
cur_prob = all_probs[0]
if cur_prob > max_prob:
max_prob = cur_prob
max_idxs = idxs
if max_idxs is None:
continue
del_idxs = [idx for idx, tag in enumerate(cur_arg_tag)
if (tag in [b_tag, i_tag]) and (idx not in max_idxs)]
cur_arg_tag[del_idxs] = arg_tag2idx['O']
return arg_tags
def get_single_predicate_idxs(pred_tags):
"""
Divide each single batch based on predicted predicates.
It is necessary for predicting argument tags with specific predicate.
:param pred_tags: tensor of predicate tags with the shape of (B, L)
EX >>> tensor([[2, 0, 0, 1, 0, 1, 0, 2, 2, 2],
[2, 2, 2, 0, 1, 0, 1, 2, 2, 2],
[2, 2, 2, 2, 2, 2, 2, 2, 0, 1]])
:return: list of tensors with the shape of (B, P, L)
the number P can be different for each batch.
EX >>> [tensor([[2., 0., 2., 2., 2., 2., 2., 2., 2., 2.],
[2., 2., 0., 1., 2., 2., 2., 2., 2., 2.],
[2., 2., 2., 2., 0., 1., 2., 2., 2., 2.],
[2., 2., 2., 2., 2., 2., 0., 2., 2., 2.]]),
tensor([[2., 2., 2., 0., 1., 2., 2., 2., 2., 2.],
[2., 2., 2., 2., 2., 0., 1., 2., 2., 2.]]),
tensor([[2., 2., 2., 2., 2., 2., 2., 2., 0., 1.]])]
"""
total_pred_tags = []
for cur_pred_tag in pred_tags:
cur_sent_preds = []
begin_idxs = [idx[0].item() for idx in (cur_pred_tag == pred_tag2idx['P-B']).nonzero()]
for i, b_idx in enumerate(begin_idxs):
cur_pred = np.full(cur_pred_tag.shape[0], pred_tag2idx['O'])
cur_pred[b_idx] = pred_tag2idx['P-B']
if i == len(begin_idxs) - 1:
end_idx = cur_pred_tag.shape[0]
else:
end_idx = begin_idxs[i + 1]
for j, tag in enumerate(cur_pred_tag[b_idx:end_idx]):
if tag.item() == pred_tag2idx['O']:
break
elif tag.item() == pred_tag2idx['P-I']:
cur_pred[b_idx + j] = pred_tag2idx['P-I']
cur_sent_preds.append(cur_pred)
total_pred_tags.append(cur_sent_preds)
return [torch.Tensor(pred_tags) for pred_tags in total_pred_tags]
def get_tuple(sentence, pred_tags, arg_tags, tokenizer):
"""
Get string format tuples from given predicate indexes and argument tags.
:param sentence: string format raw sentence.
:param pred_tags: tensor of predicate tags with the shape of (# of predicates, sequence length).
:param arg_tags: tensor of argument tags with the same shape.
:param tokenizer: transformer BertTokenizer (bert-base-cased or bert-base-multilingual-cased)
:return extractions: list of strings each element means predicate, arg0, arg1, ...
:return extraction_idxs: list of indexes of each argument for calculating confidence score.
"""
word2piece = utils.get_word2piece(sentence, tokenizer)
words = sentence.split(' ')
assert pred_tags.shape[0] == arg_tags.shape[0] # number of predicates
pred_tags = pred_tags.tolist()
arg_tags = arg_tags.tolist()
extractions = list()
extraction_idxs = list()
# loop for each predicate
for cur_pred_tag, cur_arg_tags in zip(pred_tags, arg_tags):
cur_extraction = list()
cur_extraction_idxs = list()
# get predicate
pred_labels = [pred_tag2idx['P-B'], pred_tag2idx['P-I']]
cur_predicate_idxs = [idx for idx, tag in enumerate(cur_pred_tag) if tag in pred_labels]
if len(cur_predicate_idxs) == 0:
predicates_str = ''
else:
cur_pred_words = list()
for word_idx, piece_idxs in word2piece.items():
if set(piece_idxs) <= set(cur_predicate_idxs):
cur_pred_words.append(word_idx)
if len(cur_pred_words) == 0:
predicates_str = ''
cur_predicate_idxs = list()
else:
predicates_str = ' '.join([words[idx] for idx in cur_pred_words])
cur_extraction.append(predicates_str)
cur_extraction_idxs.append(cur_predicate_idxs)
# get arguments
for arg_n in range(4):
cur_arg_labels = [arg_tag2idx[f'A{arg_n}-B'], arg_tag2idx[f'A{arg_n}-I']]
cur_arg_idxs = [idx for idx, tag in enumerate(cur_arg_tags) if tag in cur_arg_labels]
if len(cur_arg_idxs) == 0:
cur_arg_str = ''
else:
cur_arg_words = list()
for word_idx, piece_idxs in word2piece.items():
if set(piece_idxs) <= set(cur_arg_idxs):
cur_arg_words.append(word_idx)
if len(cur_arg_words) == 0:
cur_arg_str = ''
cur_arg_idxs = list()
else:
cur_arg_str = ' '.join([words[idx] for idx in cur_arg_words])
cur_extraction.append(cur_arg_str)
cur_extraction_idxs.append(cur_arg_idxs)
extractions.append(cur_extraction)
extraction_idxs.append(cur_extraction_idxs)
return extractions, extraction_idxs
def get_confidence_score(pred_probs, arg_probs, extraction_idxs):
"""
get the confidence score of each extraction for drawing PR-curve.
:param pred_probs: (sequence length, # of predicate labels)
:param arg_probs: (# of predicates, sequence length, # of argument labels)
:param extraction_idxs: [[[2, 3, 4], [0, 1], [9, 10]], [[0, 1, 2], [7, 8], [4, 5]], ...]
"""
confidence_scores = list()
for cur_arg_prob, cur_ext_idxs in zip(arg_probs, extraction_idxs):
if len(cur_ext_idxs[0]) == 0:
confidence_scores.append(0)
continue
cur_score = 0
# predicate score
pred_score = max(pred_probs[cur_ext_idxs[0][0]]).item()
cur_score += pred_score
# argument score
for arg_idx in cur_ext_idxs[1:]:
if len(arg_idx) == 0:
continue
begin_idxs = _find_begins(arg_idx)
arg_score = np.mean([max(cur_arg_prob[cur_idx]).item() for cur_idx in begin_idxs])
cur_score += arg_score
confidence_scores.append(cur_score)
return confidence_scores
def _find_begins(idxs):
found_begins = [idxs[0]]
cur_flag_idx = idxs[0]
for cur_idx in idxs[1:]:
if cur_idx - cur_flag_idx != 1:
found_begins.append(cur_idx)
cur_flag_idx = cur_idx
return found_begins

View File

@ -1,237 +0,0 @@
"""
Script for transforming original dataset for Multi^2OIE training and evaluation
1. if Mode == 'train'
- input: structured.json (https://github.com/zhanjunlang/Span_OIE)
- output: ../datasets/openie4_train.pkl
2. if Mode == 'dev_input'
- input: dev.oie.conll (https://github.com/gabrielStanovsky/supervised-oie/tree/master/data)
- output: ../datasets/oie2016_dev.pkl
3. if Mode == 'dev_gold'
- input: dev.oie.conll (https://github.com/gabrielStanovsky/supervised-oie/tree/master/data)
- output: ../evaluate/OIE2016_dev.txt
"""
import json
import numpy as np
import argparse
import pickle
from transformers import BertTokenizer
from tqdm import tqdm
pred_tag2idx = {
'P-B': 0, 'P-I': 1, 'O': 2
}
arg_tag2idx = {
'A0-B': 0, 'A0-I': 1,
'A1-B': 2, 'A1-I': 3,
'A2-B': 4, 'A2-I': 5,
'A3-B': 6, 'A3-I': 7,
'O': 8,
}
def preprocess_train(args):
print("loading dataset...")
with open(args.data) as json_file:
data = json.load(json_file)
iterator = tqdm(data)
tokenizer = BertTokenizer.from_pretrained(args.bert_config)
print("done. preprocessing starts.")
openie4_train = {
'tokens': list(),
'single_pred_labels': list(),
'single_arg_labels': list(),
'all_pred_labels': list()
}
rel_pos_malformed = 0
max_over_case = 0
for cur_data in iterator:
words = cur_data['sentence'].replace('\xa0', ' ').split(' ')
word2piece = {idx: list() for idx in range(len(words))}
sentence_pieces = list()
piece_idx = 0
for word_idx, word in enumerate(words):
pieces = tokenizer.tokenize(word)
sentence_pieces += pieces
for piece_idx_added, piece in enumerate(pieces):
word2piece[word_idx].append(piece_idx + piece_idx_added)
piece_idx += len(pieces)
assert len(sentence_pieces) == piece_idx
# if the length of sentencepieces is over maxlen-2, we skip the sentence.
if piece_idx > args.max_len - 2:
max_over_case += 1
continue
all_pred_label = np.asarray([pred_tag2idx['O'] for _ in range(len(sentence_pieces))])
cur_tuple_malformed = 0
for cur_tuple in cur_data['tuples']:
# add predicate labels
pred_label = np.asarray([pred_tag2idx['O'] for _ in range(len(sentence_pieces))])
if -1 in cur_tuple['rel_pos']:
rel_pos_malformed += 1
cur_tuple_malformed += 1
continue
else:
start_idx, end_idx = cur_tuple['rel_pos'][:2]
for pred_word_idx in range(start_idx, end_idx + 1):
pred_label[word2piece[pred_word_idx]] = pred_tag2idx['P-I']
all_pred_label[word2piece[pred_word_idx]] = pred_tag2idx['P-I']
pred_label[word2piece[start_idx][0]] = pred_tag2idx['P-B']
all_pred_label[word2piece[start_idx][0]] = pred_tag2idx['P-B']
openie4_train['single_pred_labels'].append(pred_label)
# add argument-0 labels
arg_label = np.asarray([arg_tag2idx['O'] for _ in range(len(sentence_pieces))])
start_idx, end_idx = cur_tuple['arg0_pos']
for arg_word_idx in range(start_idx, end_idx + 1):
arg_label[word2piece[arg_word_idx]] = arg_tag2idx['A0-I']
arg_label[word2piece[start_idx][0]] = arg_tag2idx['A0-B']
# add additional argument labels
for arg_n, arg_pos in enumerate(cur_tuple['args_pos'][:3]):
arg_n += 1
start_idx, end_idx = arg_pos
for arg_word_idx in range(start_idx, end_idx + 1):
arg_label[word2piece[arg_word_idx]] = arg_tag2idx[f'A{arg_n}-I']
arg_label[word2piece[start_idx][0]] = arg_tag2idx[f'A{arg_n}-B']
openie4_train['single_arg_labels'].append(arg_label)
# add sentence pieces and total predicate label of current sentence
for _ in range(len(cur_data['tuples']) - cur_tuple_malformed):
openie4_train['tokens'].append(sentence_pieces)
openie4_train['all_pred_labels'].append(all_pred_label)
assert len(openie4_train['tokens']) == len(openie4_train['all_pred_labels'])
assert len(openie4_train['all_pred_labels']) == len(openie4_train['single_pred_labels'])
assert len(openie4_train['single_pred_labels']) == len(openie4_train['single_arg_labels'])
save_pkl(args.save_path, openie4_train)
print(f"# of data over max length: {max_over_case}")
print(f"# of data with malformed relation positions: {rel_pos_malformed}")
print("\npreprocessing done.")
"""
For English BERT,
# of data over max length: 5097
# of data with malformed relation positions: 1959
For Multilingual BERT,
# of data over max length: 2480
# of data with malformed relation positions: 1974
"""
def preprocess_dev_input(args):
print("loading dataset...")
with open(args.data) as f:
lines = f.readlines()
lines = [line.strip().split('\t') for line in lines]
print("done. preprocessing starts.")
sentences = list()
words_queue = list()
for line in tqdm(lines[1:]):
if len(line) != len(lines[0]):
sentence = " ".join(words_queue)
sentences.append(sentence)
words_queue = list()
continue
words_queue.append(line[1])
sentences = list(set(sentences))
save_pkl(args.save_path, sentences)
print("\npreprocessing done.")
def preprocess_dev_gold(args):
print("loading dataset...")
with open(args.data) as f:
lines = f.readlines()
lines = [line.strip().split('\t') for line in lines]
print("done. preprocessing starts.")
f = open(args.save_path, 'w')
words_queue = list()
args_queue = {
'A0': list(), 'A1': list(), 'A2': list(), 'A3': list()
}
for line in tqdm(lines[1:]):
if len(line) != len(lines[0]):
new_line = list()
new_line.append(" ".join(words_queue))
new_line.append(words_queue[pred_head_id])
new_line.append(pred)
for label in list(args_queue.keys()):
if len(args_queue[label]) != 0:
new_line.append(" ".join(args_queue[label]))
f.write("\t".join(new_line) + "\n")
words_queue = list()
args_queue = {
'A0': list(), 'A1': list(), 'A2': list(), 'A3': list()
}
continue
word = line[1]
pred = line[2]
pred_head_id = int(line[4])
words_queue.append(word)
for label in list(args_queue.keys()):
if label in line[-1]:
args_queue[label].append(word)
f.close()
def _get_word2piece(sentence, tokenizer):
words = sentence.replace('\xa0', ' ').split(' ')
word2piece = {idx: list() for idx in range(len(words))}
sentence_pieces = list()
piece_idx = 1
for word_idx, word in enumerate(words):
pieces = tokenizer.tokenize(word)
sentence_pieces += pieces
for piece_idx_added, piece in enumerate(pieces):
word2piece[word_idx].append(piece_idx + piece_idx_added)
piece_idx += len(pieces)
assert len(sentence_pieces) == piece_idx - 1
return word2piece
def save_pkl(path, file):
with open(path, 'wb') as f:
pickle.dump(file, f)
def load_pkl(path):
with open(path, 'rb') as f:
return pickle.load(f)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--mode', default='train')
parser.add_argument('--data', default='../datasets/structured_data.json')
parser.add_argument('--save_path', default='../datasets/openie4_train.pkl')
parser.add_argument('--bert_config', default='bert-base-cased')
parser.add_argument('--max_len', type=int, default=64)
main_args = parser.parse_args()
if main_args.mode == 'train':
preprocess_train(main_args)
elif main_args.mode == 'dev_input':
preprocess_dev_input(main_args)
elif main_args.mode == 'dev_gold':
preprocess_dev_gold(main_args)
else:
raise ValueError(f"Invalid preprocessing mode: {main_args.mode}")

View File

@ -1,153 +0,0 @@
import argparse
import torch
import os
import random
import numpy as np
import pickle
import pandas as pd
import json
import copy
from src.Multi2OIE.model import Multi2OIE, BERTBiLSTM
from transformers import get_linear_schedule_with_warmup, AdamW
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError("Boolean value expected.")
def clean_config(config):
device = torch.device(config.device if torch.cuda.is_available() else 'cpu')
config.device = device
config.pred_n_labels = 3
config.arg_n_labels = 9
os.makedirs(config.save_path, exist_ok=True)
return config
def get_models(bert_config,
pred_n_labels=3,
arg_n_labels=9,
n_arg_heads=8,
n_arg_layers=4,
lstm_dropout=0.3,
mh_dropout=0.1,
pred_clf_dropout=0.,
arg_clf_dropout=0.3,
pos_emb_dim=64,
use_lstm=False,
device=None):
if not use_lstm:
return Multi2OIE(
bert_config=bert_config,
mh_dropout=mh_dropout,
pred_clf_dropout=pred_clf_dropout,
arg_clf_dropout=arg_clf_dropout,
n_arg_heads=n_arg_heads,
n_arg_layers=n_arg_layers,
pos_emb_dim=pos_emb_dim,
pred_n_labels=pred_n_labels,
arg_n_labels=arg_n_labels).to(device)
else:
return BERTBiLSTM(
bert_config=bert_config,
lstm_dropout=lstm_dropout,
pred_clf_dropout=pred_clf_dropout,
arg_clf_dropout=arg_clf_dropout,
pos_emb_dim=pos_emb_dim,
pred_n_labels=pred_n_labels,
arg_n_labels=arg_n_labels).to(device)
def save_pkl(path, file):
with open(path, 'wb') as f:
pickle.dump(file, f)
def load_pkl(path):
with open(path, 'rb') as f:
return pickle.load(f)
def get_word2piece(sentence, tokenizer):
words = sentence.split(' ')
word2piece = {idx: list() for idx in range(len(words))}
sentence_pieces = list()
piece_idx = 1
for word_idx, word in enumerate(words):
pieces = tokenizer.tokenize(word)
sentence_pieces += pieces
for piece_idx_added, piece in enumerate(pieces):
word2piece[word_idx].append(piece_idx + piece_idx_added)
piece_idx += len(pieces)
assert len(sentence_pieces) == piece_idx - 1
return word2piece
def get_train_modules(model,
lr,
total_steps,
warmup_steps):
optimizer = AdamW(
model.parameters(), lr=lr, correct_bias=False)
scheduler = get_linear_schedule_with_warmup(
optimizer, warmup_steps, total_steps)
return optimizer, scheduler
class SummaryManager:
def __init__(self, config):
self.config = config
self.save_config()
columns = ['epoch', 'train_predicate_loss', 'train_argument_loss']
for cur_dev_path in config.dev_data_path:
cur_dev_name = cur_dev_path.split('/')[-1].replace('.pkl', '')
for metric in ['f1', 'prec', 'rec', 'auc', 'sum']:
columns.append(f'{cur_dev_name}_{metric}')
columns.append('total_sum')
self.result_df = pd.DataFrame(columns=columns)
self.save_df()
def save_config(self, display=True):
if display:
for key, value in self.config.__dict__.items():
print("{}: {}".format(key, value))
print()
copied = copy.deepcopy(self.config)
copied.device = 'cuda' if torch.cuda.is_available() else 'cpu'
with open(os.path.join(copied.save_path, 'config.json'), 'w') as fp:
json.dump(copied.__dict__, fp, indent=4)
def save_results(self, results):
self.result_df = pd.read_csv(os.path.join(self.config.save_path, 'train_results.csv'))
self.result_df.loc[len(self.result_df.index)] = results
self.save_df()
def save_df(self):
self.result_df.to_csv(os.path.join(self.config.save_path, 'train_results.csv'), index=False)
def set_model_name(dev_results, epoch, step=None):
if step is not None:
return "model-epoch{}-step{}-score{:.4f}.bin".format(epoch, step, dev_results)
else:
return "model-epoch{}-end-score{:.4f}.bin".format(epoch, dev_results)
def print_results(message, results, names):
print(f"\n===== {message} =====")
for result, name in zip(results, names):
print("{}: {:.5f}".format(name, result))
print()

View File

@ -1,132 +0,0 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
# test
test.py
# ckpt
ckpt
# vscode
.vscode
# tacred
benchmark/tacred
*.swp
<<<<<<< HEAD
# data and pretrain
pretrain
benchmark
!benchmark/*.sh
!pretrain/*.sh
# test env
.test
# package
opennre-egg.info
*.sh

View File

@ -1,21 +0,0 @@
MIT License
Copyright (c) 2019 Tianyu Gao
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

File diff suppressed because it is too large Load Diff

View File

@ -1,151 +0,0 @@
# OpenNRE
****** Update ******
We provide two distantly-supervised datasets with human-annotated test sets, **NYT10m** and **Wiki20m**. Check the [datasets](#datasets) section for details.
****** Update ******
OpenNRE is an open-source and extensible toolkit that provides a unified framework to implement relation extraction models. This package is designed for the following groups:
* **New to relation extraction**: We have hand-by-hand tutorials and detailed documents that can not only enable you to use relation extraction tools, but also help you better understand the research progress in this field.
* **Developers**: Our easy-to-use interface and high-performance implementation can acclerate your deployment in the real-world applications. Besides, we provide several pretrained models which can be put into production without any training.
* **Researchers**: With our modular design, various task settings and metric tools, you can easily carry out experiments on your own models with only minor modification. We have also provided several most-used benchmarks for different settings of relation extraction.
* **Anyone who need to submit an NLP homework to impress their professors**: With state-of-the-art models, our package can definitely help you stand out among your classmates!
This package is mainly contributed by [Tianyu Gao](https://github.com/gaotianyu1350), [Xu Han](https://github.com/THUCSTHanxu13), [Shulian Cao](https://github.com/ShulinCao), [Lumin Tang](https://github.com/Tsingularity), [Yankai Lin](https://github.com/Mrlyk423), [Zhiyuan Liu](http://nlp.csai.tsinghua.edu.cn/~lzy/)
## What is Relation Extraction
Relation extraction is a natural language processing (NLP) task aiming at extracting relations (e.g., *founder of*) between entities (e.g., **Bill Gates** and **Microsoft**). For example, from the sentence *Bill Gates founded Microsoft*, we can extract the relation triple (**Bill Gates**, *founder of*, **Microsoft**).
Relation extraction is a crucial technique in automatic knowledge graph construction. By using relation extraction, we can accumulatively extract new relation facts and expand the knowledge graph, which, as a way for machines to understand the human world, has many downstream applications like question answering, recommender system and search engine.
## How to Cite
A good research work is always accompanied by a thorough and faithful reference. If you use or extend our work, please cite the following paper:
```
@inproceedings{han-etal-2019-opennre,
title = "{O}pen{NRE}: An Open and Extensible Toolkit for Neural Relation Extraction",
author = "Han, Xu and Gao, Tianyu and Yao, Yuan and Ye, Deming and Liu, Zhiyuan and Sun, Maosong",
booktitle = "Proceedings of EMNLP-IJCNLP: System Demonstrations",
year = "2019",
url = "https://www.aclweb.org/anthology/D19-3029",
doi = "10.18653/v1/D19-3029",
pages = "169--174"
}
```
It's our honor to help you better explore relation extraction with our OpenNRE toolkit!
## Papers and Document
If you want to learn more about neural relation extraction, visit another project of ours ([NREPapers](https://github.com/thunlp/NREPapers)).
You can refer to our [document](https://opennre-docs.readthedocs.io/en/latest/) for more details about this project.
## Install
### Install as A Python Package
We are now working on deploy OpenNRE as a Python package. Coming soon!
### Using Git Repository
Clone the repository from our github page (don't forget to star us!)
```bash
git clone https://github.com/thunlp/OpenNRE.git
```
If it is too slow, you can try
```
git clone https://github.com/thunlp/OpenNRE.git --depth 1
```
Then install all the requirements:
```
pip install -r requirements.txt
```
**Note**: Please choose appropriate PyTorch version based on your machine (related to your CUDA version). For details, refer to https://pytorch.org/.
Then install the package with
```
python setup.py install
```
If you also want to modify the code, run this:
```
python setup.py develop
```
Note that we have excluded all data and pretrain files for fast deployment. You can manually download them by running scripts in the ``benchmark`` and ``pretrain`` folders. For example, if you want to download FewRel dataset, you can run
```bash
bash benchmark/download_fewrel.sh
```
## Easy Start
Make sure you have installed OpenNRE as instructed above. Then import our package and load pre-trained models.
```python
>>> import opennre
>>> model = opennre.get_model('wiki80_cnn_softmax')
```
Note that it may take a few minutes to download checkpoint and data for the first time. Then use `infer` to do sentence-level relation extraction
```python
>>> model.infer({'text': 'He was the son of Máel Dúin mac Máele Fithrich, and grandson of the high king Áed Uaridnach (died 612).', 'h': {'pos': (18, 46)}, 't': {'pos': (78, 91)}})
('father', 0.5108704566955566)
```
You will get the relation result and its confidence score.
If you want to use the model on your GPU, just run
```python
>>> model = model.cuda()
```
before calling the inference function.
For now, we have the following available models:
* `wiki80_cnn_softmax`: trained on `wiki80` dataset with a CNN encoder.
* `wiki80_bert_softmax`: trained on `wiki80` dataset with a BERT encoder.
* `wiki80_bertentity_softmax`: trained on `wiki80` dataset with a BERT encoder (using entity representation concatenation).
* `tacred_bert_softmax`: trained on `TACRED` dataset with a BERT encoder.
* `tacred_bertentity_softmax`: trained on `TACRED` dataset with a BERT encoder (using entity representation concatenation).
## Datasets
You can go into the `benchmark` folder and download datasets using our scripts. We also list some of the information about the datasets in [this document](https://opennre-docs.readthedocs.io/en/latest/get_started/benchmark.html#bag-level-relation-extraction).
## Training
You can train your own models on your own data with OpenNRE. In `example` folder we give example training codes for supervised RE models and bag-level RE models. You can either use our provided datasets or your own datasets. For example, you can use the following script to train a PCNN-ATT bag-level model on the NYT10 dataset with manual test set:
```bash
python example/train_bag_cnn.py \
--metric auc \
--dataset nyt10m \
--batch_size 160 \
--lr 0.1 \
--weight_decay 1e-5 \
--max_epoch 100 \
--max_length 128 \
--seed 42 \
--encoder pcnn \
--aggr att
```
Or use the following script to train a BERT model on the Wiki80 dataset:
```bash
python example/train_supervised_bert.py \
--pretrain_path bert-base-uncased \
--dataset wiki80
```
We provide many options in the example training code and you can check them out for detailed instructions.

View File

@ -1,479 +0,0 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"collapsed_sections": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "code",
"source": [
"from google.colab import drive\n",
"drive.mount('/content/drive')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "apMnCeBi0kAa",
"executionInfo": {
"status": "ok",
"timestamp": 1668340169194,
"user_tz": -210,
"elapsed": 26751,
"user": {
"displayName": "Mohammad Ebrahimi",
"userId": "10407139745331958037"
}
},
"outputId": "852beb1d-3b81-4f1e-8b41-511ee62a6458"
},
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Mounted at /content/drive\n"
]
}
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "juTvp3cXy0vT",
"executionInfo": {
"status": "ok",
"timestamp": 1668340312085,
"user_tz": -210,
"elapsed": 2,
"user": {
"displayName": "Mohammad Ebrahimi",
"userId": "10407139745331958037"
}
},
"outputId": "76fc9722-9631-4573-b68c-f2d460ebaf96"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"/content/drive/MyDrive/OpenNRE-master\n"
]
}
],
"source": [
"%cd /content/drive/MyDrive/OpenNRE-master"
]
},
{
"cell_type": "code",
"source": [
"!pip install -r requirements.txt"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Zw-L_bBWy3q6",
"executionInfo": {
"status": "ok",
"timestamp": 1668340295967,
"user_tz": -210,
"elapsed": 119014,
"user": {
"displayName": "Mohammad Ebrahimi",
"userId": "10407139745331958037"
}
},
"outputId": "c73fdd46-9d49-4643-a588-3021e6395c32"
},
"execution_count": 5,
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"ERROR: Could not find a version that satisfies the requirement torch==1.6.0 (from versions: 1.7.1, 1.8.0, 1.8.1, 1.9.0, 1.9.1, 1.10.0, 1.10.1, 1.10.2, 1.11.0, 1.12.0, 1.12.1, 1.13.0, 1.13.1)\n",
"ERROR: No matching distribution found for torch==1.6.0\n",
"\n",
"[notice] A new release of pip available: 22.3.1 -> 23.0\n",
"[notice] To update, run: python.exe -m pip install --upgrade pip\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: pip in e:\\hamed\\work\\5\\opennre-master\\venv\\lib\\site-packages (22.3.1)\n",
"Collecting pip\n",
" Using cached pip-23.0-py3-none-any.whl (2.1 MB)\n",
"Installing collected packages: pip\n",
" Attempting uninstall: pip\n",
" Found existing installation: pip 22.3.1\n",
" Uninstalling pip-22.3.1:\n",
" Successfully uninstalled pip-22.3.1\n",
"Successfully installed pip-23.0\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# %%shell\n",
"!python train_supervised_bert.py \\\n",
" --pretrain_path HooshvareLab/bert-base-parsbert-uncased \\\n",
" --dataset none \\\n",
" --train_file ./Perlex/Perlex_train.txt \\\n",
" --val_file ./Perlex/Perlex_val.txt \\\n",
" --test_file ./Perlex/Perlex_test.txt \\\n",
" --rel2id_file ./Perlex/Perlex_rel2id.json \\\n",
" --max_epoch 20"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "aEiTLWpPzV3J",
"executionInfo": {
"status": "ok",
"timestamp": 1665308620117,
"user_tz": -210,
"elapsed": 7519388,
"user": {
"displayName": "arian ebrahimi",
"userId": "00418818321983401320"
}
},
"outputId": "90c72478-d09d-4f96-9417-7c2ad805a66b"
},
"execution_count": 3,
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Traceback (most recent call last):\n",
" File \"E:\\Hamed\\Work\\5\\OpenNRE-master\\train_supervised_bert.py\", line 2, in <module>\n",
" import torch\n",
"ModuleNotFoundError: No module named 'torch'\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"%%shell\n",
"python train_supervised_bert.py \\\n",
" --pretrain_path HooshvareLab/bert-base-parsbert-uncased \\\n",
" --dataset none \\\n",
" --train_file ./Perlex/Perlex_train.txt \\\n",
" --val_file ./Perlex/Perlex_val.txt \\\n",
" --test_file ./Perlex/Perlex_test.txt \\\n",
" --rel2id_file ./Perlex/Perlex_rel2id.json \\\n",
" --max_epoch 20"
],
"metadata": {
"id": "zhYjakhv13c8",
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"status": "ok",
"timestamp": 1665318000439,
"user_tz": -210,
"elapsed": 7490397,
"user": {
"displayName": "arian ebrahimi",
"userId": "00418818321983401320"
}
},
"outputId": "d2d9d494-0222-425f-ff3a-50e01b33419f"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"100% 242/242 [05:39<00:00, 1.40s/it, acc=0.285, loss=2.38]\n",
"100% 47/47 [00:25<00:00, 1.82it/s, acc=0.627]\n",
"100% 242/242 [05:46<00:00, 1.43s/it, acc=0.643, loss=1.2]\n",
"100% 47/47 [00:25<00:00, 1.82it/s, acc=0.718]\n",
"100% 242/242 [05:46<00:00, 1.43s/it, acc=0.739, loss=0.848]\n",
"100% 47/47 [00:25<00:00, 1.82it/s, acc=0.742]\n",
"100% 242/242 [05:46<00:00, 1.43s/it, acc=0.797, loss=0.675]\n",
"100% 47/47 [00:25<00:00, 1.83it/s, acc=0.742]\n",
"100% 242/242 [05:46<00:00, 1.43s/it, acc=0.842, loss=0.544]\n",
"100% 47/47 [00:25<00:00, 1.82it/s, acc=0.752]\n",
"100% 242/242 [05:46<00:00, 1.43s/it, acc=0.872, loss=0.452]\n",
"100% 47/47 [00:25<00:00, 1.82it/s, acc=0.742]\n",
"100% 242/242 [05:45<00:00, 1.43s/it, acc=0.901, loss=0.354]\n",
"100% 47/47 [00:25<00:00, 1.82it/s, acc=0.749]\n",
"100% 242/242 [05:46<00:00, 1.43s/it, acc=0.926, loss=0.292]\n",
"100% 47/47 [00:25<00:00, 1.82it/s, acc=0.746]\n",
"100% 242/242 [05:46<00:00, 1.43s/it, acc=0.941, loss=0.236]\n",
"100% 47/47 [00:25<00:00, 1.81it/s, acc=0.747]\n",
"100% 242/242 [05:46<00:00, 1.43s/it, acc=0.951, loss=0.195]\n",
"100% 47/47 [00:25<00:00, 1.82it/s, acc=0.748]\n",
"100% 242/242 [05:46<00:00, 1.43s/it, acc=0.962, loss=0.162]\n",
"100% 47/47 [00:25<00:00, 1.83it/s, acc=0.746]\n",
"100% 242/242 [05:46<00:00, 1.43s/it, acc=0.969, loss=0.14]\n",
"100% 47/47 [00:25<00:00, 1.83it/s, acc=0.744]\n",
"100% 242/242 [05:46<00:00, 1.43s/it, acc=0.974, loss=0.118]\n",
"100% 47/47 [00:25<00:00, 1.82it/s, acc=0.746]\n",
"100% 242/242 [05:46<00:00, 1.43s/it, acc=0.978, loss=0.102]\n",
"100% 47/47 [00:25<00:00, 1.82it/s, acc=0.746]\n",
"100% 242/242 [05:46<00:00, 1.43s/it, acc=0.981, loss=0.0912]\n",
"100% 47/47 [00:25<00:00, 1.82it/s, acc=0.74]\n",
"100% 242/242 [05:46<00:00, 1.43s/it, acc=0.985, loss=0.081]\n",
"100% 47/47 [00:25<00:00, 1.83it/s, acc=0.742]\n",
"100% 242/242 [05:46<00:00, 1.43s/it, acc=0.986, loss=0.0736]\n",
"100% 47/47 [00:25<00:00, 1.83it/s, acc=0.74]\n",
"100% 242/242 [05:46<00:00, 1.43s/it, acc=0.986, loss=0.0724]\n",
"100% 47/47 [00:25<00:00, 1.81it/s, acc=0.742]\n",
"100% 242/242 [05:46<00:00, 1.43s/it, acc=0.99, loss=0.065]\n",
"100% 47/47 [00:25<00:00, 1.82it/s, acc=0.742]\n",
"100% 242/242 [05:46<00:00, 1.43s/it, acc=0.99, loss=0.0624]\n",
"100% 47/47 [00:25<00:00, 1.82it/s, acc=0.739]\n",
"100% 47/47 [00:25<00:00, 1.82it/s, acc=0.744]\n",
"Test set results:\n",
"Accuracy: 0.7444963308872582\n",
"Micro precision: 0.7831612390786339\n",
"Micro recall: 0.7919678714859437\n",
"Micro F1: 0.7875399361022364\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": []
},
"metadata": {},
"execution_count": 3
}
]
},
{
"cell_type": "code",
"source": [
"%%shell\n",
"python train_supervised_bert.py \\\n",
" --pretrain_path HooshvareLab/bert-base-parsbert-uncased \\\n",
" --dataset none \\\n",
" --train_file ./Perlex/Perlex_train.txt \\\n",
" --val_file ./Perlex/Perlex_val.txt \\\n",
" --test_file ./Perlex/Perlex_test.txt \\\n",
" --rel2id_file ./Perlex/Perlex_rel2id.json \\\n",
" --max_epoch 20 \\\n",
" --lr 15e-6"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "bcdc6f55-f9b9-40b0-a661-777145267ede",
"id": "zUa7CDEeFBGJ"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"100% 242/242 [05:38<00:00, 1.40s/it, acc=0.444, loss=1.84]\n",
"100% 47/47 [00:25<00:00, 1.83it/s, acc=0.704]\n",
"100% 242/242 [05:46<00:00, 1.43s/it, acc=0.733, loss=0.842]\n",
"100% 47/47 [00:25<00:00, 1.82it/s, acc=0.753]\n",
"100% 242/242 [05:46<00:00, 1.43s/it, acc=0.871, loss=0.432]\n",
"100% 47/47 [00:25<00:00, 1.83it/s, acc=0.764]\n",
"100% 242/242 [05:46<00:00, 1.43s/it, acc=0.942, loss=0.208]\n",
"100% 47/47 [00:25<00:00, 1.82it/s, acc=0.747]\n",
"100% 242/242 [05:46<00:00, 1.43s/it, acc=0.979, loss=0.0991]\n",
"100% 47/47 [00:25<00:00, 1.81it/s, acc=0.75]\n",
"100% 242/242 [05:46<00:00, 1.43s/it, acc=0.991, loss=0.0488]\n",
"100% 47/47 [00:25<00:00, 1.81it/s, acc=0.737]\n",
"100% 242/242 [05:46<00:00, 1.43s/it, acc=0.998, loss=0.0243]\n",
"100% 47/47 [00:25<00:00, 1.82it/s, acc=0.746]\n",
"100% 242/242 [05:45<00:00, 1.43s/it, acc=0.999, loss=0.0157]\n",
"100% 47/47 [00:25<00:00, 1.82it/s, acc=0.747]\n",
"100% 242/242 [05:45<00:00, 1.43s/it, acc=1, loss=0.0098]\n",
"100% 47/47 [00:25<00:00, 1.83it/s, acc=0.753]\n",
"100% 242/242 [05:45<00:00, 1.43s/it, acc=0.999, loss=0.00806]\n",
"100% 47/47 [00:25<00:00, 1.82it/s, acc=0.751]\n",
" 65% 157/242 [03:45<02:01, 1.43s/it, acc=0.999, loss=0.00759]"
]
}
]
},
{
"cell_type": "code",
"source": [
"%%shell\n",
"python train_supervised_bert.py \\\n",
" --pretrain_path HooshvareLab/bert-base-parsbert-uncased \\\n",
" --dataset none \\\n",
" --train_file ./Perlex/Perlex_train.txt \\\n",
" --val_file ./Perlex/Perlex_val.txt \\\n",
" --test_file ./Perlex/Perlex_test.txt \\\n",
" --rel2id_file ./Perlex/Perlex_rel2id.json \\\n",
" --max_epoch 20 \\\n",
" --lr 15e-6"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "spUwXoWnPchu",
"outputId": "b1b389c8-c1b5-41bf-ee81-7dfc2cb5f8a0"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Downloading: 100% 654M/654M [00:09<00:00, 66.7MB/s]\n",
"Downloading: 100% 1.22M/1.22M [00:01<00:00, 872kB/s]\n",
"100% 242/242 [05:18<00:00, 1.32s/it, acc=0.444, loss=1.84]\n",
"100% 47/47 [00:23<00:00, 1.99it/s, acc=0.704]\n",
"100% 242/242 [05:29<00:00, 1.36s/it, acc=0.733, loss=0.842]\n",
"100% 47/47 [00:23<00:00, 1.98it/s, acc=0.753]\n",
"100% 242/242 [05:29<00:00, 1.36s/it, acc=0.871, loss=0.432]\n",
"100% 47/47 [00:23<00:00, 1.99it/s, acc=0.764]\n",
"100% 242/242 [05:29<00:00, 1.36s/it, acc=0.942, loss=0.208]\n",
"100% 47/47 [00:23<00:00, 1.99it/s, acc=0.747]\n",
"100% 242/242 [05:28<00:00, 1.36s/it, acc=0.979, loss=0.0991]\n",
"100% 47/47 [00:23<00:00, 1.99it/s, acc=0.75]\n",
"100% 242/242 [05:29<00:00, 1.36s/it, acc=0.991, loss=0.0488]\n",
"100% 47/47 [00:23<00:00, 2.00it/s, acc=0.737]\n",
" 31% 75/242 [01:43<03:47, 1.36s/it, acc=0.998, loss=0.0257]"
]
}
]
},
{
"cell_type": "code",
"source": [
"%%shell\n",
"python train_supervised_bert.py \\\n",
" --pretrain_path HooshvareLab/bert-base-parsbert-uncased \\\n",
" --dataset none \\\n",
" --train_file ./Perlex/Perlex_train.txt \\\n",
" --val_file ./Perlex/Perlex_val.txt \\\n",
" --test_file ./Perlex/Perlex_test.txt \\\n",
" --rel2id_file ./Perlex/Perlex_rel2id.json \\\n",
" --max_epoch 20 \\\n",
" --lr 15e-6"
],
"metadata": {
"id": "_lWuRi4EuuV4",
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"status": "ok",
"timestamp": 1668347083563,
"user_tz": -210,
"elapsed": 441647,
"user": {
"displayName": "Mohammad Ebrahimi",
"userId": "10407139745331958037"
}
},
"outputId": "4080fb9e-cf2c-498c-a412-a8ba1b44ad74"
},
"execution_count": 3,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"100% 101/101 [02:00<00:00, 1.20s/it, acc=0.278, loss=2.4]\n",
"100% 25/25 [00:12<00:00, 2.06it/s, acc=0.541]\n",
"100% 101/101 [02:11<00:00, 1.30s/it, acc=0.667, loss=1.15]\n",
"100% 25/25 [00:12<00:00, 1.97it/s, acc=0.615]\n",
"100% 101/101 [02:12<00:00, 1.31s/it, acc=0.784, loss=0.721]\n",
"100% 25/25 [00:12<00:00, 1.99it/s, acc=0.66]\n",
"100% 101/101 [02:12<00:00, 1.31s/it, acc=0.879, loss=0.422]\n",
"100% 25/25 [00:12<00:00, 2.00it/s, acc=0.656]\n",
"100% 101/101 [02:12<00:00, 1.31s/it, acc=0.947, loss=0.209]\n",
"100% 25/25 [00:12<00:00, 1.98it/s, acc=0.659]\n",
"100% 101/101 [02:12<00:00, 1.31s/it, acc=0.981, loss=0.0976]\n",
"100% 25/25 [00:12<00:00, 1.97it/s, acc=0.655]\n",
"100% 101/101 [02:12<00:00, 1.31s/it, acc=0.99, loss=0.0528]\n",
"100% 25/25 [00:12<00:00, 1.98it/s, acc=0.655]\n",
"100% 101/101 [02:12<00:00, 1.31s/it, acc=0.997, loss=0.0308]\n",
"100% 25/25 [00:12<00:00, 1.96it/s, acc=0.659]\n",
"100% 101/101 [02:12<00:00, 1.31s/it, acc=0.998, loss=0.0199]\n",
"100% 25/25 [00:12<00:00, 1.98it/s, acc=0.657]\n",
"100% 101/101 [02:12<00:00, 1.31s/it, acc=0.999, loss=0.0145]\n",
"100% 25/25 [00:12<00:00, 1.93it/s, acc=0.659]\n",
"100% 101/101 [02:12<00:00, 1.31s/it, acc=1, loss=0.0111]\n",
"100% 25/25 [00:12<00:00, 1.97it/s, acc=0.653]\n",
"100% 101/101 [02:12<00:00, 1.31s/it, acc=1, loss=0.00921]\n",
"100% 25/25 [00:12<00:00, 1.98it/s, acc=0.658]\n",
"100% 101/101 [02:12<00:00, 1.31s/it, acc=1, loss=0.00797]\n",
"100% 25/25 [00:13<00:00, 1.89it/s, acc=0.659]\n",
"100% 101/101 [02:12<00:00, 1.31s/it, acc=1, loss=0.00623]\n",
"100% 25/25 [00:12<00:00, 1.97it/s, acc=0.655]\n",
"100% 101/101 [02:12<00:00, 1.31s/it, acc=1, loss=0.00641]\n",
"100% 25/25 [00:12<00:00, 1.97it/s, acc=0.65]\n",
"100% 101/101 [02:11<00:00, 1.31s/it, acc=1, loss=0.00551]\n",
"100% 25/25 [00:12<00:00, 1.99it/s, acc=0.655]\n",
"100% 101/101 [02:12<00:00, 1.31s/it, acc=1, loss=0.00564]\n",
"100% 25/25 [00:12<00:00, 1.99it/s, acc=0.655]\n",
"100% 101/101 [02:12<00:00, 1.31s/it, acc=1, loss=0.00474]\n",
"100% 25/25 [00:12<00:00, 1.93it/s, acc=0.654]\n",
"100% 101/101 [02:12<00:00, 1.31s/it, acc=0.999, loss=0.00528]\n",
"100% 25/25 [00:12<00:00, 1.95it/s, acc=0.655]\n",
"100% 101/101 [02:11<00:00, 1.31s/it, acc=1, loss=0.00506]\n",
"100% 25/25 [00:12<00:00, 1.98it/s, acc=0.653]\n",
"100% 43/43 [00:21<00:00, 2.04it/s, acc=0.725]\n",
"Test set results:\n",
"Accuracy: 0.7245949926362297\n",
"Micro precision: 0.7602262837249782\n",
"Micro recall: 0.7723253757736517\n",
"Micro F1: 0.7662280701754387\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": []
},
"metadata": {},
"execution_count": 3
}
]
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "K8ArBYUg1gmV"
},
"execution_count": null,
"outputs": []
}
]
}

View File

@ -1,19 +0,0 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from .pretrain import check_root, get_model, download, download_pretrain
import logging
import os
logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=os.environ.get("LOGLEVEL", "INFO"))
def fix_seed(seed=12345):
import torch
import numpy as np
import random
torch.manual_seed(seed) # cpu
torch.cuda.manual_seed(seed) # gpu
np.random.seed(seed) # numpy
random.seed(seed) # random and transforms
torch.backends.cudnn.deterministic=True # cudnn

View File

@ -1,14 +0,0 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from .cnn_encoder import CNNEncoder
from .pcnn_encoder import PCNNEncoder
from .bert_encoder import BERTEncoder, BERTEntityEncoder
__all__ = [
'CNNEncoder',
'PCNNEncoder',
'BERTEncoder',
'BERTEntityEncoder'
]

View File

@ -1,154 +0,0 @@
import math, logging
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from ..tokenization import WordTokenizer
class BaseEncoder(nn.Module):
def __init__(self,
token2id,
max_length=128,
hidden_size=230,
word_size=50,
position_size=5,
blank_padding=True,
word2vec=None,
mask_entity=False):
"""
Args:
token2id: dictionary of token->idx mapping
max_length: max length of sentence, used for postion embedding
hidden_size: hidden size
word_size: size of word embedding
position_size: size of position embedding
blank_padding: padding for CNN
word2vec: pretrained word2vec numpy
"""
# Hyperparameters
super().__init__()
self.token2id = token2id
self.max_length = max_length
self.num_token = len(token2id)
self.num_position = max_length * 2
self.mask_entity = mask_entity
if word2vec is None:
self.word_size = word_size
else:
self.word_size = word2vec.shape[-1]
self.position_size = position_size
self.hidden_size = hidden_size
self.input_size = word_size + position_size * 2
self.blank_padding = blank_padding
if not '[UNK]' in self.token2id:
self.token2id['[UNK]'] = len(self.token2id)
self.num_token += 1
if not '[PAD]' in self.token2id:
self.token2id['[PAD]'] = len(self.token2id)
self.num_token += 1
# Word embedding
self.word_embedding = nn.Embedding(self.num_token, self.word_size)
if word2vec is not None:
logging.info("Initializing word embedding with word2vec.")
word2vec = torch.from_numpy(word2vec)
if self.num_token == len(word2vec) + 2:
unk = torch.randn(1, self.word_size) / math.sqrt(self.word_size)
blk = torch.zeros(1, self.word_size)
self.word_embedding.weight.data.copy_(torch.cat([word2vec, unk, blk], 0))
else:
self.word_embedding.weight.data.copy_(word2vec)
# Position Embedding
self.pos1_embedding = nn.Embedding(2 * max_length, self.position_size, padding_idx=0)
self.pos2_embedding = nn.Embedding(2 * max_length, self.position_size, padding_idx=0)
self.tokenizer = WordTokenizer(vocab=self.token2id, unk_token="[UNK]")
def forward(self, token, pos1, pos2):
"""
Args:
token: (B, L), index of tokens
pos1: (B, L), relative position to head entity
pos2: (B, L), relative position to tail entity
Return:
(B, H), representations for sentences
"""
# Check size of tensors
pass
def tokenize(self, item):
"""
Args:
item: input instance, including sentence, entity positions, etc.
Return:
index number of tokens and positions
"""
if 'text' in item:
sentence = item['text']
is_token = False
else:
sentence = item['token']
is_token = True
pos_head = item['h']['pos']
pos_tail = item['t']['pos']
# Sentence -> token
if not is_token:
if pos_head[0] > pos_tail[0]:
pos_min, pos_max = [pos_tail, pos_head]
rev = True
else:
pos_min, pos_max = [pos_head, pos_tail]
rev = False
sent_0 = self.tokenizer.tokenize(sentence[:pos_min[0]])
sent_1 = self.tokenizer.tokenize(sentence[pos_min[1]:pos_max[0]])
sent_2 = self.tokenizer.tokenize(sentence[pos_max[1]:])
ent_0 = self.tokenizer.tokenize(sentence[pos_min[0]:pos_min[1]])
ent_1 = self.tokenizer.tokenize(sentence[pos_max[0]:pos_max[1]])
if self.mask_entity:
ent_0 = ['[UNK]']
ent_1 = ['[UNK]']
tokens = sent_0 + ent_0 + sent_1 + ent_1 + sent_2
if rev:
pos_tail = [len(sent_0), len(sent_0) + len(ent_0)]
pos_head = [len(sent_0) + len(ent_0) + len(sent_1), len(sent_0) + len(ent_0) + len(sent_1) + len(ent_1)]
else:
pos_head = [len(sent_0), len(sent_0) + len(ent_0)]
pos_tail = [len(sent_0) + len(ent_0) + len(sent_1), len(sent_0) + len(ent_0) + len(sent_1) + len(ent_1)]
else:
tokens = sentence
# Token -> index
if self.blank_padding:
indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokens, self.max_length, self.token2id['[PAD]'], self.token2id['[UNK]'])
else:
indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokens, unk_id = self.token2id['[UNK]'])
# Position -> index
pos1 = []
pos2 = []
pos1_in_index = min(pos_head[0], self.max_length)
pos2_in_index = min(pos_tail[0], self.max_length)
for i in range(len(tokens)):
pos1.append(min(i - pos1_in_index + self.max_length, 2 * self.max_length - 1))
pos2.append(min(i - pos2_in_index + self.max_length, 2 * self.max_length - 1))
if self.blank_padding:
while len(pos1) < self.max_length:
pos1.append(0)
while len(pos2) < self.max_length:
pos2.append(0)
indexed_tokens = indexed_tokens[:self.max_length]
pos1 = pos1[:self.max_length]
pos2 = pos2[:self.max_length]
indexed_tokens = torch.tensor(indexed_tokens).long().unsqueeze(0) # (1, L)
pos1 = torch.tensor(pos1).long().unsqueeze(0) # (1, L)
pos2 = torch.tensor(pos2).long().unsqueeze(0) # (1, L)
return indexed_tokens, pos1, pos2

View File

@ -1,215 +0,0 @@
import logging
import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizer
from .base_encoder import BaseEncoder
class BERTEncoder(nn.Module):
def __init__(self, max_length, pretrain_path, blank_padding=True, mask_entity=False):
"""
Args:
max_length: max length of sentence
pretrain_path: path of pretrain model
"""
super().__init__()
self.max_length = max_length
self.blank_padding = blank_padding
self.hidden_size = 768
self.mask_entity = mask_entity
logging.info('Loading BERT pre-trained checkpoint.')
self.bert = BertModel.from_pretrained(pretrain_path)
self.tokenizer = BertTokenizer.from_pretrained(pretrain_path)
def forward(self, token, att_mask, pos1, pos2):
"""
Args:
token: (B, L), index of tokens
att_mask: (B, L), attention mask (1 for contents and 0 for padding)
Return:
(B, H), representations for sentences
"""
_, x = self.bert(token, attention_mask=att_mask, return_dict=False)
#return_dict=fault is set to adapt to the new version of transformers
return x
def tokenize(self, item):
"""
Args:
item: data instance containing 'text' / 'token', 'h' and 't'
Return:
Name of the relation of the sentence
"""
# Sentence -> token
if 'text' in item:
sentence = item['text']
is_token = False
else:
sentence = item['token']
is_token = True
pos_head = item['h']['pos']
pos_tail = item['t']['pos']
pos_min = pos_head
pos_max = pos_tail
if pos_head[0] > pos_tail[0]:
pos_min = pos_tail
pos_max = pos_head
rev = True
else:
rev = False
if not is_token:
sent0 = self.tokenizer.tokenize(sentence[:pos_min[0]])
ent0 = self.tokenizer.tokenize(sentence[pos_min[0]:pos_min[1]])
sent1 = self.tokenizer.tokenize(sentence[pos_min[1]:pos_max[0]])
ent1 = self.tokenizer.tokenize(sentence[pos_max[0]:pos_max[1]])
sent2 = self.tokenizer.tokenize(sentence[pos_max[1]:])
else:
sent0 = self.tokenizer.tokenize(' '.join(sentence[:pos_min[0]]))
ent0 = self.tokenizer.tokenize(' '.join(sentence[pos_min[0]:pos_min[1]]))
sent1 = self.tokenizer.tokenize(' '.join(sentence[pos_min[1]:pos_max[0]]))
ent1 = self.tokenizer.tokenize(' '.join(sentence[pos_max[0]:pos_max[1]]))
sent2 = self.tokenizer.tokenize(' '.join(sentence[pos_max[1]:]))
if self.mask_entity:
ent0 = ['[unused4]'] if not rev else ['[unused5]']
ent1 = ['[unused5]'] if not rev else ['[unused4]']
else:
ent0 = ['[unused0]'] + ent0 + ['[unused1]'] if not rev else ['[unused2]'] + ent0 + ['[unused3]']
ent1 = ['[unused2]'] + ent1 + ['[unused3]'] if not rev else ['[unused0]'] + ent1 + ['[unused1]']
re_tokens = ['[CLS]'] + sent0 + ent0 + sent1 + ent1 + sent2 + ['[SEP]']
pos1 = 1 + len(sent0) if not rev else 1 + len(sent0 + ent0 + sent1)
pos2 = 1 + len(sent0 + ent0 + sent1) if not rev else 1 + len(sent0)
pos1 = min(self.max_length - 1, pos1)
pos2 = min(self.max_length - 1, pos2)
indexed_tokens = self.tokenizer.convert_tokens_to_ids(re_tokens)
avai_len = len(indexed_tokens)
pos1 = torch.tensor([[pos1]]).long()
pos2 = torch.tensor([[pos2]]).long()
# Padding
if self.blank_padding:
while len(indexed_tokens) < self.max_length:
indexed_tokens.append(0) # 0 is id for [PAD]
indexed_tokens = indexed_tokens[:self.max_length]
indexed_tokens = torch.tensor(indexed_tokens).long().unsqueeze(0) # (1, L)
# Attention mask
att_mask = torch.zeros(indexed_tokens.size()).long() # (1, L)
att_mask[0, :avai_len] = 1
return indexed_tokens, att_mask, pos1, pos2
class BERTEntityEncoder(nn.Module):
def __init__(self, max_length, pretrain_path, blank_padding=True, mask_entity=False):
"""
Args:
max_length: max length of sentence
pretrain_path: path of pretrain model
"""
super().__init__()
self.max_length = max_length
self.blank_padding = blank_padding
self.hidden_size = 768 * 2
self.mask_entity = mask_entity
logging.info('Loading BERT pre-trained checkpoint.')
self.bert = BertModel.from_pretrained(pretrain_path)
self.tokenizer = BertTokenizer.from_pretrained(pretrain_path)
self.linear = nn.Linear(self.hidden_size, self.hidden_size)
def forward(self, token, att_mask, pos1, pos2):
"""
Args:
token: (B, L), index of tokens
att_mask: (B, L), attention mask (1 for contents and 0 for padding)
pos1: (B, 1), position of the head entity starter
pos2: (B, 1), position of the tail entity starter
Return:
(B, 2H), representations for sentences
"""
hidden, _ = self.bert(token, attention_mask=att_mask, return_dict=False)
# Get entity start hidden state
onehot_head = torch.zeros(hidden.size()[:2]).float().to(hidden.device) # (B, L)
onehot_tail = torch.zeros(hidden.size()[:2]).float().to(hidden.device) # (B, L)
onehot_head = onehot_head.scatter_(1, pos1, 1)
onehot_tail = onehot_tail.scatter_(1, pos2, 1)
head_hidden = (onehot_head.unsqueeze(2) * hidden).sum(1) # (B, H)
tail_hidden = (onehot_tail.unsqueeze(2) * hidden).sum(1) # (B, H)
x = torch.cat([head_hidden, tail_hidden], 1) # (B, 2H)
x = self.linear(x)
return x
def tokenize(self, item):
"""
Args:
item: data instance containing 'text' / 'token', 'h' and 't'
Return:
Name of the relation of the sentence
"""
# Sentence -> token
if 'text' in item:
sentence = item['text']
is_token = False
else:
sentence = item['token']
is_token = True
pos_head = item['h']['pos']
pos_tail = item['t']['pos']
pos_min = pos_head
pos_max = pos_tail
if pos_head[0] > pos_tail[0]:
pos_min = pos_tail
pos_max = pos_head
rev = True
else:
rev = False
if not is_token:
sent0 = self.tokenizer.tokenize(sentence[:pos_min[0]])
ent0 = self.tokenizer.tokenize(sentence[pos_min[0]:pos_min[1]])
sent1 = self.tokenizer.tokenize(sentence[pos_min[1]:pos_max[0]])
ent1 = self.tokenizer.tokenize(sentence[pos_max[0]:pos_max[1]])
sent2 = self.tokenizer.tokenize(sentence[pos_max[1]:])
else:
sent0 = self.tokenizer.tokenize(' '.join(sentence[:pos_min[0]]))
ent0 = self.tokenizer.tokenize(' '.join(sentence[pos_min[0]:pos_min[1]]))
sent1 = self.tokenizer.tokenize(' '.join(sentence[pos_min[1]:pos_max[0]]))
ent1 = self.tokenizer.tokenize(' '.join(sentence[pos_max[0]:pos_max[1]]))
sent2 = self.tokenizer.tokenize(' '.join(sentence[pos_max[1]:]))
if self.mask_entity:
ent0 = ['[unused4]'] if not rev else ['[unused5]']
ent1 = ['[unused5]'] if not rev else ['[unused4]']
else:
ent0 = ['[unused0]'] + ent0 + ['[unused1]'] if not rev else ['[unused2]'] + ent0 + ['[unused3]']
ent1 = ['[unused2]'] + ent1 + ['[unused3]'] if not rev else ['[unused0]'] + ent1 + ['[unused1]']
re_tokens = ['[CLS]'] + sent0 + ent0 + sent1 + ent1 + sent2 + ['[SEP]']
pos1 = 1 + len(sent0) if not rev else 1 + len(sent0 + ent0 + sent1)
pos2 = 1 + len(sent0 + ent0 + sent1) if not rev else 1 + len(sent0)
pos1 = min(self.max_length - 1, pos1)
pos2 = min(self.max_length - 1, pos2)
indexed_tokens = self.tokenizer.convert_tokens_to_ids(re_tokens)
avai_len = len(indexed_tokens)
# Position
pos1 = torch.tensor([[pos1]]).long()
pos2 = torch.tensor([[pos2]]).long()
# Padding
if self.blank_padding:
while len(indexed_tokens) < self.max_length:
indexed_tokens.append(0) # 0 is id for [PAD]
indexed_tokens = indexed_tokens[:self.max_length]
indexed_tokens = torch.tensor(indexed_tokens).long().unsqueeze(0) # (1, L)
# Attention mask
att_mask = torch.zeros(indexed_tokens.size()).long() # (1, L)
att_mask[0, :avai_len] = 1
return indexed_tokens, att_mask, pos1, pos2

View File

@ -1,68 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..module.nn import CNN
from ..module.pool import MaxPool
from .base_encoder import BaseEncoder
class CNNEncoder(BaseEncoder):
def __init__(self,
token2id,
max_length=128,
hidden_size=230,
word_size=50,
position_size=5,
blank_padding=True,
word2vec=None,
kernel_size=3,
padding_size=1,
dropout=0,
activation_function=F.relu,
mask_entity=False):
"""
Args:
token2id: dictionary of token->idx mapping
max_length: max length of sentence, used for postion embedding
hidden_size: hidden size
word_size: size of word embedding
position_size: size of position embedding
blank_padding: padding for CNN
word2vec: pretrained word2vec numpy
kernel_size: kernel_size size for CNN
padding_size: padding_size for CNN
"""
# Hyperparameters
super(CNNEncoder, self).__init__(token2id, max_length, hidden_size, word_size, position_size, blank_padding, word2vec, mask_entity=mask_entity)
self.drop = nn.Dropout(dropout)
self.kernel_size = kernel_size
self.padding_size = padding_size
self.act = activation_function
self.conv = nn.Conv1d(self.input_size, self.hidden_size, self.kernel_size, padding=self.padding_size)
self.pool = nn.MaxPool1d(self.max_length)
def forward(self, token, pos1, pos2):
"""
Args:
token: (B, L), index of tokens
pos1: (B, L), relative position to head entity
pos2: (B, L), relative position to tail entity
Return:
(B, EMBED), representations for sentences
"""
# Check size of tensors
if len(token.size()) != 2 or token.size() != pos1.size() or token.size() != pos2.size():
raise Exception("Size of token, pos1 ans pos2 should be (B, L)")
x = torch.cat([self.word_embedding(token),
self.pos1_embedding(pos1),
self.pos2_embedding(pos2)], 2) # (B, L, EMBED)
x = x.transpose(1, 2) # (B, EMBED, L)
x = self.act(self.conv(x)) # (B, H, L)
x = self.pool(x).squeeze(-1)
x = self.drop(x)
return x
def tokenize(self, item):
return super().tokenize(item)

View File

@ -1,173 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..module.nn import CNN
from ..module.pool import MaxPool
from .base_encoder import BaseEncoder
from nltk import word_tokenize
class PCNNEncoder(BaseEncoder):
def __init__(self,
token2id,
max_length=128,
hidden_size=230,
word_size=50,
position_size=5,
blank_padding=True,
word2vec=None,
kernel_size=3,
padding_size=1,
dropout=0.0,
activation_function=F.relu,
mask_entity=False):
"""
Args:
token2id: dictionary of token->idx mapping
max_length: max length of sentence, used for postion embedding
hidden_size: hidden size
word_size: size of word embedding
position_size: size of position embedding
blank_padding: padding for CNN
word2vec: pretrained word2vec numpy
kernel_size: kernel_size size for CNN
padding_size: padding_size for CNN
"""
# hyperparameters
super().__init__(token2id, max_length, hidden_size, word_size, position_size, blank_padding, word2vec, mask_entity=mask_entity)
self.drop = nn.Dropout(dropout)
self.kernel_size = kernel_size
self.padding_size = padding_size
self.act = activation_function
self.conv = nn.Conv1d(self.input_size, self.hidden_size, self.kernel_size, padding=self.padding_size)
self.pool = nn.MaxPool1d(self.max_length)
self.mask_embedding = nn.Embedding(4, 3)
self.mask_embedding.weight.data.copy_(torch.FloatTensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]]))
self.mask_embedding.weight.requires_grad = False
self._minus = -100
self.hidden_size *= 3
def forward(self, token, pos1, pos2, mask):
"""
Args:
token: (B, L), index of tokens
pos1: (B, L), relative position to head entity
pos2: (B, L), relative position to tail entity
Return:
(B, EMBED), representations for sentences
"""
# Check size of tensors
if len(token.size()) != 2 or token.size() != pos1.size() or token.size() != pos2.size():
raise Exception("Size of token, pos1 ans pos2 should be (B, L)")
x = torch.cat([self.word_embedding(token),
self.pos1_embedding(pos1),
self.pos2_embedding(pos2)], 2) # (B, L, EMBED)
x = x.transpose(1, 2) # (B, EMBED, L)
x = self.conv(x) # (B, H, L)
mask = 1 - self.mask_embedding(mask).transpose(1, 2) # (B, L) -> (B, L, 3) -> (B, 3, L)
pool1 = self.pool(self.act(x + self._minus * mask[:, 0:1, :])) # (B, H, 1)
pool2 = self.pool(self.act(x + self._minus * mask[:, 1:2, :]))
pool3 = self.pool(self.act(x + self._minus * mask[:, 2:3, :]))
x = torch.cat([pool1, pool2, pool3], 1) # (B, 3H, 1)
x = x.squeeze(2) # (B, 3H)
x = self.drop(x)
return x
def tokenize(self, item):
"""
Args:
sentence: string, the input sentence
pos_head: [start, end], position of the head entity
pos_end: [start, end], position of the tail entity
is_token: if is_token == True, sentence becomes an array of token
Return:
Name of the relation of the sentence
"""
if 'text' in item:
sentence = item['text']
is_token = False
else:
sentence = item['token']
is_token = True
pos_head = item['h']['pos']
pos_tail = item['t']['pos']
# Sentence -> token
if not is_token:
if pos_head[0] > pos_tail[0]:
pos_min, pos_max = [pos_tail, pos_head]
rev = True
else:
pos_min, pos_max = [pos_head, pos_tail]
rev = False
sent_0 = self.tokenizer.tokenize(sentence[:pos_min[0]])
sent_1 = self.tokenizer.tokenize(sentence[pos_min[1]:pos_max[0]])
sent_2 = self.tokenizer.tokenize(sentence[pos_max[1]:])
ent_0 = self.tokenizer.tokenize(sentence[pos_min[0]:pos_min[1]])
ent_1 = self.tokenizer.tokenize(sentence[pos_max[0]:pos_max[1]])
if self.mask_entity:
ent_0 = ['[UNK]']
ent_1 = ['[UNK]']
tokens = sent_0 + ent_0 + sent_1 + ent_1 + sent_2
if rev:
pos_tail = [len(sent_0), len(sent_0) + len(ent_0)]
pos_head = [len(sent_0) + len(ent_0) + len(sent_1), len(sent_0) + len(ent_0) + len(sent_1) + len(ent_1)]
else:
pos_head = [len(sent_0), len(sent_0) + len(ent_0)]
pos_tail = [len(sent_0) + len(ent_0) + len(sent_1), len(sent_0) + len(ent_0) + len(sent_1) + len(ent_1)]
else:
tokens = sentence
# Token -> index
if self.blank_padding:
indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokens, self.max_length, self.token2id['[PAD]'], self.token2id['[UNK]'])
else:
indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokens, unk_id = self.token2id['[UNK]'])
# Position -> index
pos1 = []
pos2 = []
pos1_in_index = min(pos_head[0], self.max_length)
pos2_in_index = min(pos_tail[0], self.max_length)
for i in range(len(tokens)):
pos1.append(min(i - pos1_in_index + self.max_length, 2 * self.max_length - 1))
pos2.append(min(i - pos2_in_index + self.max_length, 2 * self.max_length - 1))
if self.blank_padding:
while len(pos1) < self.max_length:
pos1.append(0)
while len(pos2) < self.max_length:
pos2.append(0)
indexed_tokens = indexed_tokens[:self.max_length]
pos1 = pos1[:self.max_length]
pos2 = pos2[:self.max_length]
indexed_tokens = torch.tensor(indexed_tokens).long().unsqueeze(0) # (1, L)
pos1 = torch.tensor(pos1).long().unsqueeze(0) # (1, L)
pos2 = torch.tensor(pos2).long().unsqueeze(0) # (1, L)
# Mask
mask = []
pos_min = min(pos1_in_index, pos2_in_index)
pos_max = max(pos1_in_index, pos2_in_index)
for i in range(len(tokens)):
if i <= pos_min:
mask.append(1)
elif i <= pos_max:
mask.append(2)
else:
mask.append(3)
# Padding
if self.blank_padding:
while len(mask) < self.max_length:
mask.append(0)
mask = mask[:self.max_length]
mask = torch.tensor(mask).long().unsqueeze(0) # (1, L)
return indexed_tokens, pos1, pos2, mask

View File

@ -1,20 +0,0 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from .data_loader import SentenceREDataset, SentenceRELoader, BagREDataset, BagRELoader, MultiLabelSentenceREDataset, MultiLabelSentenceRELoader
from .sentence_re import SentenceRE
from .bag_re import BagRE
from .multi_label_sentence_re import MultiLabelSentenceRE
__all__ = [
'SentenceREDataset',
'SentenceRELoader',
'SentenceRE',
'BagRE',
'BagREDataset',
'BagRELoader',
'MultiLabelSentenceREDataset',
'MultiLabelSentenceRELoader',
'MultiLabelSentenceRE'
]

View File

@ -1,184 +0,0 @@
import torch
from torch import nn, optim
import json
from .data_loader import SentenceRELoader, BagRELoader
from .utils import AverageMeter
from tqdm import tqdm
import os
class BagRE(nn.Module):
def __init__(self,
model,
train_path,
val_path,
test_path,
ckpt,
batch_size=32,
max_epoch=100,
lr=0.1,
weight_decay=1e-5,
opt='sgd',
bag_size=0,
loss_weight=False):
super().__init__()
self.max_epoch = max_epoch
self.bag_size = bag_size
# Load data
if train_path != None:
self.train_loader = BagRELoader(
train_path,
model.rel2id,
model.sentence_encoder.tokenize,
batch_size,
True,
bag_size=bag_size,
entpair_as_bag=False)
if val_path != None:
self.val_loader = BagRELoader(
val_path,
model.rel2id,
model.sentence_encoder.tokenize,
batch_size,
False,
bag_size=bag_size,
entpair_as_bag=True)
if test_path != None:
self.test_loader = BagRELoader(
test_path,
model.rel2id,
model.sentence_encoder.tokenize,
batch_size,
False,
bag_size=bag_size,
entpair_as_bag=True
)
# Model
self.model = nn.DataParallel(model)
# Criterion
if loss_weight:
self.criterion = nn.CrossEntropyLoss(weight=self.train_loader.dataset.weight)
else:
self.criterion = nn.CrossEntropyLoss()
# Params and optimizer
params = self.model.parameters()
self.lr = lr
if opt == 'sgd':
self.optimizer = optim.SGD(params, lr, weight_decay=weight_decay)
elif opt == 'adam':
self.optimizer = optim.Adam(params, lr, weight_decay=weight_decay)
elif opt == 'adamw':
from transformers import AdamW
params = list(self.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
grouped_params = [
{
'params': [p for n, p in params if not any(nd in n for nd in no_decay)],
'weight_decay': 0.01,
'lr': lr,
'ori_lr': lr
},
{
'params': [p for n, p in params if any(nd in n for nd in no_decay)],
'weight_decay': 0.0,
'lr': lr,
'ori_lr': lr
}
]
self.optimizer = AdamW(grouped_params, correct_bias=False)
else:
raise Exception("Invalid optimizer. Must be 'sgd' or 'adam' or 'bert_adam'.")
# Cuda
if torch.cuda.is_available():
self.cuda()
# Ckpt
self.ckpt = ckpt
def train_model(self, metric='auc'):
best_metric = 0
for epoch in range(self.max_epoch):
# Train
self.train()
print("=== Epoch %d train ===" % epoch)
avg_loss = AverageMeter()
avg_acc = AverageMeter()
avg_pos_acc = AverageMeter()
t = tqdm(self.train_loader)
for iter, data in enumerate(t):
if torch.cuda.is_available():
for i in range(len(data)):
try:
data[i] = data[i].cuda()
except:
pass
label = data[0]
bag_name = data[1]
scope = data[2]
args = data[3:]
logits = self.model(label, scope, *args, bag_size=self.bag_size)
loss = self.criterion(logits, label)
score, pred = logits.max(-1) # (B)
acc = float((pred == label).long().sum()) / label.size(0)
pos_total = (label != 0).long().sum()
pos_correct = ((pred == label).long() * (label != 0).long()).sum()
if pos_total > 0:
pos_acc = float(pos_correct) / float(pos_total)
else:
pos_acc = 0
# Log
avg_loss.update(loss.item(), 1)
avg_acc.update(acc, 1)
avg_pos_acc.update(pos_acc, 1)
t.set_postfix(loss=avg_loss.avg, acc=avg_acc.avg, pos_acc=avg_pos_acc.avg)
# Optimize
loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()
# Val
print("=== Epoch %d val ===" % epoch)
result = self.eval_model(self.val_loader)
print("AUC: %.4f" % result['auc'])
print("Micro F1: %.4f" % (result['max_micro_f1']))
if result[metric] > best_metric:
print("Best ckpt and saved.")
torch.save({'state_dict': self.model.module.state_dict()}, self.ckpt)
best_metric = result[metric]
print("Best %s on val set: %f" % (metric, best_metric))
def eval_model(self, eval_loader):
self.model.eval()
with torch.no_grad():
t = tqdm(eval_loader)
pred_result = []
for iter, data in enumerate(t):
if torch.cuda.is_available():
for i in range(len(data)):
try:
data[i] = data[i].cuda()
except:
pass
label = data[0]
bag_name = data[1]
scope = data[2]
args = data[3:]
logits = self.model(None, scope, *args, train=False, bag_size=self.bag_size) # results after softmax
logits = logits.cpu().numpy()
for i in range(len(logits)):
for relid in range(self.model.module.num_class):
if self.model.module.id2rel[relid] != 'NA':
pred_result.append({
'entpair': bag_name[i][:2],
'relation': self.model.module.id2rel[relid],
'score': logits[i][relid]
})
result = eval_loader.dataset.eval(pred_result)
return result
def load_state_dict(self, state_dict):
self.model.module.load_state_dict(state_dict)

View File

@ -1,459 +0,0 @@
import torch
import torch.utils.data as data
import os, random, json, logging
import numpy as np
import sklearn.metrics
class SentenceREDataset(data.Dataset):
"""
Sentence-level relation extraction dataset
"""
def __init__(self, path, rel2id, tokenizer, kwargs):
"""
Args:
path: path of the input file
rel2id: dictionary of relation->id mapping
tokenizer: function of tokenizing
"""
super().__init__()
self.path = path
self.tokenizer = tokenizer
self.rel2id = rel2id
self.kwargs = kwargs
# Load the file
f = open(path)
self.data = []
for line in f.readlines():
line = line.rstrip()
if len(line) > 0:
self.data.append(eval(line))
f.close()
logging.info("Loaded sentence RE dataset {} with {} lines and {} relations.".format(path, len(self.data), len(self.rel2id)))
def __len__(self):
return len(self.data)
def __getitem__(self, index):
item = self.data[index]
seq = list(self.tokenizer(item, **self.kwargs))
res = [self.rel2id[item['relation']]] + seq
return [self.rel2id[item['relation']]] + seq # label, seq1, seq2, ...
def collate_fn(data):
data = list(zip(*data))
labels = data[0]
seqs = data[1:]
batch_labels = torch.tensor(labels).long() # (B)
batch_seqs = []
for seq in seqs:
batch_seqs.append(torch.cat(seq, 0)) # (B, L)
return [batch_labels] + batch_seqs
def eval(self, pred_result, use_name=False):
"""
Args:
pred_result: a list of predicted label (id)
Make sure that the `shuffle` param is set to `False` when getting the loader.
use_name: if True, `pred_result` contains predicted relation names instead of ids
Return:
{'acc': xx}
"""
correct = 0
total = len(self.data)
correct_positive = 0
pred_positive = 0
gold_positive = 0
neg = -1
for name in ['NA', 'na', 'no_relation', 'Other', 'Others']:
if name in self.rel2id:
if use_name:
neg = name
else:
neg = self.rel2id[name]
break
for i in range(total):
if use_name:
golden = self.data[i]['relation']
else:
golden = self.rel2id[self.data[i]['relation']]
if golden == pred_result[i]:
correct += 1
if golden != neg:
correct_positive += 1
if golden != neg:
gold_positive +=1
if pred_result[i] != neg:
pred_positive += 1
acc = float(correct) / float(total)
try:
micro_p = float(correct_positive) / float(pred_positive)
except:
micro_p = 0
try:
micro_r = float(correct_positive) / float(gold_positive)
except:
micro_r = 0
try:
micro_f1 = 2 * micro_p * micro_r / (micro_p + micro_r)
except:
micro_f1 = 0
result = {'acc': acc, 'micro_p': micro_p, 'micro_r': micro_r, 'micro_f1': micro_f1}
logging.info('Evaluation result: {}.'.format(result))
return result
def SentenceRELoader(path, rel2id, tokenizer, batch_size,
shuffle, num_workers=8, collate_fn=SentenceREDataset.collate_fn, **kwargs):
dataset = SentenceREDataset(path = path, rel2id = rel2id, tokenizer = tokenizer, kwargs=kwargs)
data_loader = data.DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=shuffle,
pin_memory=True,
num_workers=num_workers,
collate_fn=collate_fn)
return data_loader
class BagREDataset(data.Dataset):
"""
Bag-level relation extraction dataset. Note that relation of NA should be named as 'NA'.
"""
def __init__(self, path, rel2id, tokenizer, entpair_as_bag=False, bag_size=0, mode=None):
"""
Args:
path: path of the input file
rel2id: dictionary of relation->id mapping
tokenizer: function of tokenizing
entpair_as_bag: if True, bags are constructed based on same
entity pairs instead of same relation facts (ignoring
relation labels)
"""
super().__init__()
self.tokenizer = tokenizer
self.rel2id = rel2id
self.entpair_as_bag = entpair_as_bag
self.bag_size = bag_size
# Load the file
f = open(path)
self.data = []
for line in f:
line = line.rstrip()
if len(line) > 0:
self.data.append(eval(line))
f.close()
# Construct bag-level dataset (a bag contains instances sharing the same relation fact)
if mode == None:
self.weight = np.ones((len(self.rel2id)), dtype=np.float32)
self.bag_scope = []
self.name2id = {}
self.bag_name = []
self.facts = {}
for idx, item in enumerate(self.data):
# Annotated test set
if 'anno_relation_list' in item:
for r in item['anno_relation_list']:
fact = (item['h']['id'], item['t']['id'], r)
if r != 'NA':
self.facts[fact] = 1
assert entpair_as_bag
name = (item['h']['id'], item['t']['id'])
else:
fact = (item['h']['id'], item['t']['id'], item['relation'])
if item['relation'] != 'NA':
self.facts[fact] = 1
if entpair_as_bag:
name = (item['h']['id'], item['t']['id'])
else:
name = fact
if name not in self.name2id:
self.name2id[name] = len(self.name2id)
self.bag_scope.append([])
self.bag_name.append(name)
self.bag_scope[self.name2id[name]].append(idx)
self.weight[self.rel2id[item['relation']]] += 1.0
self.weight = 1.0 / (self.weight ** 0.05)
self.weight = torch.from_numpy(self.weight)
else:
pass
def __len__(self):
return len(self.bag_scope)
def __getitem__(self, index):
bag = self.bag_scope[index]
if self.bag_size > 0:
if self.bag_size <= len(bag):
resize_bag = random.sample(bag, self.bag_size)
else:
resize_bag = bag + list(np.random.choice(bag, self.bag_size - len(bag)))
bag = resize_bag
seqs = None
rel = self.rel2id[self.data[bag[0]]['relation']]
for sent_id in bag:
item = self.data[sent_id]
seq = list(self.tokenizer(item))
if seqs is None:
seqs = []
for i in range(len(seq)):
seqs.append([])
for i in range(len(seq)):
seqs[i].append(seq[i])
for i in range(len(seqs)):
seqs[i] = torch.cat(seqs[i], 0) # (n, L), n is the size of bag
return [rel, self.bag_name[index], len(bag)] + seqs
def collate_fn(data):
data = list(zip(*data))
label, bag_name, count = data[:3]
seqs = data[3:]
for i in range(len(seqs)):
seqs[i] = torch.cat(seqs[i], 0) # (sumn, L)
seqs[i] = seqs[i].expand((torch.cuda.device_count() if torch.cuda.device_count() > 0 else 1, ) + seqs[i].size())
scope = [] # (B, 2)
start = 0
for c in count:
scope.append((start, start + c))
start += c
assert(start == seqs[0].size(1))
scope = torch.tensor(scope).long()
label = torch.tensor(label).long() # (B)
return [label, bag_name, scope] + seqs
def collate_bag_size_fn(data):
data = list(zip(*data))
label, bag_name, count = data[:3]
seqs = data[3:]
for i in range(len(seqs)):
seqs[i] = torch.stack(seqs[i], 0) # (batch, bag, L)
scope = [] # (B, 2)
start = 0
for c in count:
scope.append((start, start + c))
start += c
label = torch.tensor(label).long() # (B)
return [label, bag_name, scope] + seqs
def eval(self, pred_result, threshold=0.5):
"""
Args:
pred_result: a list with dict {'entpair': (head_id, tail_id), 'relation': rel, 'score': score}.
Note that relation of NA should be excluded.
Return:
{'prec': narray[...], 'rec': narray[...], 'mean_prec': xx, 'f1': xx, 'auc': xx}
prec (precision) and rec (recall) are in micro style.
prec (precision) and rec (recall) are sorted in the decreasing order of the score.
f1 is the max f1 score of those precison-recall points
"""
sorted_pred_result = sorted(pred_result, key=lambda x: x['score'], reverse=True)
prec = []
rec = []
correct = 0
total = len(self.facts)
entpair = {}
for i, item in enumerate(sorted_pred_result):
# Save entpair label and result for later calculating F1
idtf = item['entpair'][0] + '#' + item['entpair'][1]
if idtf not in entpair:
entpair[idtf] = {
'label': np.zeros((len(self.rel2id)), dtype=np.int),
'pred': np.zeros((len(self.rel2id)), dtype=np.int),
'score': np.zeros((len(self.rel2id)), dtype=np.float)
}
if (item['entpair'][0], item['entpair'][1], item['relation']) in self.facts:
correct += 1
entpair[idtf]['label'][self.rel2id[item['relation']]] = 1
if item['score'] >= threshold:
entpair[idtf]['pred'][self.rel2id[item['relation']]] = 1
entpair[idtf]['score'][self.rel2id[item['relation']]] = item['score']
prec.append(float(correct) / float(i + 1))
rec.append(float(correct) / float(total))
auc = sklearn.metrics.auc(x=rec, y=prec)
np_prec = np.array(prec)
np_rec = np.array(rec)
max_micro_f1 = (2 * np_prec * np_rec / (np_prec + np_rec + 1e-20)).max()
best_threshold = sorted_pred_result[(2 * np_prec * np_rec / (np_prec + np_rec + 1e-20)).argmax()]['score']
mean_prec = np_prec.mean()
label_vec = []
pred_result_vec = []
score_vec = []
for ep in entpair:
label_vec.append(entpair[ep]['label'])
pred_result_vec.append(entpair[ep]['pred'])
score_vec.append(entpair[ep]['score'])
label_vec = np.stack(label_vec, 0)
pred_result_vec = np.stack(pred_result_vec, 0)
score_vec = np.stack(score_vec, 0)
micro_p = sklearn.metrics.precision_score(label_vec, pred_result_vec, labels=list(range(1, len(self.rel2id))), average='micro')
micro_r = sklearn.metrics.recall_score(label_vec, pred_result_vec, labels=list(range(1, len(self.rel2id))), average='micro')
micro_f1 = sklearn.metrics.f1_score(label_vec, pred_result_vec, labels=list(range(1, len(self.rel2id))), average='micro')
macro_p = sklearn.metrics.precision_score(label_vec, pred_result_vec, labels=list(range(1, len(self.rel2id))), average='macro')
macro_r = sklearn.metrics.recall_score(label_vec, pred_result_vec, labels=list(range(1, len(self.rel2id))), average='macro')
macro_f1 = sklearn.metrics.f1_score(label_vec, pred_result_vec, labels=list(range(1, len(self.rel2id))), average='macro')
pred_result_vec = score_vec >= best_threshold
max_macro_f1 = sklearn.metrics.f1_score(label_vec, pred_result_vec, labels=list(range(1, len(self.rel2id))), average='macro')
max_micro_f1_each_relation = {}
for rel in self.rel2id:
if rel != 'NA':
max_micro_f1_each_relation[rel] = sklearn.metrics.f1_score(label_vec, pred_result_vec, labels=[self.rel2id[rel]], average='micro')
return {'np_prec': np_prec, 'np_rec': np_rec, 'max_micro_f1': max_micro_f1, 'max_macro_f1': max_macro_f1, 'auc': auc, 'p@100': np_prec[99], 'p@200': np_prec[199], 'p@300': np_prec[299], 'avg_p300': (np_prec[99] + np_prec[199] + np_prec[299]) / 3, 'micro_f1': micro_f1, 'macro_f1': macro_f1, 'max_micro_f1_each_relation': max_micro_f1_each_relation}
def BagRELoader(path, rel2id, tokenizer, batch_size,
shuffle, entpair_as_bag=False, bag_size=0, num_workers=8,
collate_fn=BagREDataset.collate_fn):
if bag_size == 0:
collate_fn = BagREDataset.collate_fn
else:
collate_fn = BagREDataset.collate_bag_size_fn
dataset = BagREDataset(path, rel2id, tokenizer, entpair_as_bag=entpair_as_bag, bag_size=bag_size)
data_loader = data.DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=shuffle,
pin_memory=True,
num_workers=num_workers,
collate_fn=collate_fn)
return data_loader
class MultiLabelSentenceREDataset(data.Dataset):
"""
Sentence-level relation extraction dataset
"""
def __init__(self, path, rel2id, tokenizer, kwargs):
"""
Args:
path: path of the input file
rel2id: dictionary of relation->id mapping
tokenizer: function of tokenizing
"""
super().__init__()
self.path = path
self.tokenizer = tokenizer
self.rel2id = rel2id
self.kwargs = kwargs
# Load the file
f = open(path)
self.data = []
for line in f.readlines():
line = line.rstrip()
if len(line) > 0:
self.data.append(eval(line))
f.close()
logging.info("Loaded sentence RE dataset {} with {} lines and {} relations.".format(path, len(self.data), len(self.rel2id)))
def __len__(self):
return len(self.data)
def __getitem__(self, index):
item = self.data[index]
seq = list(self.tokenizer(item, **self.kwargs))
res = [self.rel2id[item['relation']]] + seq
return [self.rel2id[item['relation']]] + seq # label, seq1, seq2, ...
def collate_fn(data):
data = list(zip(*data))
labels = data[0]
seqs = data[1:]
batch_labels = torch.tensor(labels).long() # (B)
batch_seqs = []
for seq in seqs:
batch_seqs.append(torch.cat(seq, 0)) # (B, L)
return [batch_labels] + batch_seqs
def eval(self, pred_score, threshold=0.5, use_name=False):
"""
Args:
pred_score: [sent_num, label_num]
use_name: if True, `pred_result` contains predicted relation names instead of ids
Return:
{'acc': xx}
"""
assert len(self.data) == len(pred_score)
pred_score = np.array(pred_score)
# Calculate AUC
sorted_result = []
total = 0
for sent_id in range(len(self.data)):
for rel in self.rel2id:
if rel not in ['NA', 'na', 'N/A', 'None', 'none', 'n/a', 'no_relation']:
sorted_result.append({'sent_id': sent_id, 'relation': rel, 'score': pred_score[sent_id][self.rel2id[rel]]})
if 'anno_relation_list' in self.data[sent_id]:
if rel in self.data[sent_id]['anno_relation_list']:
total += 1
else:
if rel == self.data[sent_id]['relation']:
total += 1
sorted_result.sort(key=lambda x: x['score'], reverse=True)
prec = []
rec = []
correct = 0
for i, item in enumerate(sorted_result):
if 'anno_relation_list' in self.data[item['sent_id']]:
if item['relation'] in self.data[item['sent_id']]['anno_relation_list']:
correct += 1
else:
if item['relation'] == self.data[item['sent_id']]['relation']:
correct += 1
prec.append(float(correct) / float(i + 1))
rec.append(float(correct) / float(total))
auc = sklearn.metrics.auc(x=rec, y=prec)
np_prec = np.array(prec)
np_rec = np.array(rec)
max_micro_f1 = (2 * np_prec * np_rec / (np_prec + np_rec + 1e-20)).max()
max_micro_f1_threshold = sorted_result[(2 * np_prec * np_rec / (np_prec + np_rec + 1e-20)).argmax()]['score']
mean_prec = np_prec.mean()
# Calculate F1
pred_result_vec = np.zeros((len(self.data), len(self.rel2id)), dtype=np.int)
pred_result_vec[pred_score >= threshold] = 1
label_vec = []
for item in self.data:
if 'anno_relation_list' in item:
label_vec.append(np.array(item['anno_relation_vec'], dtype=np.int))
else:
one_hot = np.zeros((len(self.rel2id)), dtype=np.int)
one_hot[self.rel2id[item['relation']]] = 1
label_vec.append(one_hot)
label_vec = np.stack(label_vec, 0)
assert label_vec.shape == pred_result_vec.shape
micro_p = sklearn.metrics.precision_score(label_vec, pred_result_vec, labels=list(range(1, len(self.rel2id))), average='micro')
micro_r = sklearn.metrics.recall_score(label_vec, pred_result_vec, labels=list(range(1, len(self.rel2id))), average='micro')
micro_f1 = sklearn.metrics.f1_score(label_vec, pred_result_vec, labels=list(range(1, len(self.rel2id))), average='micro')
macro_p = sklearn.metrics.precision_score(label_vec, pred_result_vec, labels=list(range(1, len(self.rel2id))), average='macro')
macro_r = sklearn.metrics.recall_score(label_vec, pred_result_vec, labels=list(range(1, len(self.rel2id))), average='macro')
macro_f1 = sklearn.metrics.f1_score(label_vec, pred_result_vec, labels=list(range(1, len(self.rel2id))), average='macro')
acc = (label_vec == pred_result_vec).mean()
result = {'acc': acc, 'micro_p': micro_p, 'micro_r': micro_r, 'micro_f1': micro_f1, 'macro_p': macro_p, 'macro_r': macro_r, 'macro_f1': macro_f1, 'np_prec': np_prec, 'np_rec': np_rec, 'max_micro_f1': max_micro_f1, 'max_micro_f1_threshold': max_micro_f1_threshold, 'auc': auc, 'p@100': np_prec[99], 'p@200': np_prec[199], 'p@300': np_prec[299]}
logging.info('Evaluation result: {}.'.format(result))
return result
def MultiLabelSentenceRELoader(path, rel2id, tokenizer, batch_size,
shuffle, num_workers=8, collate_fn=SentenceREDataset.collate_fn, **kwargs):
dataset = MultiLabelSentenceREDataset(path = path, rel2id = rel2id, tokenizer = tokenizer, kwargs=kwargs)
data_loader = data.DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=shuffle,
pin_memory=True,
num_workers=num_workers,
collate_fn=collate_fn)
return data_loader

View File

@ -1,175 +0,0 @@
import os, logging, json
from tqdm import tqdm
import torch
from torch import nn, optim
from .data_loader import MultiLabelSentenceRELoader
from .utils import AverageMeter
import numpy as np
class MultiLabelSentenceRE(nn.Module):
def __init__(self,
model,
train_path,
val_path,
test_path,
ckpt,
batch_size=32,
max_epoch=100,
lr=0.1,
weight_decay=1e-5,
warmup_step=300,
opt='sgd'):
super().__init__()
self.max_epoch = max_epoch
# Load data
if train_path != None:
self.train_loader = MultiLabelSentenceRELoader(
train_path,
model.rel2id,
model.sentence_encoder.tokenize,
batch_size,
True)
if val_path != None:
self.val_loader = MultiLabelSentenceRELoader(
val_path,
model.rel2id,
model.sentence_encoder.tokenize,
batch_size,
False)
if test_path != None:
self.test_loader = MultiLabelSentenceRELoader(
test_path,
model.rel2id,
model.sentence_encoder.tokenize,
batch_size,
False
)
# Model
self.model = model
self.parallel_model = nn.DataParallel(self.model)
# Criterion
self.criterion = nn.BCEWithLogitsLoss()
# Params and optimizer
params = self.parameters()
self.lr = lr
if opt == 'sgd':
self.optimizer = optim.SGD(params, lr, weight_decay=weight_decay)
elif opt == 'adam':
self.optimizer = optim.Adam(params, lr, weight_decay=weight_decay)
elif opt == 'adamw': # Optimizer for BERT
from transformers import AdamW
params = list(self.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
grouped_params = [
{
'params': [p for n, p in params if not any(nd in n for nd in no_decay)],
'weight_decay': 0.01,
'lr': lr,
'ori_lr': lr
},
{
'params': [p for n, p in params if any(nd in n for nd in no_decay)],
'weight_decay': 0.0,
'lr': lr,
'ori_lr': lr
}
]
self.optimizer = AdamW(grouped_params, correct_bias=False)
else:
raise Exception("Invalid optimizer. Must be 'sgd' or 'adam' or 'adamw'.")
# Warmup
if warmup_step > 0:
from transformers import get_linear_schedule_with_warmup
training_steps = self.train_loader.dataset.__len__() // batch_size * self.max_epoch
self.scheduler = get_linear_schedule_with_warmup(self.optimizer, num_warmup_steps=warmup_step, num_training_steps=training_steps)
else:
self.scheduler = None
# Cuda
if torch.cuda.is_available():
self.cuda()
# Ckpt
self.ckpt = ckpt
def train_model(self, metric='acc'):
best_metric = 0
global_step = 0
for epoch in range(self.max_epoch):
self.train()
logging.info("=== Epoch %d train ===" % epoch)
avg_loss = AverageMeter()
avg_acc = AverageMeter()
t = tqdm(self.train_loader)
for iter, data in enumerate(t):
if torch.cuda.is_available():
for i in range(len(data)):
try:
data[i] = data[i].cuda()
except:
pass
label = data[0]
args = data[1:]
logits = self.parallel_model(*args)
label_vec = torch.zeros_like(logits).cuda()
label_vec[torch.arange(label_vec.size(0)), label] = 1
label_vec = label_vec[:, 1:]
logits = logits[:, 1:]
loss = self.criterion(logits.reshape(-1), label_vec.reshape(-1))
pred = (torch.sigmoid(logits) >= 0.5).long()
acc = float((pred == label_vec).long().sum()) / (label_vec.size(0) * label_vec.size(1))
# Log
avg_loss.update(loss.item(), 1)
avg_acc.update(acc, 1)
t.set_postfix(loss=avg_loss.avg, acc=avg_acc.avg)
# Optimize
loss.backward()
self.optimizer.step()
if self.scheduler is not None:
self.scheduler.step()
self.optimizer.zero_grad()
global_step += 1
# Val
logging.info("=== Epoch %d val ===" % epoch)
result = self.eval_model(self.val_loader)
logging.info('Metric {} current / best: {} / {}'.format(metric, result[metric], best_metric))
if result[metric] > best_metric:
logging.info("Best ckpt and saved.")
folder_path = '/'.join(self.ckpt.split('/')[:-1])
if not os.path.exists(folder_path):
os.mkdir(folder_path)
torch.save({'state_dict': self.model.state_dict()}, self.ckpt)
best_metric = result[metric]
logging.info("Best %s on val set: %f" % (metric, best_metric))
def eval_model(self, eval_loader):
self.eval()
pred_score = []
with torch.no_grad():
t = tqdm(eval_loader)
for iter, data in enumerate(t):
if torch.cuda.is_available():
for i in range(len(data)):
try:
data[i] = data[i].cuda()
except:
pass
label = data[0]
args = data[1:]
logits = self.parallel_model(*args)
score = self.parallel_model.module.logit_to_score(logits).cpu().numpy()
# Save result
pred_score.append(score)
# Log
pred_score = np.concatenate(pred_score, 0)
result = eval_loader.dataset.eval(pred_score)
return result
def load_state_dict(self, state_dict):
self.model.load_state_dict(state_dict)

View File

@ -1,171 +0,0 @@
import os, logging, json
from tqdm import tqdm
import torch
from torch import nn, optim
from .data_loader import SentenceRELoader
from .utils import AverageMeter
class SentenceRE(nn.Module):
def __init__(self,
model,
train_path,
val_path,
test_path,
ckpt,
batch_size=32,
max_epoch=100,
lr=0.1,
weight_decay=1e-5,
warmup_step=300,
opt='sgd'):
super().__init__()
self.max_epoch = max_epoch
# Load data
if train_path != None:
self.train_loader = SentenceRELoader(
train_path,
model.rel2id,
model.sentence_encoder.tokenize,
batch_size,
True)
if val_path != None:
self.val_loader = SentenceRELoader(
val_path,
model.rel2id,
model.sentence_encoder.tokenize,
batch_size,
False)
if test_path != None:
self.test_loader = SentenceRELoader(
test_path,
model.rel2id,
model.sentence_encoder.tokenize,
batch_size,
False
)
# Model
self.model = model
self.parallel_model = nn.DataParallel(self.model)
# Criterion
self.criterion = nn.CrossEntropyLoss()
# Params and optimizer
params = self.parameters()
self.lr = lr
if opt == 'sgd':
self.optimizer = optim.SGD(params, lr, weight_decay=weight_decay)
elif opt == 'adam':
self.optimizer = optim.Adam(params, lr, weight_decay=weight_decay)
elif opt == 'adamw': # Optimizer for BERT
from transformers import AdamW
params = list(self.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
grouped_params = [
{
'params': [p for n, p in params if not any(nd in n for nd in no_decay)],
'weight_decay': 0.01,
'lr': lr,
'ori_lr': lr
},
{
'params': [p for n, p in params if any(nd in n for nd in no_decay)],
'weight_decay': 0.0,
'lr': lr,
'ori_lr': lr
}
]
self.optimizer = AdamW(grouped_params, correct_bias=False)
else:
raise Exception("Invalid optimizer. Must be 'sgd' or 'adam' or 'adamw'.")
# Warmup
if warmup_step > 0:
from transformers import get_linear_schedule_with_warmup
training_steps = self.train_loader.dataset.__len__() // batch_size * self.max_epoch
self.scheduler = get_linear_schedule_with_warmup(self.optimizer, num_warmup_steps=warmup_step, num_training_steps=training_steps)
else:
self.scheduler = None
# Cuda
if torch.cuda.is_available():
self.cuda()
# Ckpt
self.ckpt = ckpt
def train_model(self, metric='acc'):
best_metric = 0
global_step = 0
for epoch in range(self.max_epoch):
self.train()
logging.info("=== Epoch %d train ===" % epoch)
avg_loss = AverageMeter()
avg_acc = AverageMeter()
t = tqdm(self.train_loader)
for iter, data in enumerate(t):
if torch.cuda.is_available():
for i in range(len(data)):
try:
data[i] = data[i].cuda()
except:
pass
label = data[0]
args = data[1:]
logits = self.parallel_model(*args)
loss = self.criterion(logits, label)
score, pred = logits.max(-1) # (B)
acc = float((pred == label).long().sum()) / label.size(0)
# Log
avg_loss.update(loss.item(), 1)
avg_acc.update(acc, 1)
t.set_postfix(loss=avg_loss.avg, acc=avg_acc.avg)
# Optimize
loss.backward()
self.optimizer.step()
if self.scheduler is not None:
self.scheduler.step()
self.optimizer.zero_grad()
global_step += 1
# Val
logging.info("=== Epoch %d val ===" % epoch)
result = self.eval_model(self.val_loader)
logging.info('Metric {} current / best: {} / {}'.format(metric, result[metric], best_metric))
if result[metric] > best_metric:
logging.info("Best ckpt and saved.")
folder_path = '/'.join(self.ckpt.split('/')[:-1])
if not os.path.exists(folder_path):
os.mkdir(folder_path)
torch.save({'state_dict': self.model.state_dict()}, self.ckpt)
best_metric = result[metric]
logging.info("Best %s on val set: %f" % (metric, best_metric))
def eval_model(self, eval_loader):
self.eval()
avg_acc = AverageMeter()
pred_result = []
with torch.no_grad():
t = tqdm(eval_loader)
for iter, data in enumerate(t):
if torch.cuda.is_available():
for i in range(len(data)):
try:
data[i] = data[i].cuda()
except:
pass
label = data[0]
args = data[1:]
logits = self.parallel_model(*args)
score, pred = logits.max(-1) # (B)
# Save result
for i in range(pred.size(0)):
pred_result.append(pred[i].item())
# Log
acc = float((pred == label).long().sum()) / label.size(0)
avg_acc.update(acc, pred.size(0))
t.set_postfix(acc=avg_acc.avg)
result = eval_loader.dataset.eval(pred_result)
return result
def load_state_dict(self, state_dict):
self.model.load_state_dict(state_dict)

View File

@ -1,29 +0,0 @@
class AverageMeter(object):
"""
Computes and stores the average and current value of metrics.
"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=0):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / (.0001 + self.count)
def __str__(self):
"""
String representation for logging
"""
# for values that should be recorded exactly e.g. iteration number
if self.count == 0:
return str(self.val)
# for stats
return '%.4f (%.4f)' % (self.val, self.avg)

View File

@ -1,21 +0,0 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from .base_model import SentenceRE, BagRE, FewShotRE, NER
from .softmax_nn import SoftmaxNN
from .sigmoid_nn import SigmoidNN
from .bag_attention import BagAttention
from .bag_average import BagAverage
from .bag_one import BagOne
__all__ = [
'SentenceRE',
'BagRE',
'FewShotRE',
'NER',
'SoftmaxNN',
'BagAttention',
'BagAverage',
'BagOne'
]

View File

@ -1,182 +0,0 @@
import torch
from torch import nn, optim
from .base_model import BagRE
class BagAttention(BagRE):
"""
Instance attention for bag-level relation extraction.
"""
def __init__(self, sentence_encoder, num_class, rel2id, use_diag=True):
"""
Args:
sentence_encoder: encoder for sentences
num_class: number of classes
id2rel: dictionary of id -> relation name mapping
"""
super().__init__()
self.sentence_encoder = sentence_encoder
self.num_class = num_class
self.fc = nn.Linear(self.sentence_encoder.hidden_size, num_class)
self.softmax = nn.Softmax(-1)
self.rel2id = rel2id
self.id2rel = {}
self.drop = nn.Dropout()
for rel, id in rel2id.items():
self.id2rel[id] = rel
if use_diag:
self.use_diag = True
self.diag = nn.Parameter(torch.ones(self.sentence_encoder.hidden_size))
else:
self.use_diag = False
def infer(self, bag):
"""
Args:
bag: bag of sentences with the same entity pair
[{
'text' or 'token': ...,
'h': {'pos': [start, end], ...},
't': {'pos': [start, end], ...}
}]
Return:
(relation, score)
"""
self.eval()
tokens = []
pos1s = []
pos2s = []
masks = []
for item in bag:
token, pos1, pos2, mask = self.sentence_encoder.tokenize(item)
tokens.append(token)
pos1s.append(pos1)
pos2s.append(pos2)
masks.append(mask)
tokens = torch.cat(tokens, 0).unsqueeze(0) # (n, L)
pos1s = torch.cat(pos1s, 0).unsqueeze(0)
pos2s = torch.cat(pos2s, 0).unsqueeze(0)
masks = torch.cat(masks, 0).unsqueeze(0)
scope = torch.tensor([[0, len(bag)]]).long() # (1, 2)
bag_logits = self.forward(None, scope, tokens, pos1s, pos2s, masks, train=False).squeeze(0) # (N) after softmax
score, pred = bag_logits.max(0)
score = score.item()
pred = pred.item()
rel = self.id2rel[pred]
return (rel, score)
def forward(self, label, scope, token, pos1, pos2, mask=None, train=True, bag_size=0):
"""
Args:
label: (B), label of the bag
scope: (B), scope for each bag
token: (nsum, L), index of tokens
pos1: (nsum, L), relative position to head entity
pos2: (nsum, L), relative position to tail entity
mask: (nsum, L), used for piece-wise CNN
Return:
logits, (B, N)
Dirty hack:
When the encoder is BERT, the input is actually token, att_mask, pos1, pos2, but
since the arguments are then fed into BERT encoder with the original order,
the encoder can actually work out correclty.
"""
if bag_size > 0:
token = token.view(-1, token.size(-1))
pos1 = pos1.view(-1, pos1.size(-1))
pos2 = pos2.view(-1, pos2.size(-1))
if mask is not None:
mask = mask.view(-1, mask.size(-1))
else:
begin, end = scope[0][0], scope[-1][1]
token = token[:, begin:end, :].view(-1, token.size(-1))
pos1 = pos1[:, begin:end, :].view(-1, pos1.size(-1))
pos2 = pos2[:, begin:end, :].view(-1, pos2.size(-1))
if mask is not None:
mask = mask[:, begin:end, :].view(-1, mask.size(-1))
scope = torch.sub(scope, torch.zeros_like(scope).fill_(begin))
# Attention
if train:
if mask is not None:
rep = self.sentence_encoder(token, pos1, pos2, mask) # (nsum, H)
else:
rep = self.sentence_encoder(token, pos1, pos2) # (nsum, H)
if bag_size == 0:
bag_rep = []
query = torch.zeros((rep.size(0))).long()
if torch.cuda.is_available():
query = query.cuda()
for i in range(len(scope)):
query[scope[i][0]:scope[i][1]] = label[i]
att_mat = self.fc.weight[query] # (nsum, H)
if self.use_diag:
att_mat = att_mat * self.diag.unsqueeze(0)
att_score = (rep * att_mat).sum(-1) # (nsum)
for i in range(len(scope)):
bag_mat = rep[scope[i][0]:scope[i][1]] # (n, H)
softmax_att_score = self.softmax(att_score[scope[i][0]:scope[i][1]]) # (n)
bag_rep.append((softmax_att_score.unsqueeze(-1) * bag_mat).sum(0)) # (n, 1) * (n, H) -> (n, H) -> (H)
bag_rep = torch.stack(bag_rep, 0) # (B, H)
else:
batch_size = label.size(0)
query = label.unsqueeze(1) # (B, 1)
att_mat = self.fc.weight[query] # (B, 1, H)
if self.use_diag:
att_mat = att_mat * self.diag.unsqueeze(0)
rep = rep.view(batch_size, bag_size, -1)
att_score = (rep * att_mat).sum(-1) # (B, bag)
softmax_att_score = self.softmax(att_score) # (B, bag)
bag_rep = (softmax_att_score.unsqueeze(-1) * rep).sum(1) # (B, bag, 1) * (B, bag, H) -> (B, bag, H) -> (B, H)
bag_rep = self.drop(bag_rep)
bag_logits = self.fc(bag_rep) # (B, N)
else:
if bag_size == 0:
rep = []
bs = 256
total_bs = len(token) // bs + (1 if len(token) % bs != 0 else 0)
for b in range(total_bs):
with torch.no_grad():
left = bs * b
right = min(bs * (b + 1), len(token))
if mask is not None:
rep.append(self.sentence_encoder(token[left:right], pos1[left:right], pos2[left:right], mask[left:right]).detach()) # (nsum, H)
else:
rep.append(self.sentence_encoder(token[left:right], pos1[left:right], pos2[left:right]).detach()) # (nsum, H)
rep = torch.cat(rep, 0)
bag_logits = []
att_mat = self.fc.weight.transpose(0, 1)
if self.use_diag:
att_mat = att_mat * self.diag.unsqueeze(1)
att_score = torch.matmul(rep, att_mat) # (nsum, H) * (H, N) -> (nsum, N)
for i in range(len(scope)):
bag_mat = rep[scope[i][0]:scope[i][1]] # (n, H)
softmax_att_score = self.softmax(att_score[scope[i][0]:scope[i][1]].transpose(0, 1)) # (N, (softmax)n)
rep_for_each_rel = torch.matmul(softmax_att_score, bag_mat) # (N, n) * (n, H) -> (N, H)
logit_for_each_rel = self.softmax(self.fc(rep_for_each_rel)) # ((each rel)N, (logit)N)
logit_for_each_rel = logit_for_each_rel.diag() # (N)
bag_logits.append(logit_for_each_rel)
bag_logits = torch.stack(bag_logits, 0) # after **softmax**
else:
if mask is not None:
rep = self.sentence_encoder(token, pos1, pos2, mask) # (nsum, H)
else:
rep = self.sentence_encoder(token, pos1, pos2) # (nsum, H)
batch_size = rep.size(0) // bag_size
att_mat = self.fc.weight.transpose(0, 1)
if self.use_diag:
att_mat = att_mat * self.diag.unsqueeze(1)
att_score = torch.matmul(rep, att_mat) # (nsum, H) * (H, N) -> (nsum, N)
att_score = att_score.view(batch_size, bag_size, -1) # (B, bag, N)
rep = rep.view(batch_size, bag_size, -1) # (B, bag, H)
softmax_att_score = self.softmax(att_score.transpose(1, 2)) # (B, N, (softmax)bag)
rep_for_each_rel = torch.matmul(softmax_att_score, rep) # (B, N, bag) * (B, bag, H) -> (B, N, H)
bag_logits = self.softmax(self.fc(rep_for_each_rel)).diagonal(dim1=1, dim2=2) # (B, (each rel)N)
return bag_logits

View File

@ -1,134 +0,0 @@
import torch
from torch import nn, optim
from .base_model import BagRE
class BagAverage(BagRE):
"""
Average policy for bag-level relation extraction.
"""
def __init__(self, sentence_encoder, num_class, rel2id):
"""
Args:
sentence_encoder: encoder for sentences
num_class: number of classes
id2rel: dictionary of id -> relation name mapping
"""
super().__init__()
self.sentence_encoder = sentence_encoder
self.num_class = num_class
self.fc = nn.Linear(self.sentence_encoder.hidden_size, num_class)
self.softmax = nn.Softmax(-1)
self.rel2id = rel2id
self.id2rel = {}
self.drop = nn.Dropout()
for rel, id in rel2id.items():
self.id2rel[id] = rel
def infer(self, bag):
"""
Args:
bag: bag of sentences with the same entity pair
[{
'text' or 'token': ...,
'h': {'pos': [start, end], ...},
't': {'pos': [start, end], ...}
}]
Return:
(relation, score)
"""
pass
"""
tokens = []
pos1s = []
pos2s = []
masks = []
for item in bag:
if 'text' in item:
token, pos1, pos2, mask = self.tokenizer(item['text'],
item['h']['pos'], item['t']['pos'], is_token=False, padding=True)
else:
token, pos1, pos2, mask = self.tokenizer(item['token'],
item['h']['pos'], item['t']['pos'], is_token=True, padding=True)
tokens.append(token)
pos1s.append(pos1)
pos2s.append(pos2)
masks.append(mask)
tokens = torch.cat(tokens, 0) # (n, L)
pos1s = torch.cat(pos1s, 0)
pos2s = torch.cat(pos2s, 0)
masks = torch.cat(masks, 0)
scope = torch.tensor([[0, len(bag)]]).long() # (1, 2)
bag_logits = self.forward(None, scope, tokens, pos1s, pos2s, masks, train=False).squeeze(0) # (N) after softmax
score, pred = bag_logits.max()
score = score.item()
pred = pred.item()
rel = self.id2rel[pred]
return (rel, score)
"""
def forward(self, label, scope, token, pos1, pos2, mask=None, train=True, bag_size=None):
"""
Args:
label: (B), label of the bag
scope: (B), scope for each bag
token: (nsum, L), index of tokens
pos1: (nsum, L), relative position to head entity
pos2: (nsum, L), relative position to tail entity
mask: (nsum, L), used for piece-wise CNN
Return:
logits, (B, N)
"""
if bag_size > 0:
token = token.view(-1, token.size(-1))
pos1 = pos1.view(-1, pos1.size(-1))
pos2 = pos2.view(-1, pos2.size(-1))
if mask is not None:
mask = mask.view(-1, mask.size(-1))
else:
begin, end = scope[0][0], scope[-1][1]
token = token[:, begin:end, :].view(-1, token.size(-1))
pos1 = pos1[:, begin:end, :].view(-1, pos1.size(-1))
pos2 = pos2[:, begin:end, :].view(-1, pos2.size(-1))
if mask is not None:
mask = mask[:, begin:end, :].view(-1, mask.size(-1))
scope = torch.sub(scope, torch.zeros_like(scope).fill_(begin))
if train or bag_size > 0:
if mask is not None:
rep = self.sentence_encoder(token, pos1, pos2, mask) # (nsum, H)
else:
rep = self.sentence_encoder(token, pos1, pos2) # (nsum, H)
else:
rep = []
bs = 256
total_bs = len(token) // bs + (1 if len(token) % bs != 0 else 0)
for b in range(total_bs):
with torch.no_grad():
left = bs * b
right = min(bs * (b + 1), len(token))
if mask is not None:
rep.append(self.sentence_encoder(token[left:right], pos1[left:right], pos2[left:right], mask[left:right]).detach()) # (nsum, H)
else:
rep.append(self.sentence_encoder(token[left:right], pos1[left:right], pos2[left:right]).detach()) # (nsum, H)
rep = torch.cat(rep, 0)
# Average
bag_rep = []
if bag_size is None or bag_size == 0:
for i in range(len(scope)):
bag_rep.append(rep[scope[i][0]:scope[i][1]].mean(0))
bag_rep = torch.stack(bag_rep, 0) # (B, H)
else:
batch_size = len(scope)
rep = rep.view(batch_size, bag_size, -1) # (B, bag, H)
bag_rep = rep.mean(1) # (B, H)
bag_rep = self.drop(bag_rep)
bag_logits = self.fc(bag_rep) # (B, N)
if not train:
bag_logits = torch.softmax(bag_logits, -1)
return bag_logits

View File

@ -1,155 +0,0 @@
import torch
from torch import nn, optim
from .base_model import BagRE
class BagOne(BagRE):
"""
Instance one(max) for bag-level relation extraction.
"""
def __init__(self, sentence_encoder, num_class, rel2id):
"""
Args:
sentence_encoder: encoder for sentences
num_class: number of classes
id2rel: dictionary of id -> relation name mapping
"""
super().__init__()
self.sentence_encoder = sentence_encoder
self.num_class = num_class
self.fc = nn.Linear(self.sentence_encoder.hidden_size, num_class)
self.softmax = nn.Softmax(-1)
self.rel2id = rel2id
self.id2rel = {}
self.drop = nn.Dropout()
for rel, id in rel2id.items():
self.id2rel[id] = rel
def infer(self, bag):
"""
Args:
bag: bag of sentences with the same entity pair
[{
'text' or 'token': ...,
'h': {'pos': [start, end], ...},
't': {'pos': [start, end], ...}
}]
Return:
(relation, score)
"""
self.eval()
tokens = []
pos1s = []
pos2s = []
masks = []
for item in bag:
token, pos1, pos2, mask = self.sentence_encoder.tokenize(item)
tokens.append(token)
pos1s.append(pos1)
pos2s.append(pos2)
masks.append(mask)
tokens = torch.cat(tokens, 0).unsqueeze(0) # (n, L)
pos1s = torch.cat(pos1s, 0).unsqueeze(0)
pos2s = torch.cat(pos2s, 0).unsqueeze(0)
masks = torch.cat(masks, 0).unsqueeze(0)
scope = torch.tensor([[0, len(bag)]]).long() # (1, 2)
bag_logits = self.forward(None, scope, tokens, pos1s, pos2s, masks, train=False).squeeze(0) # (N) after softmax
score, pred = bag_logits.max(0)
score = score.item()
pred = pred.item()
rel = self.id2rel[pred]
return (rel, score)
def forward(self, label, scope, token, pos1, pos2, mask=None, train=True, bag_size=0):
"""
Args:
label: (B), label of the bag
scope: (B), scope for each bag
token: (nsum, L), index of tokens
pos1: (nsum, L), relative position to head entity
pos2: (nsum, L), relative position to tail entity
mask: (nsum, L), used for piece-wise CNN
Return:
logits, (B, N)
"""
# Encode
if bag_size > 0:
token = token.view(-1, token.size(-1))
pos1 = pos1.view(-1, pos1.size(-1))
pos2 = pos2.view(-1, pos2.size(-1))
if mask is not None:
mask = mask.view(-1, mask.size(-1))
else:
begin, end = scope[0][0], scope[-1][1]
token = token[:, begin:end, :].view(-1, token.size(-1))
pos1 = pos1[:, begin:end, :].view(-1, pos1.size(-1))
pos2 = pos2[:, begin:end, :].view(-1, pos2.size(-1))
if mask is not None:
mask = mask[:, begin:end, :].view(-1, mask.size(-1))
scope = torch.sub(scope, torch.zeros_like(scope).fill_(begin))
if train or bag_size > 0:
if mask is not None:
rep = self.sentence_encoder(token, pos1, pos2, mask) # (nsum, H)
else:
rep = self.sentence_encoder(token, pos1, pos2) # (nsum, H)
else:
rep = []
bs = 256
total_bs = len(token) // bs + (1 if len(token) % bs != 0 else 0)
for b in range(total_bs):
with torch.no_grad():
left = bs * b
right = min(bs * (b + 1), len(token))
if mask is not None:
rep.append(self.sentence_encoder(token[left:right], pos1[left:right], pos2[left:right], mask[left:right]).detach()) # (nsum, H)
else:
rep.append(self.sentence_encoder(token[left:right], pos1[left:right], pos2[left:right]).detach()) # (nsum, H)
rep = torch.cat(rep, 0)
# Max
if train:
if bag_size == 0:
bag_rep = []
query = torch.zeros((rep.size(0))).long()
if torch.cuda.is_available():
query = query.cuda()
for i in range(len(scope)):
query[scope[i][0]:scope[i][1]] = label[i]
for i in range(len(scope)): # iterate over bags
bag_mat = rep[scope[i][0]:scope[i][1]] # (n, H)
instance_logit = self.softmax(self.fc(bag_mat)) # (n, N)
# select j* which scores highest on the known label
max_index = instance_logit[:, query[i]].argmax() # (1)
bag_rep.append(bag_mat[max_index]) # (n, H) -> (H)
bag_rep = torch.stack(bag_rep, 0) # (B, H)
bag_rep = self.drop(bag_rep)
bag_logits = self.fc(bag_rep) # (B, N)
else:
batch_size = label.size(0)
query = label # (B)
rep = rep.view(batch_size, bag_size, -1)
instance_logit = self.softmax(self.fc(rep))
max_index = instance_logit[torch.arange(batch_size), :, query].argmax(-1)
bag_rep = rep[torch.arange(batch_size), max_index]
bag_rep = self.drop(bag_rep)
bag_logits = self.fc(bag_rep) # (B, N)
else:
if bag_size == 0:
bag_logits = []
for i in range(len(scope)):
bag_mat = rep[scope[i][0]:scope[i][1]] # (n, H)
instance_logit = self.softmax(self.fc(bag_mat)) # (n, N)
logit_for_each_rel = instance_logit.max(dim=0)[0] # (N)
bag_logits.append(logit_for_each_rel)
bag_logits = torch.stack(bag_logits, 0) # after **softmax**
else:
batch_size = rep.size(0) // bag_size
rep = rep.view(batch_size, bag_size, -1)
bag_logits = self.softmax(self.fc(rep)).max(1)[0]
return bag_logits

View File

@ -1,71 +0,0 @@
import torch
from torch import nn
import json
class SentenceRE(nn.Module):
def __init__(self):
super().__init__()
def infer(self, item):
"""
Args:
item: {'text' or 'token', 'h': {'pos': [start, end]}, 't': ...}
Return:
(Name of the relation of the sentence, score)
"""
raise NotImplementedError
class BagRE(nn.Module):
def __init__(self):
super().__init__()
def infer(self, bag):
"""
Args:
bag: bag of sentences with the same entity pair
[{
'text' or 'token': ...,
'h': {'pos': [start, end], ...},
't': {'pos': [start, end], ...}
}]
Return:
(relation, score)
"""
raise NotImplementedError
class FewShotRE(nn.Module):
def __init__(self):
super().__init__()
def infer(self, support, query):
"""
Args:
support: supporting set.
[{'text' or 'token': ...,
'h': {'pos': [start, end], ...},
't': {'pos': [start, end], ...},
'relation': ...}]
query: same format as support
Return:
[(relation, score), ...]
For few-shot relation extraction, please refer to FewRel
https://github.com/thunlp/FewRel
"""
raise NotImplementedError
class NER(nn.Module):
def __init__(self):
super().__init__()
def ner(self, sentence, is_token=False):
"""
Args:
sentence: string, the input sentence
is_token: if is_token == True, senetence becomes an array of token
Return:
[{name: xx, pos: [start, end]}], a list of named entities
"""
raise NotImplementedError

Some files were not shown because too many files have changed in this diff Show More