2000字范文,分享全网优秀范文,学习好帮手!
2000字范文 > Keras花卉分类全流程(预处理+训练+预测)

Keras花卉分类全流程(预处理+训练+预测)

时间:2020-12-31 09:12:45

相关推荐

Keras花卉分类全流程(预处理+训练+预测)

本文的代码包括以下内容的示例:

1.用一个类封装自己的模型和训练、预测等过程

2.使用图片生成器(ImageDataGenerator)进行数据预处理,这一功能是Keras很 方面的地方,省去了自己进行数据处理的过程

3.在使用图片生成器的基础上进行训练的过程(fit_generator用法)

4.如何使用图片生成器进行预测和精度验证,这一部分包括predict_generator 的用法,以及相关的标签提取的过程,其中的一些细节。这一部分内容,在官方文档亦是没有提及。(我是没找到)

提醒

对于keras初学者来说,可能难以看懂,但是只要坚持看完代码,弄清楚数据生成器的使用,对使用keras进行分类的问题将是巨大帮助。这个东西我也是写加改弄了三四天,感觉帮助很大。

模型

本文采用的是简化版的VGG的A型的网络

优化方法:

本文采用的是Adamdelta

数据集

是Google的一个花卉数据集

共分为五类:

daisy,dandelion,roses,sunflower,tulips

我已手动分好了数据

当然比例有点问题

不过就是个小实验嘛

链接:/s/1ktu-6GOWnSYjuzHyxFeL7Q

提取码:bask

在给出代码之前,我想总结的几点经验:

1.模型难以收敛可能与图片尺寸有关系,较小的图片比较容易收敛,模型也相对容易训练。较大的图片使得模型的参数呈几何倍数增加,训练难度加大。

2.我上一篇博客写的是关于数据预处理的,是将原始图片分成了训练集、验证集、和测试集,存成了numpy数组的形式,进行保存。这样子不甚方便。

keras数据生成器的.flow_from_directroy()方法,直接生成各式数据。不需要进行人工的预处理。

数据需要组织成以下形式:

主文件夹下属train、validation、test三个文件夹,每个文件夹内下属多个类别名称文件夹,每个类别名称文件夹下下属该类别的图片。

3.全连接层与最后一个卷积层之间相连的时候,所需的显存极大,参数极多,容易造成资源枯竭的错误。尤其是当图片的尺寸较大的时候,更容易发生错误。

我的配置:

CPU:i7-8750H

内存:8GB

显卡:GTX1060 6GB显存

这样的配置,在一个7x7x512*4096的大张量下直接不行。最后进行了调整,适应了我的电脑配置。

通常出现: OOM when allocating tensor 就是这个地方有问题

4.loss不下降很有可能是你训练的轮次不够多。

因为从mnist的一下子训练完,到现在你训练一个比较大的网络,所需的时间是不一样的。需要转过弯来。尤其是图片尺寸和网络深度都相比较较大的时候,训练绝不是能够快速的就拟合的,而是需要一个缓慢的变化过程。

5.loss为nan有可能是你的数据没有成功的喂入网络中。尤其是一上来loss和acc就为nan的情况。

6.遇到问题,百度不到的,可以试试Google,找到答案的几率更大。

代码部分:

下面是代码:

##coding:utf-8#from __future__ import print_functionimport kerasfrom keras.models import Sequentialimport numpy as npfrom keras.layers import Dense,Conv2D,Dropout,BatchNormalizationfrom keras.layers import Activation,MaxPooling2D,Flattenfrom keras.preprocessing.image import ImageDataGeneratorimport matplotlib.pyplot as pltfrom PIL import Imagefrom keras.utils.vis_utils import plot_modelimport keras.backend as K'''Use 11 weight layers vursion to reduceparameters to train fastThe dataset is small,only 3000 pictures'''class SampleVGGForFlower:'''This class is for flower dataset of Google'''def __init__(self,train_path,input_shape=(100,100,3),imagesize=(100,100),validation_path=None,test_path=None,train=True,show=True,plot=False):'''模型的传入图片的维度和尺寸都可以另行改变因为不同的尺寸训练的难度也很不同show:是否展示图片的训练过程的精度和损失值变化plot:是否保存权重文件train:是进行训练还是不进行训练只进行预测'''self.num_classes = 5self.batch_size = 16self.num_epoch = 200self.input_shape = input_shapeself.imagesize = imagesizeself.learningrate = 1self.trainpath = train_pathself.validationpath = validation_pathself.testpath = test_pathself.show = showself.plot = plotself.model = self.build_model(plot)if train:self.model = self.train(self.model,show)else:self.model.load_weights('flowervgg.h5')def build_model(self,plot):'''Build the model with 11 layers.This is a sample vursion of VGG netBecause of my device,some parameters arechanged.我设备不行,必须去掉4096长的全连接层采取贯序模型'''model = Sequential()model.add(Conv2D(64,(3,3),padding='same',input_shape=self.input_shape))model.add(Activation('relu'))model.add(BatchNormalization())model.add(MaxPooling2D())model.add(Conv2D(128,(3,3),padding='same'))model.add(Activation('relu'))model.add(BatchNormalization())model.add(MaxPooling2D())model.add(Conv2D(256,(3,3),padding='same'))model.add(Activation('relu'))model.add(BatchNormalization())model.add(Dropout(0.3))model.add(Conv2D(256,(3,3),padding='same'))model.add(Activation('relu'))model.add(BatchNormalization())model.add(MaxPooling2D())model.add(Conv2D(512,(3,3),padding='same'))model.add(Activation('relu'))model.add(BatchNormalization())model.add(Dropout(0.3))model.add(Conv2D(512,(3,3),padding='same'))model.add(Activation('relu'))model.add(BatchNormalization())model.add(MaxPooling2D())model.add(Flatten())model.add(Dense(1000))model.add(Activation('relu'))model.add(BatchNormalization())model.add(Dropout(0.3))model.add(Dense(self.num_classes))model.add(Activation('softmax'))#有人说keras里softmax和交叉熵不能一起用#无稽之谈if plot:plot_model(model,'model.png',show_shapes=True,show_layer_names=True)return modeldef train(self,model,show=True):'''In this function,model will be trained.And if show was set,a image of the modelwill be stored at the directory.训练网络'''train_datagen = ImageDataGenerator(rotation_range=15,width_shift_range=0.1,height_shift_range=0.1,horizontal_flip=True,vertical_flip=False)validation_datagen = ImageDataGenerator()train_generator = train_datagen.flow_from_directory(directory=self.trainpath,target_size=self.imagesize,color_mode='rgb',classes=['daisy','dandelion','roses','sunflowers','tulips'],class_mode='categorical',batch_size=self.batch_size,shuffle=True)validation_generator = validation_datagen.flow_from_directory(directory=self.validationpath,target_size=self.imagesize,classes=['daisy','dandelion','roses','sunflowers','tulips'],color_mode='rgb',class_mode='categorical',batch_size=self.batch_size,shuffle=True)opt = keras.optimizers.adadelta(lr = self.learningrate)pile(optimizer=opt,loss='categorical_crossentropy',metrics=['accuracy'])#设置早停功能,二十五轮没有下降就自动停止,调整参数earlystop = keras.callbacks.EarlyStopping(monitor='loss',patience=25,mode='auto')history = model.fit_generator(generator=train_generator,steps_per_epoch=90,#this parameter depend on datasetepochs=self.num_epoch,callbacks=[earlystop],validation_data=validation_generator,validation_steps=4 #this parameter depend on my dataset)model.save_weights('flowervgg.h5')#保存模型'''This is the way to show the accuracy or lossin the training.这里是一个绘制模型训练变化的示例'''if show:plt.plot(history.history['acc'])plt.plot(history.history['val_acc'])plt.title('model acc')plt.ylabel('acc')plt.xlabel('epoch')plt.legend(['train','validation'])plt.show()else:passreturn modeldef predict(self):'''In this funtion,a path that testdata in should be provide,and the function could give a accuracy back.模型的预测的方法'''test_datagen = ImageDataGenerator()test_generator = test_datagen.flow_from_directory(directory=self.testpath,target_size=self.imagesize,color_mode='rgb',classes=['daisy','dandelion','roses','sunflowers','tulips'],class_mode='categorical',batch_size=1,shuffle=False#you must let shuffle False)#获取生成器生成的文件夹的全部文件名,计算数量filenames = test_generator.filenamesnb_samples = len(filenames)#这里的step必须是全部文件的数量,不然没法算精度了#返回预测值的一维数组predict = self.model.predict_generator(generator=test_generator,steps=nb_samples)#classes属性提供了文件夹内全部文件的类别#是一个一维数组true_labels = (test_generator.classes)pre_labels = np.argmax(predict,axis=-1)correct_label = np.equal(true_labels,pre_labels)accuracy = np.mean(correct_label)print('Accuracy is: ',accuracy)if __name__ == '__main__':#从这里传入参数model = SampleVGGForFlower(train_path='C:\\Users\\Dash\\Desktop\\Tensorflow\\preprocess\\flower_photos\\train',validation_path='C:\\Users\\Dash\\Desktop\\Tensorflow\\preprocess\\flower_photos\\validation',test_path='C:\\Users\\Dash\\Desktop\\Tensorflow\\preprocess\\flower_photos\\test',train = False)#你也可以不预测,把这行注释掉即可#尤其是你没有传入testpath的情况下model.predict()

训练集准确率能够达到90%以上,测试集比较小,准确率在75%左右,估计是过拟合了。

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。