445 lines
		
	
	
		
			17 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			445 lines
		
	
	
		
			17 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from sklearn.preprocessing.data import binarize
 | |
| from carb.argument import Argument
 | |
| from operator import itemgetter
 | |
| from collections import defaultdict
 | |
| import nltk
 | |
| import itertools
 | |
| import logging
 | |
| import numpy as np
 | |
| import pdb
 | |
| 
 | |
| class Extraction:
 | |
|     """
 | |
|     Stores sentence, single predicate and corresponding arguments.
 | |
|     """
 | |
|     def __init__(self, pred, head_pred_index, sent, confidence, question_dist = '', index = -1):
 | |
|         self.pred = pred
 | |
|         self.head_pred_index = head_pred_index
 | |
|         self.sent = sent
 | |
|         self.args = []
 | |
|         self.confidence = confidence
 | |
|         self.matched = []
 | |
|         self.questions = {}
 | |
|         self.indsForQuestions = defaultdict(lambda: set())
 | |
|         self.is_mwp = False
 | |
|         self.question_dist = question_dist
 | |
|         self.index = index
 | |
| 
 | |
|     def distArgFromPred(self, arg):
 | |
|         assert(len(self.pred) == 2)
 | |
|         dists = []
 | |
|         for x in self.pred[1]:
 | |
|             for y in arg.indices:
 | |
|                 dists.append(abs(x - y))
 | |
| 
 | |
|         return min(dists)
 | |
| 
 | |
|     def argsByDistFromPred(self, question):
 | |
|         return sorted(self.questions[question], key = lambda arg: self.distArgFromPred(arg))
 | |
| 
 | |
|     def addArg(self, arg, question = None):
 | |
|         self.args.append(arg)
 | |
|         if question:
 | |
|             self.questions[question] = self.questions.get(question,[]) + [Argument(arg)]
 | |
| 
 | |
|     def noPronounArgs(self):
 | |
|         """
 | |
|         Returns True iff all of this extraction's arguments are not pronouns.
 | |
|         """
 | |
|         for (a, _) in self.args:
 | |
|             tokenized_arg = nltk.word_tokenize(a)
 | |
|             if len(tokenized_arg) == 1:
 | |
|                 _, pos_tag = nltk.pos_tag(tokenized_arg)[0]
 | |
|                 if ('PRP' in pos_tag):
 | |
|                     return False
 | |
|         return True
 | |
| 
 | |
|     def isContiguous(self):
 | |
|         return all([indices for (_, indices) in self.args])
 | |
| 
 | |
|     def toBinary(self):
 | |
|         ''' Try to represent this extraction's arguments as binary
 | |
|         If fails, this function will return an empty list.  '''
 | |
| 
 | |
|         ret = [self.elementToStr(self.pred)]
 | |
| 
 | |
|         if len(self.args) == 2:
 | |
|             # we're in luck
 | |
|             return ret + [self.elementToStr(arg) for arg in self.args]
 | |
| 
 | |
|         return []
 | |
| 
 | |
|         if not self.isContiguous():
 | |
|             # give up on non contiguous arguments (as we need indexes)
 | |
|             return []
 | |
| 
 | |
|         # otherwise, try to merge based on indices
 | |
|         # TODO: you can explore other methods for doing this
 | |
|         binarized = self.binarizeByIndex()
 | |
| 
 | |
|         if binarized:
 | |
|             return ret + binarized
 | |
| 
 | |
|         return []
 | |
| 
 | |
| 
 | |
|     def elementToStr(self, elem, print_indices = True):
 | |
|         ''' formats an extraction element (pred or arg) as a raw string
 | |
|         removes indices and trailing spaces '''
 | |
|         if print_indices:
 | |
|             return str(elem)
 | |
|         if isinstance(elem, str):
 | |
|             return elem
 | |
|         if isinstance(elem, tuple):
 | |
|             ret = elem[0].rstrip().lstrip()
 | |
|         else:
 | |
|             ret = ' '.join(elem.words)
 | |
|         assert ret, "empty element? {0}".format(elem)
 | |
|         return ret
 | |
| 
 | |
|     def binarizeByIndex(self):
 | |
|         extraction = [self.pred] + self.args
 | |
|         markPred = [(w, ind, i == 0) for i, (w, ind) in enumerate(extraction)]
 | |
|         sortedExtraction = sorted(markPred, key = lambda ws, indices, f : indices[0])
 | |
|         s =  ' '.join(['{1} {0} {1}'.format(self.elementToStr(elem), SEP) if elem[2] else self.elementToStr(elem) for elem in sortedExtraction])
 | |
|         binArgs = [a for a in s.split(SEP) if a.rstrip().lstrip()]
 | |
| 
 | |
|         if len(binArgs) == 2:
 | |
|             return binArgs
 | |
| 
 | |
|         # failure
 | |
|         return []
 | |
| 
 | |
|     def bow(self):
 | |
|         return ' '.join([self.elementToStr(elem) for elem in [self.pred] + self.args])
 | |
| 
 | |
|     def getSortedArgs(self):
 | |
|         """
 | |
|         Sort the list of arguments.
 | |
|         If a question distribution is provided - use it,
 | |
|         otherwise, default to the order of appearance in the sentence.
 | |
|         """
 | |
|         if self.question_dist:
 | |
|             # There's a question distribtuion - use it
 | |
|             return self.sort_args_by_distribution()
 | |
|         ls = []
 | |
|         for q, args in self.questions.iteritems():
 | |
|             if (len(args) != 1):
 | |
|                 logging.debug("Not one argument: {}".format(args))
 | |
|                 continue
 | |
|             arg = args[0]
 | |
|             indices = list(self.indsForQuestions[q].union(arg.indices))
 | |
|             if not indices:
 | |
|                 logging.debug("Empty indexes for arg {} -- backing to zero".format(arg))
 | |
|                 indices = [0]
 | |
|             ls.append(((arg, q), indices))
 | |
|         return [a for a, _ in sorted(ls,
 | |
|                                      key = lambda _, indices: min(indices))]
 | |
| 
 | |
|     def question_prob_for_loc(self, question, loc):
 | |
|         """
 | |
|         Returns the probability of the given question leading to argument
 | |
|         appearing in the given location in the output slot.
 | |
|         """
 | |
|         gen_question = generalize_question(question)
 | |
|         q_dist = self.question_dist[gen_question]
 | |
|         logging.debug("distribution of {}: {}".format(gen_question,
 | |
|                                                       q_dist))
 | |
| 
 | |
|         return float(q_dist.get(loc, 0)) /  \
 | |
|             sum(q_dist.values())
 | |
| 
 | |
|     def sort_args_by_distribution(self):
 | |
|         """
 | |
|         Use this instance's question distribution (this func assumes it exists)
 | |
|         in determining the positioning of the arguments.
 | |
|         Greedy algorithm:
 | |
|         0. Decide on which argument will serve as the ``subject'' (first slot) of this extraction
 | |
|         0.1 Based on the most probable one for this spot
 | |
|         (special care is given to select the highly-influential subject position)
 | |
|         1. For all other arguments, sort arguments by the prevalance of their questions
 | |
|         2. For each argument:
 | |
|         2.1 Assign to it the most probable slot still available
 | |
|         2.2 If non such exist (fallback) - default to put it in the last location
 | |
|         """
 | |
|         INF_LOC = 100 # Used as an impractical last argument
 | |
| 
 | |
|         # Store arguments by slot
 | |
|         ret = {INF_LOC: []}
 | |
|         logging.debug("sorting: {}".format(self.questions))
 | |
| 
 | |
|         # Find the most suitable arguemnt for the subject location
 | |
|         logging.debug("probs for subject: {}".format([(q, self.question_prob_for_loc(q, 0))
 | |
|                                                       for (q, _) in self.questions.iteritems()]))
 | |
| 
 | |
|         subj_question, subj_args = max(self.questions.iteritems(),
 | |
|                                        key = lambda q, _: self.question_prob_for_loc(q, 0))
 | |
| 
 | |
|         ret[0] = [(subj_args[0], subj_question)]
 | |
| 
 | |
|         # Find the rest
 | |
|         for (question, args) in sorted([(q, a)
 | |
|                                         for (q, a) in self.questions.iteritems() if (q not in [subj_question])],
 | |
|                                        key = lambda q, _: \
 | |
|                                        sum(self.question_dist[generalize_question(q)].values()),
 | |
|                                        reverse = True):
 | |
|             gen_question = generalize_question(question)
 | |
|             arg = args[0]
 | |
|             assigned_flag = False
 | |
|             for (loc, count) in sorted(self.question_dist[gen_question].iteritems(),
 | |
|                                        key = lambda _ , c: c,
 | |
|                                        reverse = True):
 | |
|                 if loc not in ret:
 | |
|                     # Found an empty slot for this item
 | |
|                     # Place it there and break out
 | |
|                     ret[loc] = [(arg, question)]
 | |
|                     assigned_flag = True
 | |
|                     break
 | |
| 
 | |
|             if not assigned_flag:
 | |
|                 # Add this argument to the non-assigned (hopefully doesn't happen much)
 | |
|                 logging.debug("Couldn't find an open assignment for {}".format((arg, gen_question)))
 | |
|                 ret[INF_LOC].append((arg, question))
 | |
| 
 | |
|         logging.debug("Linearizing arg list: {}".format(ret))
 | |
| 
 | |
|         # Finished iterating - consolidate and return a list of arguments
 | |
|         return [arg
 | |
|                 for (_, arg_ls) in sorted(ret.iteritems(),
 | |
|                                           key = lambda k, v: int(k))
 | |
|                 for arg in arg_ls]
 | |
| 
 | |
| 
 | |
|     def __str__(self):
 | |
|         pred_str = self.elementToStr(self.pred)
 | |
|         return '{}\t{}\t{}'.format(self.get_base_verb(pred_str),
 | |
|                                    self.compute_global_pred(pred_str,
 | |
|                                                             self.questions.keys()),
 | |
|                                    '\t'.join([escape_special_chars(self.augment_arg_with_question(self.elementToStr(arg),
 | |
|                                                                                                   question))
 | |
|                                               for arg, question in self.getSortedArgs()]))
 | |
| 
 | |
|     def get_base_verb(self, surface_pred):
 | |
|         """
 | |
|         Given the surface pred, return the original annotated verb
 | |
|         """
 | |
|         # Assumes that at this point the verb is always the last word
 | |
|         # in the surface predicate
 | |
|         return surface_pred.split(' ')[-1]
 | |
| 
 | |
| 
 | |
|     def compute_global_pred(self, surface_pred, questions):
 | |
|         """
 | |
|         Given the surface pred and all instansiations of questions,
 | |
|         make global coherence decisions regarding the final form of the predicate
 | |
|         This should hopefully take care of multi word predicates and correct inflections
 | |
|         """
 | |
|         from operator import itemgetter
 | |
|         split_surface = surface_pred.split(' ')
 | |
| 
 | |
|         if len(split_surface) > 1:
 | |
|             # This predicate has a modal preceding the base verb
 | |
|             verb = split_surface[-1]
 | |
|             ret = split_surface[:-1] # get all of the elements in the modal
 | |
|         else:
 | |
|             verb = split_surface[0]
 | |
|             ret = []
 | |
| 
 | |
|         split_questions = map(lambda question: question.split(' '),
 | |
|                             questions)
 | |
| 
 | |
|         preds = map(normalize_element,
 | |
|                     map(itemgetter(QUESTION_TRG_INDEX),
 | |
|                         split_questions))
 | |
|         if len(set(preds)) > 1:
 | |
|             # This predicate is appears in multiple ways, let's stick to the base form
 | |
|             ret.append(verb)
 | |
| 
 | |
|         if len(set(preds)) == 1:
 | |
|             # Change the predciate to the inflected form
 | |
|             # if there's exactly one way in which the predicate is conveyed
 | |
|             ret.append(preds[0])
 | |
| 
 | |
|             pps = map(normalize_element,
 | |
|                       map(itemgetter(QUESTION_PP_INDEX),
 | |
|                           split_questions))
 | |
| 
 | |
|             obj2s = map(normalize_element,
 | |
|                         map(itemgetter(QUESTION_OBJ2_INDEX),
 | |
|                             split_questions))
 | |
| 
 | |
|             if (len(set(pps)) == 1):
 | |
|                 # If all questions for the predicate include the same pp attachemnt -
 | |
|                 # assume it's a multiword predicate
 | |
|                 self.is_mwp = True # Signal to arguments that they shouldn't take the preposition
 | |
|                 ret.append(pps[0])
 | |
| 
 | |
|         # Concat all elements in the predicate and return
 | |
|         return " ".join(ret).strip()
 | |
| 
 | |
| 
 | |
|     def augment_arg_with_question(self, arg, question):
 | |
|         """
 | |
|         Decide what elements from the question to incorporate in the given
 | |
|         corresponding argument
 | |
|         """
 | |
|         # Parse question
 | |
|         wh, aux, sbj, trg, obj1, pp, obj2 = map(normalize_element,
 | |
|                                                 question.split(' ')[:-1]) # Last split is the question mark
 | |
| 
 | |
|         # Place preposition in argument
 | |
|         # This is safer when dealing with n-ary arguments, as it's directly attaches to the
 | |
|         # appropriate argument
 | |
|         if (not self.is_mwp) and pp and (not obj2):
 | |
|             if not(arg.startswith("{} ".format(pp))):
 | |
|                 # Avoid repeating the preporition in cases where both question and answer contain it
 | |
|                 return " ".join([pp,
 | |
|                                  arg])
 | |
| 
 | |
|         # Normal cases
 | |
|         return arg
 | |
| 
 | |
|     def clusterScore(self, cluster):
 | |
|         """
 | |
|         Calculate cluster density score as the mean distance of the maximum distance of each slot.
 | |
|         Lower score represents a denser cluster.
 | |
|         """
 | |
|         logging.debug("*-*-*- Cluster: {}".format(cluster))
 | |
| 
 | |
|         # Find global centroid
 | |
|         arr = np.array([x for ls in cluster for x in ls])
 | |
|         centroid = np.sum(arr)/arr.shape[0]
 | |
|         logging.debug("Centroid: {}".format(centroid))
 | |
| 
 | |
|         # Calculate mean over all maxmimum points
 | |
|         return np.average([max([abs(x - centroid) for x in ls]) for ls in cluster])
 | |
| 
 | |
|     def resolveAmbiguity(self):
 | |
|         """
 | |
|         Heursitic to map the elments (argument and predicates) of this extraction
 | |
|         back to the indices of the sentence.
 | |
|         """
 | |
|         ## TODO: This removes arguments for which there was no consecutive span found
 | |
|         ## Part of these are non-consecutive arguments,
 | |
|         ## but other could be a bug in recognizing some punctuation marks
 | |
| 
 | |
|         elements = [self.pred] \
 | |
|                    + [(s, indices)
 | |
|                       for (s, indices)
 | |
|                       in self.args
 | |
|                       if indices]
 | |
|         logging.debug("Resolving ambiguity in: {}".format(elements))
 | |
| 
 | |
|         # Collect all possible combinations of arguments and predicate indices
 | |
|         # (hopefully it's not too much)
 | |
|         all_combinations = list(itertools.product(*map(itemgetter(1), elements)))
 | |
|         logging.debug("Number of combinations: {}".format(len(all_combinations)))
 | |
| 
 | |
|         # Choose the ones with best clustering and unfold them
 | |
|         resolved_elements = zip(map(itemgetter(0), elements),
 | |
|                                 min(all_combinations,
 | |
|                                     key = lambda cluster: self.clusterScore(cluster)))
 | |
|         logging.debug("Resolved elements = {}".format(resolved_elements))
 | |
| 
 | |
|         self.pred = resolved_elements[0]
 | |
|         self.args = resolved_elements[1:]
 | |
| 
 | |
|     def conll(self, external_feats = {}):
 | |
|         """
 | |
|         Return a CoNLL string representation of this extraction
 | |
|         """
 | |
|         return '\n'.join(["\t".join(map(str,
 | |
|                                         [i, w] + \
 | |
|                                         list(self.pred) + \
 | |
|                                         [self.head_pred_index] + \
 | |
|                                         external_feats + \
 | |
|                                         [self.get_label(i)]))
 | |
|                           for (i, w)
 | |
|                           in enumerate(self.sent.split(" "))]) + '\n'
 | |
| 
 | |
|     def get_label(self, index):
 | |
|         """
 | |
|         Given an index of a word in the sentence -- returns the appropriate BIO conll label
 | |
|         Assumes that ambiguation was already resolved.
 | |
|         """
 | |
|         # Get the element(s) in which this index appears
 | |
|         ent = [(elem_ind, elem)
 | |
|                for (elem_ind, elem)
 | |
|                in enumerate(map(itemgetter(1),
 | |
|                                 [self.pred] + self.args))
 | |
|                if index in elem]
 | |
| 
 | |
|         if not ent:
 | |
|             # index doesnt appear in any element
 | |
|             return "O"
 | |
| 
 | |
|         if len(ent) > 1:
 | |
|             # The same word appears in two different answers
 | |
|             # In this case we choose the first one as label
 | |
|             logging.warn("Index {} appears in one than more element: {}".\
 | |
|                          format(index,
 | |
|                                 "\t".join(map(str,
 | |
|                                               [ent,
 | |
|                                                self.sent,
 | |
|                                                self.pred,
 | |
|                                                self.args]))))
 | |
| 
 | |
|         ## Some indices appear in more than one argument (ones where the above message appears)
 | |
|         ## From empricial observation, these seem to mostly consist of different levels of granularity:
 | |
|         ##     what	had	_	been taken	_	_	_	?	loan commitments topping $ 3 billion
 | |
|         ##     how much	had	_	been taken	_	_	_	?	topping $ 3 billion
 | |
|         ## In these cases we heuristically choose the shorter answer span, hopefully creating minimal spans
 | |
|         ## E.g., in this example two arguemnts are created: (loan commitments, topping $ 3 billion)
 | |
| 
 | |
|         elem_ind, elem = min(ent, key = lambda _, ls: len(ls))
 | |
| 
 | |
|         # Distinguish between predicate and arguments
 | |
|         prefix = "P" if elem_ind == 0 else "A{}".format(elem_ind - 1)
 | |
| 
 | |
|         # Distinguish between Beginning and Inside labels
 | |
|         suffix = "B" if index == elem[0] else "I"
 | |
| 
 | |
|         return "{}-{}".format(prefix, suffix)
 | |
| 
 | |
|     def __str__(self):
 | |
|         return '{0}\t{1}'.format(self.elementToStr(self.pred,
 | |
|                                                    print_indices = True),
 | |
|                                  '\t'.join([self.elementToStr(arg)
 | |
|                                             for arg
 | |
|                                             in self.args]))
 | |
| 
 | |
| # Flatten a list of lists
 | |
| flatten = lambda l: [item for sublist in l for item in sublist]
 | |
| 
 | |
| 
 | |
| def normalize_element(elem):
 | |
|     """
 | |
|     Return a surface form of the given question element.
 | |
|     the output should be properly able to precede a predicate (or blank otherwise)
 | |
|     """
 | |
|     return elem.replace("_", " ") \
 | |
|         if (elem != "_")\
 | |
|            else ""
 | |
| 
 | |
| ## Helper functions
 | |
| def escape_special_chars(s):
 | |
|     return s.replace('\t', '\\t')
 | |
| 
 | |
| 
 | |
| def generalize_question(question):
 | |
|     """
 | |
|     Given a question in the context of the sentence and the predicate index within
 | |
|     the question - return a generalized version which extracts only order-imposing features
 | |
|     """
 | |
|     import nltk   # Using nltk since couldn't get spaCy to agree on the tokenization
 | |
|     wh, aux, sbj, trg, obj1, pp, obj2 = question.split(' ')[:-1] # Last split is the question mark
 | |
|     return ' '.join([wh, sbj, obj1])
 | |
| 
 | |
| 
 | |
| 
 | |
| ## CONSTANTS
 | |
| SEP = ';;;'
 | |
| QUESTION_TRG_INDEX =  3 # index of the predicate within the question
 | |
| QUESTION_PP_INDEX = 5
 | |
| QUESTION_OBJ2_INDEX = 6
 |