3 Star 12 Fork 3

zhaoang / TextBrewerNer

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
utils_ner.py 6.90 KB
一键复制 编辑 原始数据 按行查看 历史
zhaoang 提交于 2020-09-25 14:19 . TextBrewerNer
import os, pickle
import torch
from torch.utils.data import TensorDataset
wrap = ['B-', 'I-','E-','S-']
wgseg_tag = ['w']
label2id_dict = {'O':0}
#pos_tags = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
id2space = {0:'', 1:'', 2:'', 3:' ', 4:' '}
for i, tag in enumerate(wgseg_tag):
for j, w in enumerate(wrap):
one_tag = w + tag
label2id_dict[one_tag] = len(label2id_dict)
id2label_dict = {v:k for k, v in label2id_dict.items()}
def half2full(s):
n = []
s = s.decode('utf-8')
for char in s:
num = char(char)
if num == 320:
num = 0x3000
elif 0x21 <= num <= 0x7E:
num += 0xfee0
num = chr(num)
n.append(num)
return ''.join(n)
def full2half(strs):
n = []
for char in strs:
num = ord(char)
if num == 0x3000:
num = 32
elif 0xFF01 <= num <= 0xFF5E:
num -= 0xfee0
num = chr(num)
n.append(num)
return ''.join(n)
class Tokenize():
def __init__(self, dict_path):
self.dict_path = dict_path
self.token2id = {}
self.id2token = {}
self.unkid = 100
self.dict_size = 0
self.load_dict()
def load_dict(self,):
for i, line in enumerate(open(self.dict_path, 'r', encoding='utf-8')):
token = line.strip('\n')
if token.rfind('##') == 0:
continue
self.id2token[i] = token
self.token2id[token] = i
self.dict_size = len(self.token2id)
def sentence2id(self, line):
ids = [self.token2id.get(c, self.unkid) for c in line]
return ids
def id2sentence(self, ids):
line = [self.id2token.get(c) for c in ids]
return ''.join(line)
def convert_tokens_to_ids(self, line):
return self.sentence2id(line)
class Examples:
def __init__(self, tokens, label_ids):
self.tokens = tokens
self.label_ids = label_ids
def __str__(self):
return self.__repr__()
def __repr__(self):
s = ""
s += f"tokens: {' '.join(self.tokens)}\n"
s += f"labels: {' '.join(str(i) for i in self.label_ids)}\n"
return s
class Featues:
def __init__(self, token_ids, input_mask, label_ids):
self.token_ids = token_ids
self.input_mask = input_mask
self.label_ids = label_ids
def __str__(self):
return self.__repr__()
def __repr__(self):
s = ""
s += f"token_ids: {' '.join(str(i) for i in self.token_ids)}\n"
s += f"label_ids: {' '.join(str(i) for i in self.label_ids)}\n"
s += f"input_mask:{' '.join(str(i) for i in self.input_mask)}\n"
return s
def read_examples(input_file, mode=0):
examples = []
tokens = []
label_ids = []
errors = 0
if mode == 0:
with open(input_file, 'r', encoding='utf-8') as f:
for idx, line in enumerate(f):
if len(line.strip()) == 0:
if len(tokens)>0:
examples.append(Examples(tokens, label_ids))
tokens = []
label_ids = []
continue
try:
tup = line.strip().split('\t')
token, seg_tag, pos_tag, entity_tag = tup
except ValueError:
errors +=1
print("Num errors: ", errors)
continue
tokens.append(full2half(token.lower()))
label_ids.append(label2id_dict[seg_tag])
elif mode == 1:
with open(input_file, 'r', encoding='utf-8') as f:
for idx, line in enumerate(f):
tokens = list(full2half(line.lower().strip().replace(' ', '')))
label_ids = [0] * len(tokens)
examples.append(Examples(tokens, label_ids))
return examples
def convert_example_to_features(examples, tokenizer, max_seq_length,
cls_token='[CLS]', sep_token='[SEP]', pad_token_id=0, txtmode=0):
features = []
PAD = '[PAD]'
overlap = 8
pad_label = [label2id_dict['O']]
stride = max_seq_length - overlap * 2
tokens = []
token_ids = []
label_ids = []
examples_rsp = []
##reshape example.
#head pad
tokens.extend([PAD] * overlap)
token_ids.extend([0] * overlap)
label_ids.extend([0] * overlap)
for example in examples:
stokens = [cls_token] + example.tokens + [sep_token]
stoken_ids = tokenizer.convert_tokens_to_ids(stokens)
slabel_ids = pad_label + example.label_ids + pad_label
tokens.extend(stokens)
token_ids.extend(stoken_ids)
label_ids.extend(slabel_ids)
#tail pad
tokens.extend([PAD] * overlap)
token_ids.extend([0] * overlap)
label_ids.extend([0] * overlap)
assert len(token_ids) == len(label_ids)
length = len(token_ids)
start = 0
#eager pad.
while start < length:
end = start + max_seq_length
while end < length and id2label_dict[label_ids[end - 1]][0] not in 'ESO':
end -= 1
toks = tokens[start:end]
ids = token_ids[start:end]
labs = label_ids[start:end]
masks = [1] * len(ids)
pad_len = max_seq_length - len(ids)
toks = toks + [PAD] * pad_len
ids = ids + [pad_token_id] * pad_len
masks = masks + [0] * pad_len
labs = labs + pad_label * pad_len
examples_rsp.append(Examples(toks, labs))
features.append(Featues(token_ids=ids, input_mask=masks, label_ids=labs))
if start == 0:
print(f'examples 0 show:\n {examples_rsp[0].__str__()}')
start += stride
while start < length and id2label_dict[label_ids[start]][0] not in 'BSO':
start += 1
return examples_rsp, features
def read_features(input_file, max_seq_length=160, tokenizer=None, cls_token='[CLS]', sep_token='[SEP]', pad_token_id=0, txtmode=0, dump=True):
cached_features_file = input_file +f'_{max_seq_length}.bin'
if os.path.exists(cached_features_file) and dump is True:
with open(cached_features_file,'rb') as f:
examples, features = pickle.load(f)
else:
raw_examples = read_examples(input_file, txtmode)
examples, features = convert_example_to_features(raw_examples, tokenizer,max_seq_length,cls_token,sep_token,pad_token_id, txtmode)
if dump is True:
with open(cached_features_file, 'wb') as f:
pickle.dump([examples, features],f)
all_token_ids = torch.tensor([f.token_ids for f in features],dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in features],dtype=torch.long)
all_label_ids = torch.tensor([f.label_ids for f in features], dtype=torch.long)
dataset = TensorDataset(all_token_ids,all_input_mask,all_label_ids)
return examples, dataset
if __name__ == '__main__':
pass
Python
1
https://gitee.com/angzhao/TextBrewerNer.git
git@gitee.com:angzhao/TextBrewerNer.git
angzhao
TextBrewerNer
TextBrewerNer
master

搜索帮助