2000字范文,分享全网优秀范文,学习好帮手!
2000字范文 > NLP-信息抽取-关系抽取-:Attention-BiLSTM实体关系分类器【基于双向LSTM及注意力

NLP-信息抽取-关系抽取-:Attention-BiLSTM实体关系分类器【基于双向LSTM及注意力

时间:2021-01-13 08:49:35

相关推荐

NLP-信息抽取-关系抽取-:Attention-BiLSTM实体关系分类器【基于双向LSTM及注意力

《原始论文:Attention-based bidirectional long short-term memory networks for relation classification》

一、概述

1、本文idea提出原因

传统的方法中,大多数研究依赖于一些现有的词汇资源(例如WordNet)、NLP系 统或一些手工提取的特征。这样的方法可能导致计算复杂度的增加,并且特征提取工作本身会耗费大量的时间和精力,特征提取质量的对于实验的结果也有很大的影响。

提出了 ATT-BLSTM的网络结构解决关系端对端识别问题

这篇论文从这一角度出发,提出一个基于Attention机制的双向 LSTM神经网络模型进行关系抽取研究,Attention机制能够自动 发现那些对于分类起到关键作用的词,使得这个模型可以从每个句子中捕获最重要的语义信息,它不依赖于任何外部的知识或者NLP系统

2、本论文历史意义

巧妙地在双向LSTM模型中加入Attention机制,用于关系抽取任务,避免了传统的 任务中复杂的特征工程,大大简化了实验过程并得到相当不错的结果,也为相关的研究提供了可操作性的思路

这篇论文的整体的逻辑十分清晰,紧紧围绕研究动机.整篇论文的思路十分简单,模型也一目了然,但是结果表现优秀

3、摘要核心

目前关系识别依赖于Mp工具提取特征;提出一种不需要复杂预处理的关系识别方法att-blstm;实验结果表明该方法是有效的,达到the state-of-the-art的效果

二、Attention-BiLSTM模型结构

1、模型结构

ATT-BLSTM网络结构以word embeding为基础,加入实体标识位,通过ATT-BLSTM的结构让模型动态区分关系分类的重要词汇。

As shown in Figure 1, the model proposed in this paper contains five components:

输入句子:Input layer: input sentence to this model;Embedding layer: map each word into a low dimension vector;BiLSTM:LSTM layer: utilize BLSTM to get high level features from step (2);Attention layer: produce a weight vector, and merge word-level features from each time step into a sentence-level feature vector, by multiplying the weight vector;Output layer: the sentence-level feature vec- tor is finally used for relation classification.

2、Attention 原理

Attention 原理:Attention Mechanism可以帮助模型对输入的X每个部分赋予不同的权重,抽取出更加关键及重要的信息,使模型做出更加准确的判断,同时不会对模型的计算和存储带来更大的开销。

根据Attention的计算区域,可以分成以下几种:

Soft-Attention/Global Attention:这是比较常见的Attention方式,对所有key求权重概率,每个key都有一个对应的权重,是一种全局的计算方式(也可以叫Global Attention).Hard-Attention:这种方式是直接精准定位到某个key,其余key就都不管了,相当于这个key的 概率是1 ,其余key的概率全部是0。因此这种对齐方式要求很高,要求一步到位,如果没有正确对齐, 会带来很大的影响。另一方面,因为不可导,一般需要用强化学习的方法进行训练Local-Attention:这种方式其实是以上两种方式的一个折中,对一个窗口区域进行计算。先用 Hard方式定位到某个地方,以这个点为中心可以得到一个窗口区域,在这个小区域内用Soft方式来

算 Attention。

3、小技巧

对实体前后添加特定标识符标明实体位置

采用带约束的正则损失

三、实验结果

compare various model configurations on the SemEval- Task 8 dataset

四、论文结论

1、关键点

不依赖任何其他NLP工具

2、创新点

引入Attention-BiLSTM结构

3、启发点

网格结构完全不依何nlp工具或词法资源,只需要带位置标识的原始文本作为输入。

This model does not rely on NLP tools or lexical resources to get, it uses raw text with position indicators as input.

五、论文代码

1、数据集

1.1 原始数据集

train_file.txt【样本1-8000】

1"The system as described above has its greatest application in an arrayed <e1>configuration</e1> of antenna <e2>elements</e2>."Component-Whole(e2,e1)Comment: Not a collection: there is structure here, organisation.2"The <e1>child</e1> was carefully wrapped and bound into the <e2>cradle</e2> by means of a cord."OtherComment:3"The <e1>author</e1> of a keygen uses a <e2>disassembler</e2> to look at the raw assembly code."Instrument-Agency(e2,e1)Comment:4"A misty <e1>ridge</e1> uprises from the <e2>surge</e2>."OtherComment:5"The <e1>student</e1> <e2>association</e2> is the voice of the undergraduate student population of the State University of New York at Buffalo."Member-Collection(e1,e2)Comment:6"This is the sprawling <e1>complex</e1> that is Peru's largest <e2>producer</e2> of silver."OtherComment:7"The current view is that the chronic <e1>inflammation</e1> in the distal part of the stomach caused by Helicobacter pylori <e2>infection</e2> results in an increased acid production from the non-infected upper corpus region of the stomach."Cause-Effect(e2,e1)Comment:8"<e1>People</e1> have been moving back into <e2>downtown</e2>."Entity-Destination(e1,e2)Comment:9"The <e1>lawsonite</e1> was contained in a <e2>platinum crucible</e2> and the counter-weight was a plastic crucible with metal pieces."Content-Container(e1,e2)Comment: prototypical example10"The solute was placed inside a beaker and 5 mL of the <e1>solvent</e1> was pipetted into a 25 mL glass <e2>flask</e2> for each trial."Entity-Destination(e1,e2)Comment:......

test_file.txt【样本8001-10717】

8001"The most common <e1>audits</e1> were about <e2>waste</e2> and recycling."Message-Topic(e1,e2)Comment: Assuming an audit = an audit document.8002"The <e1>company</e1> fabricates plastic <e2>chairs</e2>."Product-Producer(e2,e1)Comment: (a) is satisfied8003"The school <e1>master</e1> teaches the lesson with a <e2>stick</e2>."Instrument-Agency(e2,e1)Comment:8004"The suspect dumped the dead <e1>body</e1> into a local <e2>reservoir</e2>."Entity-Destination(e1,e2)Comment:8005"Avian <e1>influenza</e1> is an infectious disease of birds caused by type A strains of the influenza <e2>virus</e2>."Cause-Effect(e2,e1)Comment:8006"The <e1>ear</e1> of the African <e2>elephant</e2> is significantly larger--measuring 183 cm by 114 cm in the bush elephant."Component-Whole(e1,e2)Comment:8007"A child is told a <e1>lie</e1> for several years by their <e2>parents</e2> before he/she realizes that a Santa Claus does not exist."Product-Producer(e1,e2)Comment: (a) is satisfied; negation is outside8008"Skype, a free software, allows a <e1>hookup</e1> of multiple computer <e2>users</e2> to join in an online conference call without incurring any telephone costs."Member-Collection(e2,e1)Comment:8009"The disgusting scene was retaliation against her brother Philip who rents the <e1>room</e1> inside this apartment <e2>house</e2> on Lombard street."Component-Whole(e1,e2)Comment:8010"This <e1>thesis</e1> defines the <e2>clinical characteristics</e2> of amyloid disease."Message-Topic(e1,e2)Comment: may be we could leave clinical out of e2.

1.2 处理后的数据

preprocess.py

#!/usr/bin/env python# -*- encoding: utf-8 -*-# @Version : Python 3.6import jsonimport refrom nltk.tokenize import word_tokenizedef search_entity(sentence):e1 = re.findall(r'<e1>(.*)</e1>', sentence)[0]e2 = re.findall(r'<e2>(.*)</e2>', sentence)[0]sentence = sentence.replace('<e1>' + e1 + '</e1>', ' <e1> ' + e1 + ' </e1> ', 1)sentence = sentence.replace('<e2>' + e2 + '</e2>', ' <e2> ' + e2 + ' </e2> ', 1)sentence = word_tokenize(sentence)sentence = ' '.join(sentence)sentence = sentence.replace('< e1 >', '<e1>')sentence = sentence.replace('< e2 >', '<e2>')sentence = sentence.replace('< /e1 >', '</e1>')sentence = sentence.replace('< /e2 >', '</e2>')sentence = sentence.split()assert '<e1>' in sentenceassert '<e2>' in sentenceassert '</e1>' in sentenceassert '</e2>' in sentencereturn sentencedef convert(path_src, path_des):with open(path_src, 'r', encoding='utf-8') as fr:data = fr.readlines()with open(path_des, 'w', encoding='utf-8') as fw:for i in range(0, len(data), 4):id_s, sentence = data[i].strip().split('\t')sentence = sentence[1:-1]sentence = search_entity(sentence)meta = dict(id=id_s,relation=data[i+1].strip(),sentence=sentence,comment=data[i+2].strip()[8:])json.dump(meta, fw, ensure_ascii=False)fw.write('\n')if __name__ == '__main__':path_train = './SemEval_task8_all_data/SemEval_task8_training/TRAIN_FILE.TXT'path_test = './SemEval_task8_all_data/SemEval_task8_testing_keys/TEST_FILE_FULL.TXT'convert(path_train, 'train.json')convert(path_test, 'test.json')

train.json

{"id": "1", "relation": "Component-Whole(e2,e1)", "sentence": ["The", "system", "as", "described", "above", "has", "its", "greatest", "application", "in", "an", "arrayed", "<e1>", "configuration", "</e1>", "of", "antenna", "<e2>", "elements", "</e2>", "."], "comment": " Not a collection: there is structure here, organisation."}{"id": "2", "relation": "Other", "sentence": ["The", "<e1>", "child", "</e1>", "was", "carefully", "wrapped", "and", "bound", "into", "the", "<e2>", "cradle", "</e2>", "by", "means", "of", "a", "cord", "."], "comment": ""}{"id": "3", "relation": "Instrument-Agency(e2,e1)", "sentence": ["The", "<e1>", "author", "</e1>", "of", "a", "keygen", "uses", "a", "<e2>", "disassembler", "</e2>", "to", "look", "at", "the", "raw", "assembly", "code", "."], "comment": ""}{"id": "4", "relation": "Other", "sentence": ["A", "misty", "<e1>", "ridge", "</e1>", "uprises", "from", "the", "<e2>", "surge", "</e2>", "."], "comment": ""}{"id": "5", "relation": "Member-Collection(e1,e2)", "sentence": ["The", "<e1>", "student", "</e1>", "<e2>", "association", "</e2>", "is", "the", "voice", "of", "the", "undergraduate", "student", "population", "of", "the", "State", "University", "of", "New", "York", "at", "Buffalo", "."], "comment": ""}......

test.json

{"id": "8001", "relation": "Message-Topic(e1,e2)", "sentence": ["The", "most", "common", "<e1>", "audits", "</e1>", "were", "about", "<e2>", "waste", "</e2>", "and", "recycling", "."], "comment": " Assuming an audit = an audit document."}{"id": "8002", "relation": "Product-Producer(e2,e1)", "sentence": ["The", "<e1>", "company", "</e1>", "fabricates", "plastic", "<e2>", "chairs", "</e2>", "."], "comment": " (a) is satisfied"}{"id": "8003", "relation": "Instrument-Agency(e2,e1)", "sentence": ["The", "school", "<e1>", "master", "</e1>", "teaches", "the", "lesson", "with", "a", "<e2>", "stick", "</e2>", "."], "comment": ""}{"id": "8004", "relation": "Entity-Destination(e1,e2)", "sentence": ["The", "suspect", "dumped", "the", "dead", "<e1>", "body", "</e1>", "into", "a", "local", "<e2>", "reservoir", "</e2>", "."], "comment": ""}{"id": "8005", "relation": "Cause-Effect(e2,e1)", "sentence": ["Avian", "<e1>", "influenza", "</e1>", "is", "an", "infectious", "disease", "of", "birds", "caused", "by", "type", "A", "strains", "of", "the", "influenza", "<e2>", "virus", "</e2>", "."], "comment": ""}......

1.3 relation2id

Other0Cause-Effect(e1,e2)1Cause-Effect(e2,e1)2Component-Whole(e1,e2)3Component-Whole(e2,e1)4Content-Container(e1,e2)5Content-Container(e2,e1)6Entity-Destination(e1,e2)7Entity-Destination(e2,e1)8Entity-Origin(e1,e2)9Entity-Origin(e2,e1)10Instrument-Agency(e1,e2)11Instrument-Agency(e2,e1)12Member-Collection(e1,e2)13Member-Collection(e2,e1)14Message-Topic(e1,e2)15Message-Topic(e2,e1)16Product-Producer(e1,e2)17Product-Producer(e2,e1)18

2、预训练词向量:静态词向量HLBL

hlbl-embeddings-scaled.EMBEDDING_SIZE=50

*UNKNOWN* -0.166038776479 0.104395984608 0.163119732357 0.0899594154863 -0.0192271099805 -0.0417631572501 -0.0163376687927 0.0357616216019 0.0536077591673 0.0127688536503 -0.00284508433021 -0.0626207031228 -0.0379452734015 -0.103548297666 0.0381169119981 0.00199421074321 -0.0474636488659 -0.0127526851513 0.016404178535 -0.12759853361 -0.0292937037717 -0.0512566352549 0.0233097445983 0.0360505083995 0.00229317984472 -0.0771565284227 0.0071461584378 -0.051608090196 -0.0267547654304 0.0492994451068 -0.0531630844999 0.00787191810391 0.082280106873 0.066908641868 -0.0283930612982 0.216840166248 0.164923151267 0.00188498983723 0.0328679039324 -0.00175432516758 0.0614261774935 0.0987773071377 0.0548423375506 -0.0307057922059 0.053074241476 0.04982054279 -0.0572485864016 0.132236444766 -0.0379717035014 -0.120915939814the -0.0841015569168 0.145263825738 0.116945121935 -0.0754618634155 0.17901499611 -0.000652852605208 -0.0713783879233 0.207273704502 0.060711721477 0.0366727701165 -0.0269791566731 -0.156993473526 -0.0393947453024 0.00749161628231 -0.332851634057 -0.1708430781 -0.275163605231 -0.266592614101 0.43349041466 -0.00779248211778 0.031101796379 -0.0257114150838 0.174856713352 -0.0543054233622 -0.0846669459476 -0.006234398456 0.00414488584462 0.119738648443 -0.0914876936952 -0.317381121871 -0.27471439742 0.234269597998 0.170305945138 -0.0282815073325 -0.10127814458 0.156451476203 0.154703520781 -0.0014827085612 0.164287521114 0.0328582913203 0.0356570354049 -0.190254406793 -0.112029936115 -0.198875312619 0.00102875631152 -0.00161517169984 -0.125210890327 0.196903181061 -0.11915766 -0.00838804375065. -0.0875932389444 -0.0586365253633 0.0729727126603 0.32072000431 0.0745620569276 -0.0494709138174 0.208708067552 -0.025035364294 -0.197531050237 0.17731828 0.297077745222 -0.0256369072571 0.182364658364 0.189089099105 0.0589179494006 -0.0627276310572 0.0682898379459 0.241161712515 0.253510796291 -0.0325139691451 -0.0129081882483 -0.083367340352 0.0276167362372 -0.00757124183183 -0.0905801885623 0.305015208385 0.0755474920504 -0.00516459185438 -0.0412876867803 0.105047372601 -0.718674456034 0.184682477295 0.232732814491 0.0929975692214 0.0999329447708 -0.0968008990987 0.421525505372 -0.136460066398 -0.323294448817 0.118318915141 0.415411774103 -0.135770867168 0.0404792691614 0.264279769529 -0.133076243622 0.195087919022 -0.087589323012 0.0335223022065 -0.0365650611956 -0.0163760300203, -0.023019838485 0.277215570968 0.241932261453 -0.105403438907 0.247316949736 0.0859618436243 -0.0130132156599 0.123988163629 -0.150741462418 0.129993766762 0.0766431623839 0.0547135456598 0.187342182554 0.176303102861 -0.121401723217 0.0458278230666 0.0339804870854 -0.0619606057248 0.0514787739809 0.00732501266557 0.0879996990484 -0.369288823679 0.235222707122 -0.0528783055204 0.0121891472663 -0.165169815904 -0.136829953355 -0.0750751223049 -0.0503433833321 0.0782539868365 -0.400940778018 -0.09974522 -0.152448498545 -0.0815002789835 -0.010575616616 0.331604536668 -0.0124179474775 0.00173559407939 -0.230971231526 0.0162523457081 0.213848645598 0.184698023693 0.158368229826 0.0975422545404 -0.0307127563081 0.09346492 -0.0377856184872 -0.0181716170654 0.43322993915 -0.113289957059to 0.134693667961 0.392203653086 0.0346151199225 0.135354475458 0.0719918082372 0.118667933013 -0.0698386234679 -0.0139927084407 0.144452931939 0.0383223273458 -0.0491954394553 -0.126435975874 0.23979196724 -0.186550477314 0.0602616605691 -0.0875395769807 0.0788848675161 0.132691898026 0.155618778336 0.00680378469567 -0.126513561203 -0.436124771467 0.132675129426 -0.0946286638801 0.0986847070674 -0.354397304845 -0.196909463175 -0.0911408611189 0.134975690877 0.0625931974859 0.0108112360985 -0.107933544401 -0.166545488854 0.0137397678012 -0.0268394211932 -0.260328038765 0.0745185746772 0.020864049205 0.133485534344 -0.0479098207297 0.145382061477 -0.116284346216 0.0822848147919 -0.00621959258902 0.0135679910959 -0.0723116375013 -0.422793539068 0.144456402991 -0.119019192402 0.0659297394103......

3、config.py

#!/usr/bin/env python# -*- encoding: utf-8 -*-# @Version : Python 3.6import argparseimport torchimport osimport randomimport jsonimport numpy as npclass Config(object):def __init__(self):# get init configargs = self.__get_config()for key in args.__dict__:setattr(self, key, args.__dict__[key])# select deviceself.device = Noneif self.cuda >= 0 and torch.cuda.is_available():self.device = torch.device('cuda:{}'.format(self.cuda))else:self.device = torch.device('cpu')# determine the model name and model dirif self.model_name is None:self.model_name = 'Att_BLSTM'self.model_dir = os.path.join(self.output_dir, self.model_name)if not os.path.exists(self.model_dir):os.makedirs(self.model_dir)# backup dataself.__config_backup(args)# set the random seedself.__set_seed(self.seed)def __get_config(self):parser = argparse.ArgumentParser()parser.description = 'config for models'# several key selective parametersparser.add_argument('--data_dir', type=str,default='./data',help='dir to load data')parser.add_argument('--output_dir', type=str,default='./output',help='dir to save output')# word embeddingparser.add_argument('--embedding_path', type=str,default='./embedding/glove.6B.100d.txt',help='pre_trained word embedding')parser.add_argument('--word_dim', type=int,default=100,help='dimension of word embedding')# train settingsparser.add_argument('--model_name', type=str,default=None,help='model name')parser.add_argument('--mode', type=int,default=1,choices=[0, 1],help='running mode: 1 for training; otherwise testing')parser.add_argument('--seed', type=int,default=5782,help='random seed')parser.add_argument('--cuda', type=int,default=0,help='num of gpu device, if -1, select cpu')parser.add_argument('--epoch', type=int,default=30,help='max epoches during training')# hyper parametersparser.add_argument('--batch_size', type=int,default=10,help='batch size')parser.add_argument('--lr', type=float,default=1.0,help='learning rate')parser.add_argument('--max_len', type=int,default=100,help='max length of sentence')parser.add_argument('--emb_dropout', type=float,default=0.3,help='the possiblity of dropout in embedding layer')parser.add_argument('--lstm_dropout', type=float,default=0.3,help='the possiblity of dropout in (Bi)LSTM layer')parser.add_argument('--linear_dropout', type=float,default=0.5,help='the possiblity of dropout in liner layer')parser.add_argument('--hidden_size', type=int,default=100,help='the dimension of hidden units in (Bi)LSTM layer')parser.add_argument('--layers_num', type=int,default=1,help='num of RNN layers')parser.add_argument('--L2_decay', type=float, default=1e-5,help='L2 weight decay')args = parser.parse_args()return argsdef __set_seed(self, seed=1234):os.environ['PYTHONHASHSEED'] = '{}'.format(seed)random.seed(seed)np.random.seed(seed)torch.manual_seed(seed) # set seed for cputorch.cuda.manual_seed(seed) # set seed for current gputorch.cuda.manual_seed_all(seed) # set seed for all gpudef __config_backup(self, args):config_backup_path = os.path.join(self.model_dir, 'config.json')with open(config_backup_path, 'w', encoding='utf-8') as fw:json.dump(vars(args), fw, ensure_ascii=False)def print_config(self):for key in self.__dict__:print(key, end=' = ')print(self.__dict__[key])if __name__ == '__main__':config = Config()config.print_config()

4、model.py

#!/usr/bin/env python# -*- encoding: utf-8 -*-# @Version : Python 3.6import torchimport torch.nn as nnimport torch.nn.functional as Ffrom torch.nn import initfrom torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequenceclass Att_BLSTM(nn.Module):def __init__(self, word_vec, class_num, config):super().__init__()self.word_vec = word_vecself.class_num = class_num# hyper parameters and othersself.max_len = config.max_lenself.word_dim = config.word_dimself.hidden_size = config.hidden_sizeself.layers_num = config.layers_numself.emb_dropout_value = config.emb_dropoutself.lstm_dropout_value = config.lstm_dropoutself.linear_dropout_value = config.linear_dropout# net structures and operationsself.word_embedding = nn.Embedding.from_pretrained(embeddings=self.word_vec,freeze=False,)self.lstm = nn.LSTM(input_size=self.word_dim,hidden_size=self.hidden_size,num_layers=self.layers_num,bias=True,batch_first=True,dropout=0,bidirectional=True,)self.tanh = nn.Tanh()self.emb_dropout = nn.Dropout(self.emb_dropout_value)self.lstm_dropout = nn.Dropout(self.lstm_dropout_value)self.linear_dropout = nn.Dropout(self.linear_dropout_value)self.att_weight = nn.Parameter(torch.randn(1, self.hidden_size, 1))self.dense = nn.Linear(in_features=self.hidden_size,out_features=self.class_num,bias=True)# initialize weightinit.xavier_normal_(self.dense.weight)init.constant_(self.dense.bias, 0.)def lstm_layer(self, x, mask):lengths = torch.sum(mask.gt(0), dim=-1)x = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)h, (_, _) = self.lstm(x)h, _ = pad_packed_sequence(h, batch_first=True, padding_value=0.0, total_length=self.max_len)h = h.view(-1, self.max_len, 2, self.hidden_size)h = torch.sum(h, dim=2) # B*L*Hreturn hdef attention_layer(self, h, mask):att_weight = self.att_weight.expand(mask.shape[0], -1, -1) # B*H*1att_score = torch.bmm(self.tanh(h), att_weight) # B*L*H * B*H*1 -> B*L*1# mask, remove the effect of 'PAD'mask = mask.unsqueeze(dim=-1) # B*L*1att_score = att_score.masked_fill(mask.eq(0), float('-inf')) # B*L*1att_weight = F.softmax(att_score, dim=1) # B*L*1reps = torch.bmm(h.transpose(1, 2), att_weight).squeeze(dim=-1) # B*H*L * B*L*1 -> B*H*1 -> B*Hreps = self.tanh(reps) # B*Hreturn repsdef forward(self, data):token = data[:, 0, :].view(-1, self.max_len)mask = data[:, 1, :].view(-1, self.max_len)emb = self.word_embedding(token) # B*L*word_dimemb = self.emb_dropout(emb)h = self.lstm_layer(emb, mask) # B*L*Hh = self.lstm_dropout(h)reps = self.attention_layer(h, mask) # B*repsreps = self.linear_dropout(reps)logits = self.dense(reps)return logits

5、train_or_test.py

#!/usr/bin/env python# -*- encoding: utf-8 -*-# @Version : Python 3.6import osimport torchimport torch.nn as nnimport torch.optim as optimfrom config import Configfrom utils import WordEmbeddingLoader, RelationLoader, SemEvalDataLoaderfrom model import Att_BLSTMfrom evaluate import Evaldef print_result(predict_label, id2rel, start_idx=8001):with open('predicted_result.txt', 'w', encoding='utf-8') as fw:for i in range(0, predict_label.shape[0]):fw.write('{}\t{}\n'.format(start_idx+i, id2rel[int(predict_label[i])]))def train(model, criterion, loader, config):train_loader, dev_loader, _ = loaderoptimizer = optim.Adadelta(model.parameters(), lr=config.lr, weight_decay=config.L2_decay)print(model)print('traning model parameters:')for name, param in model.named_parameters():if param.requires_grad:print('%s : %s' % (name, str(param.data.shape)))print('--------------------------------------')print('start to train the model ...')eval_tool = Eval(config)min_f1 = -float('inf')for epoch in range(1, config.epoch+1):for step, (data, label) in enumerate(train_loader):model.train()data = data.to(config.device)label = label.to(config.device)optimizer.zero_grad()logits = model(data)loss = criterion(logits, label)loss.backward()nn.utils.clip_grad_value_(model.parameters(), clip_value=5)optimizer.step()_, train_loss, _ = eval_tool.evaluate(model, criterion, train_loader)f1, dev_loss, _ = eval_tool.evaluate(model, criterion, dev_loader)print('[%03d] train_loss: %.3f | dev_loss: %.3f | micro f1 on dev: %.4f'% (epoch, train_loss, dev_loss, f1), end=' ')if f1 > min_f1:min_f1 = f1torch.save(model.state_dict(), os.path.join(config.model_dir, 'model.pkl'))print('>>> save models!')else:print()def test(model, criterion, loader, config):print('--------------------------------------')print('start test ...')_, _, test_loader = loadermodel.load_state_dict(torch.load(os.path.join(config.model_dir, 'model.pkl')))eval_tool = Eval(config)f1, test_loss, predict_label = eval_tool.evaluate(model, criterion, test_loader)print('test_loss: %.3f | micro f1 on test: %.4f' % (test_loss, f1))return predict_labelif __name__ == '__main__':config = Config()print('--------------------------------------')print('some config:')config.print_config()print('--------------------------------------')print('start to load data ...')word2id, word_vec = WordEmbeddingLoader(config).load_embedding()rel2id, id2rel, class_num = RelationLoader(config).get_relation()loader = SemEvalDataLoader(rel2id, word2id, config)train_loader, dev_loader = None, Noneif config.mode == 1: # train modetrain_loader = loader.get_train()dev_loader = loader.get_dev()test_loader = loader.get_test()loader = [train_loader, dev_loader, test_loader]print('finish!')print('--------------------------------------')model = Att_BLSTM(word_vec=word_vec, class_num=class_num, config=config)model = model.to(config.device)criterion = nn.CrossEntropyLoss()if config.mode == 1: # train modetrain(model, criterion, loader, config)predict_label = test(model, criterion, loader, config)print_result(predict_label, id2rel)

6、evaluate.py

#!/usr/bin/env python# -*- encoding: utf-8 -*-# @Version : Python 3.6import numpy as npimport torchdef semeval_scorer(predict_label, true_label, class_num=10):import mathassert true_label.shape[0] == predict_label.shape[0]confusion_matrix = np.zeros(shape=[class_num, class_num], dtype=np.float32)xDIRx = np.zeros(shape=[class_num], dtype=np.float32)for i in range(true_label.shape[0]):true_idx = math.ceil(true_label[i]/2)predict_idx = math.ceil(predict_label[i]/2)if true_label[i] == predict_label[i]:confusion_matrix[predict_idx][true_idx] += 1else:if true_idx == predict_idx:xDIRx[predict_idx] += 1else:confusion_matrix[predict_idx][true_idx] += 1col_sum = np.sum(confusion_matrix, axis=0).reshape(-1)row_sum = np.sum(confusion_matrix, axis=1).reshape(-1)f1 = np.zeros(shape=[class_num], dtype=np.float32)for i in range(0, class_num): # ignore the 'Other'try:p = float(confusion_matrix[i][i]) / float(col_sum[i] + xDIRx[i])r = float(confusion_matrix[i][i]) / float(row_sum[i] + xDIRx[i])f1[i] = (2 * p * r / (p + r))except:passactual_class = 0total_f1 = 0.0for i in range(1, class_num):if f1[i] > 0.0: # classes that not in the predict label are not consideredactual_class += 1total_f1 += f1[i]try:macro_f1 = total_f1 / actual_classexcept:macro_f1 = 0.0return macro_f1class Eval(object):def __init__(self, config):self.device = config.devicedef evaluate(self, model, criterion, data_loader):predict_label = []true_label = []total_loss = 0.0with torch.no_grad():model.eval()for _, (data, label) in enumerate(data_loader):data = data.to(self.device)label = label.to(self.device)logits = model(data)loss = criterion(logits, label)total_loss += loss.item() * logits.shape[0]_, pred = torch.max(logits, dim=1) # replace softmax with max function, same impactspred = pred.cpu().detach().numpy().reshape((-1, 1))label = label.cpu().detach().numpy().reshape((-1, 1))predict_label.append(pred)true_label.append(label)predict_label = np.concatenate(predict_label, axis=0).reshape(-1).astype(np.int64)true_label = np.concatenate(true_label, axis=0).reshape(-1).astype(np.int64)eval_loss = total_loss / predict_label.shape[0]f1 = semeval_scorer(predict_label, true_label)return f1, eval_loss, predict_label

7、util.py

#!/usr/bin/env python# -*- encoding: utf-8 -*-# @Version : Python 3.6import osimport jsonimport torchimport numpy as npfrom torch.utils.data import Dataset, DataLoaderclass WordEmbeddingLoader(object):"""A loader for pre-trained word embedding"""def __init__(self, config):self.path_word = config.embedding_path # path of pre-trained word embeddingself.word_dim = config.word_dim # dimension of word embeddingdef load_embedding(self):word2id = dict() # word to wordIDword_vec = list() # wordID to word embeddingword2id['PAD'] = len(word2id) # PAD characterword2id['UNK'] = len(word2id) # out of vocabularyword2id['<e1>'] = len(word2id)word2id['<e2>'] = len(word2id)word2id['</e1>'] = len(word2id)word2id['</e2>'] = len(word2id)with open(self.path_word, 'r', encoding='utf-8') as fr:for line in fr:line = line.strip().split()if len(line) != self.word_dim + 1:continueword2id[line[0]] = len(word2id)word_vec.append(np.asarray(line[1:], dtype=np.float32))word_vec = np.stack(word_vec)vec_mean, vec_std = word_vec.mean(), word_vec.std()special_emb = np.random.normal(vec_mean, vec_std, (6, self.word_dim))special_emb[0] = 0 # <pad> is initialize as zeroword_vec = np.concatenate((special_emb, word_vec), axis=0)word_vec = word_vec.astype(np.float32).reshape(-1, self.word_dim)word_vec = torch.from_numpy(word_vec)return word2id, word_vecclass RelationLoader(object):def __init__(self, config):self.data_dir = config.data_dirdef __load_relation(self):relation_file = os.path.join(self.data_dir, 'relation2id.txt')rel2id = {}id2rel = {}with open(relation_file, 'r', encoding='utf-8') as fr:for line in fr:relation, id_s = line.strip().split()id_d = int(id_s)rel2id[relation] = id_did2rel[id_d] = relationreturn rel2id, id2rel, len(rel2id)def get_relation(self):return self.__load_relation()class SemEvalDateset(Dataset):def __init__(self, filename, rel2id, word2id, config):self.filename = filenameself.rel2id = rel2idself.word2id = word2idself.max_len = config.max_lenself.data_dir = config.data_dirself.dataset, self.label = self.__load_data()def __symbolize_sentence(self, sentence):"""Args:sentence (list)"""mask = [1] * len(sentence)words = []length = min(self.max_len, len(sentence))mask = mask[:length]for i in range(length):words.append(self.word2id.get(sentence[i].lower(), self.word2id['UNK']))if length < self.max_len:for i in range(length, self.max_len):mask.append(0) # 'PAD' mask is zerowords.append(self.word2id['PAD'])unit = np.asarray([words, mask], dtype=np.int64)unit = np.reshape(unit, newshape=(1, 2, self.max_len))return unitdef __load_data(self):path_data_file = os.path.join(self.data_dir, self.filename)data = []labels = []with open(path_data_file, 'r', encoding='utf-8') as fr:for line in fr:line = json.loads(line.strip())label = line['relation']sentence = line['sentence']label_idx = self.rel2id[label]one_sentence = self.__symbolize_sentence(sentence)data.append(one_sentence)labels.append(label_idx)return data, labelsdef __getitem__(self, index):data = self.dataset[index]label = self.label[index]return data, labeldef __len__(self):return len(self.label)class SemEvalDataLoader(object):def __init__(self, rel2id, word2id, config):self.rel2id = rel2idself.word2id = word2idself.config = configdef __collate_fn(self, batch):data, label = zip(*batch) # unzip the batch datadata = list(data)label = list(label)data = torch.from_numpy(np.concatenate(data, axis=0))label = torch.from_numpy(np.asarray(label, dtype=np.int64))return data, labeldef __get_data(self, filename, shuffle=False):dataset = SemEvalDateset(filename, self.rel2id, self.word2id, self.config)loader = DataLoader(dataset=dataset,batch_size=self.config.batch_size,shuffle=shuffle,num_workers=2,collate_fn=self.__collate_fn)return loaderdef get_train(self):return self.__get_data('train.json', shuffle=True)def get_dev(self):return self.__get_data('test.json', shuffle=False)def get_test(self):return self.__get_data('test.json', shuffle=False)if __name__ == '__main__':from config import Configconfig = Config()word2id, word_vec = WordEmbeddingLoader(config).load_embedding()rel2id, id2rel, class_num = RelationLoader(config).get_relation()loader = SemEvalDataLoader(rel2id, word2id, config)test_loader = loader.get_train()for step, (data, label) in enumerate(test_loader):print(type(data), data.shape)print(type(label), label.shape)break

NLP-信息抽取-关系抽取-:Attention-BiLSTM实体关系分类器【基于双向LSTM及注意力机制的关系分类】【数据集:SemEval- Task 8】

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。