2000字范文,分享全网优秀范文,学习好帮手!
2000字范文 > 深度学习模型保存_Web服务部署深度学习模型

深度学习模型保存_Web服务部署深度学习模型

时间:2020-03-19 03:25:48

相关推荐

深度学习模型保存_Web服务部署深度学习模型

本文的目的是介绍如何使用Web服务快速部署深度学习模型,虽然TF有TFserving可以进行模型部署,但是对于Pytorch无能为力(如果要使用的话需要把torch模型进行转换,有些麻烦);因此,本文在这里介绍一种使用Web服务部署深度学习的方法(简单有效,不喜勿喷)。

本文以简单的新闻分类模型来举例,模型:BERT;数据来源:清华新闻语料(地址:

THUCTC: 一个高效的中文文本分类工具),清华新闻语料共有14个类别,分别是体育,娱乐,家居,彩票,房产,教育,时尚,时政,星座,游戏,社会,科技,股票和财经。为了快速训练模型,本人在每个类别中分别随机挑选1000个作为训练集,200个作为验证集。数据预处理、模型训练和pb模型保存代码见:新闻分类模型训练github地址。(非重点,不过多介绍了,github上有详细的使用说明,有问题可留言。)

为了使web服务部署变得简洁,因此本人构造一个方法类,方便加载pb模型,对传入文本进行数据预处理以及进行模型预测。

模型初始化代码如下:

import bert_tokenizationimport tensorflow as tffrom tensorflow.python.platform import gfileimport numpy as npimport osclass ClassificationModel(object):def __init__(self):self.tokenizer = Noneself.sess = Noneself.is_train = Noneself.input_ids = Noneself.input_mask = Noneself.segment_ids = Noneself.predictions = Noneself.max_seq_length = Noneself.label_dict = ['体育', '娱乐', '家居', '彩票', '房产', '教育', '时尚', '时政', '星座', '游戏', '社会', '科技', '股票', '财经']

其中,tokenizer 为分词器;sess为TF的session模块;is_train、input_ids、input_mask和segment_ids分别是pb模型的输入;predictions为pb模型的输出;max_seq_length为模型的最大输入长度;label_dict为新闻分类标签。

加载pb模型代码如下:

def load_model(self, gpu_id, vocab_file, gpu_memory_fraction, model_path, max_seq_length):os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'os.environ['CUDA_VISIBLE_DEVICES'] = gpu_idself.tokenizer = bert_tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=True)gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_memory_fraction)sess_config = tf.ConfigProto(gpu_options=gpu_options)self.sess = tf.Session(config=sess_config)with gfile.FastGFile(model_path, "rb") as f:graph_def = tf.GraphDef()graph_def.ParseFromString(f.read())self.sess.graph.as_default()tf.import_graph_def(graph_def, name="")self.sess.run(tf.global_variables_initializer())self.is_train = self.sess.graph.get_tensor_by_name("input/is_train:0")self.input_ids = self.sess.graph.get_tensor_by_name("input/input_ids:0")self.input_mask = self.sess.graph.get_tensor_by_name("input/input_mask:0")self.segment_ids = self.sess.graph.get_tensor_by_name("input/segment_ids:0")self.predictions = self.sess.graph.get_tensor_by_name("output_layer/predictions:0")self.max_seq_length = max_seq_length

其中,gpu_id为使用GPU的序号;vocab_file为BERT模型所使用的字典路径;gpu_memory_fraction为使用GPU时所占用的比例;model_path为pb模型的路径;max_seq_length为BERT模型的最大长度。

将传入文本转化成模型所需格式代码如下:

def convert_fearture(self, text):max_seq_length = self.max_seq_lengthmax_length_context = max_seq_length - 2content_token = self.tokenizer.tokenize(text)if len(content_token) > max_length_context:content_token = content_token[:max_length_context]tokens = []segment_ids = []tokens.append("[CLS]")segment_ids.append(0)for token in content_token:tokens.append(token)segment_ids.append(0)tokens.append("[SEP]")segment_ids.append(0)input_ids = self.tokenizer.convert_tokens_to_ids(tokens)input_mask = [1] * len(input_ids)while len(input_ids) < max_seq_length:input_ids.append(0)input_mask.append(0)segment_ids.append(0)assert len(input_ids) == max_seq_lengthassert len(input_mask) == max_seq_lengthassert len(segment_ids) == max_seq_lengthinput_ids = np.array(input_ids)input_mask = np.array(input_mask)segment_ids = np.array(segment_ids)return input_ids, input_mask, segment_ids

预测代码如下:

def predict(self, text):input_ids_temp, input_mask_temp, segment_ids_temp = self.convert_fearture(text)feed = {self.is_train: False,self.input_ids: input_ids_temp.reshape(1, self.max_seq_length),self.input_mask: input_mask_temp.reshape(1, self.max_seq_length),self.segment_ids: segment_ids_temp.reshape(1, self.max_seq_length)}[label] = self.sess.run([self.predictions], feed)label_name = self.label_dict[label[0]]return label[0], label_name

其中,输入是一个新闻文本,输出为类别序号以及对应的标签名称。详细完整代码见github:

ClassificationModel.py文件。

(划重点)上面介绍的都是如何方便简洁地加载模型,下面开始使用web服务挂起模型。通俗地讲,其实本人就是通过flask框架,搭建了一个web服务,来获取外部的输入;并且使用挂载的模型进行预测;最后将预测结果通过web服务传出。

from gevent import monkeymonkey.patch_all()from flask import Flask, requestfrom gevent import wsgiimport jsonfrom ClassificationModel import ClassificationModeldef start_sever(http_id, port, gpu_id, vocab_file, gpu_memory_fraction, model_path, max_seq_length):model = ClassificationModel()model.load_model(gpu_id, vocab_file, gpu_memory_fraction, model_path, max_seq_length)print("load model ending!")app = Flask(__name__)@app.route('/')def index():return "This is News Classification Model Server"@app.route('/news-classification', methods=['Get', 'POST'])def response_request():if request.method == 'POST':text = request.form.get('text')else:text = request.args.get('text')label, label_name = model.predict(text)d = {"label": str(label), "label_name": label_name}print(d)return json.dumps(d, ensure_ascii=False)server = wsgi.WSGIServer((str(http_id), port), app)server.serve_forever()

其中,http_id为web服务的地址;port为端口号;gpu_id、vocab_file、gpu_memory_fraction、model_path和max_seq_length为上面介绍的加载模型所需要的参数,详细见上文。

index函数用于检验web服务是否畅通。如图1所示。

图1

response_request函数为响应函数。定义了两种请求数据的方式,get和post。当使用get方法获取web输入时,获取命令为request.args.get('text');当使用post方法获取web输入时,获取命令为request.form.get('text')。

当web服务起起来之后,就可以调用啦!!!

浏览器调用如图2所示。

图2

Code调用如下:

import requestsdef http_test(text):url = 'http://127.0.0.1:5555/news-classification'raw_data = {'text': text}res = requests.post(url, raw_data)result = res.json()return resultif __name__ == "__main__":text = "姚明在NBA打球,很强。"result = http_test(text)print(result["label_name"])

以上就是通过web服务部署深度学习模型的全部内容,喜欢的同学还请多多点赞~~~~~

推荐几篇本人之前写的一些文章:

刘聪NLP:短文本相似度算法研究

刘聪NLP:阅读笔记:开放域检索问答(ORQA)

刘聪NLP:论文阅读笔记:文本蕴含之BiMPM

喜欢的同学,可以关注一下专栏,关注一下作者,还请多多点赞~~~~~~

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。