2000字范文,分享全网优秀范文,学习好帮手!
2000字范文 > 人工智能图片分类Python小程序

人工智能图片分类Python小程序

时间:2023-09-06 19:56:54

相关推荐

人工智能图片分类Python小程序

个人小作业,虽说做的很差,也算是一个学习的转化;主要用于分类自己下载的壁纸

1 背景

学期末需要一个学习成果的展示,高难度的自己做不来,模型也跑不动(电脑有点渣),刚好自己也有图片分类的需求,最后决定做了这个,确实也算做了一个自己用得到的小程序

2 项目说明

2.1 项目需求

需要自动加载指定目录所有图片,自行迁移至指定目录并存入不同的文件夹

2.2 实现思路

数据来源于各大壁纸网站,通过下载分类好的图片免去了自己手动分类的痛苦将图片进行微缩处理,将1920 × \times × 1080的图片转化为192 × \times × 108的尺寸,不然尺寸太大硬件吃不消。第二步可以将图片转化为单通道,数据量会小很多,但是测试过程中发现数据集较小时准确率比直接使用三通道要高一些,但是数据集大之后三通道的图片识别更加准确目前数据集是共10000多张图片共五个分类(差不多自己电脑的上限),通过第二步、第三步的三通道缩小处理后,所有数据集大小约600MB,还在接受范围内。模型的搭建与其他模型搭建基本一致

3 项目说明

3.1 项目结构

│ colorUi.ui正在使用的UI界面文件│ fun.py对于模型函数的初步封装,为PyQt界面提供支持│ main.py入口部分│ model.py模型的训练、加载│ ui.py正在使用的UI界面py文件│ ui.ui老的UI界面文件│ utils.py一些读取图片处理图片的函数├─fun_test内含各类图片共100张,用于最后的功能测试├─make_data_set用于处理制作数据集├─model训练好的模型存储的路径├─test内含处理好的数据集的测试集,存储格式是是numpy数组的序列化,三通道维度信息(N,108.,192,3);标签一维数组├─test_pic测试集原始数据目录,路径下各种图片独占一个目录,用于通过make_data_set制作数据集,目录应与train_pic对应│ ├─dongman其中一个分类│ ├─dongwu其中一个分类│ ├─fengjing其中一个分类│ ├─meinv其中一个分类│ └─youxi其中一个分类├─train内含处理好的数据集的训练集,存储格式是是numpy数组的序列化,三通道维度信息(N,108.,192,3);标签一维数组└─train_pic├─dongman其中一个分类├─dongwu其中一个分类├─fengjing其中一个分类├─meinv其中一个分类└─youxi其中一个分类

3.2 源码说明

3.2.1 模型的创建、加载、训练

import jsonimport osimport cv2import numpyimport numpy as npimport tensorflow as tffrom tensorflow import kerasfrom tqdm import tqdmfrom utils import img_resizedef init_network():"""初始化神经网络,支持五种类型:return: 模型"""model = tf.keras.Sequential([tf.keras.layers.Conv2D(filters=48, kernel_size=(3, 3), padding='same', activation='relu', strides=1,input_shape=(108, 192, 3)),tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),# 抑制过拟合tf.keras.layers.Dropout(rate=0.6),tf.keras.layers.Conv2D(filters=24, kernel_size=(3, 3), padding='same', activation='relu', strides=1),# 2*2池化取最大值tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),# 抑制过拟合tf.keras.layers.Dropout(rate=0.6),# 维度拉伸成1维tf.keras.layers.Flatten(),# 第二层隐藏层,使用relu激活函数tf.keras.layers.Dense(256, activation='relu'),# 抑制过拟合tf.keras.layers.Dropout(rate=0.6),tf.keras.layers.Dense(256, activation='relu'),tf.keras.layers.Dropout(rate=0.5),tf.keras.layers.Dense(256, activation='relu'),tf.keras.layers.Dropout(rate=0.5),# 输出层tf.keras.layers.Dense(5, activation='softmax')])pile(loss='categorical_crossentropy',optimizer='adam', metrics=['accuracy'])model.summary()return modeldef getTrainData():"""获取训练集数据:return: train_images, train_labels, class_names"""fp = open('./train/train.json', 'r', encoding='utf8')class_names = json.load(fp)['support']fp.close()# 返回加载来的数据集pic_train_images = numpy.load('./train/train_pic.npy')train_images = pic_train_images.reshape(pic_train_images.shape[0], 108, 192, 3) / 255.0print(train_images.shape)train_labels = numpy.load('./train/train_labels.npy')print(numpy.load('./train/train_labels.npy').shape)return train_images, train_labels, class_namesdef getTestData():"""获取测试集包的数据:return: train_images, train_labels, class_names"""fp = open('./test/test.json', 'r', encoding='utf8')class_names = json.load(fp)['support']fp.close()# 返回加载来的数据集pic_test_images = numpy.load('./test/test_pic.npy')test_images = pic_test_images.reshape(pic_test_images.shape[0], 108, 192, 3) / 255.0print(test_images.shape)test_labels = numpy.load('./test/test_labels.npy')print(numpy.load('./test/test_labels.npy').shape)return test_images, test_labels, class_namesdef getTestImages():"""加载测试集1920*1080的壁纸"""path = './test_pic'imgs = []labels = []k = 0paths = os.listdir(path)paths.sort()for j in paths:pbar = tqdm(total=100)for i in os.listdir(path + '/' + j):pbar.update(100.0 / len(os.listdir(path + '/' + j)))pic_path = path + '/' + j + '/' + i# img = img_resize(cv2.imread(pic_path, cv2.IMREAD_GRAYSCALE))img = img_resize(cv2.imread(pic_path))if img.shape[0] != 108 or img.shape[1] != 192:os.remove(pic_path)continueimgs.append(img)labels.append(k)pbar.close()k = k + 1pic_test_images = np.array(imgs)test_images = pic_test_images.reshape(pic_test_images.shape[0], 108, 192, 3) / 255.0return test_images, np.array(labels)def getModel(train_mode=False):"""获取模型:param train_mode: 是否训练:return: 模型"""# 如果训练if train_mode:# 初始化神经网络model = init_network()# 加载数据集train_images, train_labels, _ = getTrainData()test_images, test_labels, _ = getTestData()print(train_images.shape)print(train_labels.shape)print(test_images.shape)print(test_labels.shape)# 开始训练,训练二十次,显示日志信息model.fit(train_images, keras.utils.to_categorical(train_labels), batch_size=128, epochs=100, verbose=2)# 评估模型,不输出预测结果test_loss, test_acc = model.evaluate(test_images, keras.utils.to_categorical(test_labels), verbose=2)# 输出损失值print('测试集损失:', test_loss)# 输出正确率print('测试集正确率:', test_acc)# 保存模型model.save('.\\model\\expll.h5')return model, test_loss, test_accelse:# 加载模型model = tf.keras.models.load_model('.\\model\\780_3x3_1_3_100_expll.h5')# 打印模型信息model.summary()test_images, test_labels, _ = getTestData()# 评估模型,不输出预测结果test_loss, test_acc = model.evaluate(test_images, keras.utils.to_categorical(test_labels), verbose=2)# print([np.where(i == np.max(i))[0][0] for i in model.predict(test_images)])return model, test_loss, test_acc# 训练模型# if __name__ == '__main__':#model = getModel(True)

3.2.2 模型功能的封装,用于支持PyQt功能界面逻辑

import jsonimport osimport shutilimport numpy as npfrom PyQt5.QtCore import *import utilsfrom model import getModeldef getModelSupportTypes(data):"""获取模型支持的分类:return:"""temp = ''for i in data:temp = temp + ' ' + ireturn tempdef getModelInfo(loss, acc):"""获取模型信息:return: 模型测试准确度"""return '测试集损失:{:.3f}\n测试集准确率:{:.3f}%'.format(loss, acc * 100)class Service(QObject):signalRunTime = pyqtSignal(str, bool)model = NonesignalWorking = pyqtSignal(bool)loadModelStatus = FalsesignalModelInfo = pyqtSignal(str)signalModelSupportTypes = pyqtSignal(str)def __init__(self):super().__init__()def predict(self, imgs: np.array):"""预测:param imgs: 预测图片集:return: 预测结果"""rs = self.model.predict(imgs)return [np.where(i == np.max(i))[0][0] for i in rs]def iniModel(self):"""初始化加载模型"""if self.loadModelStatus:self.signalRunTime.emit('模型加载中···', False)returnself.loadModelStatus = Trueself.signalRunTime.emit('正在加载模型···', False)self.model, loss, acc = getModel()with open('model/model.json', 'r', encoding='utf8') as fp:info = json.load(fp)self.signalModelInfo.emit('方法:' + info['way'] + '\n' + getModelInfo(loss, acc))self.signalModelSupportTypes.emit(getModelSupportTypes(info['support']))self.signalRunTime.emit('模型加载完成', False)self.loadModelStatus = Falsedef startRun(self, window):"""开始进行分类:param window: 窗口对象"""if len(window.getFromPath()) == 0 or len(window.getTargetPath()) == 0:self.signalRunTime.emit('\n存在路径为空\n', False)self.signalWorking.emit(False)returnlist_path = []self.signalRunTime.emit('\n检索中······\n', False)utils.getListDir(window.fromPath.toPlainText(), window.getRecursionPathStatus(), list_path, imageCallback=None,dirCallback=lambda x: self.signalRunTime.emit('检索检索到目录: {0}\n'.format(x), False))self.signalRunTime.emit('检索完成,共计{0}张图片\n'.format(len(list_path)), False)if len(list_path) == 0:self.signalWorking.emit(False)returnself.signalRunTime.emit('开始读取图片······', False)img = utils.get_data(list_path, lambda x: self.signalRunTime.emit('已加载: {0}\n'.format(x), False))self.signalRunTime.emit('读取图片完成', False)self.signalRunTime.emit('维度信息:{0}'.format(img.shape), False)self.signalRunTime.emit('进行分类识别中······', False)rs = self.predict(img)self.signalRunTime.emit('分类识别完成\n***********\n识别结果:\n***********\n***********\n***********\n', False)with open('.\\model\\model.json', encoding='utf8') as fp:supportTypes = json.load(fp)['support']outRunInfo = '\n'for i in zip(list_path, rs):outRunInfo = outRunInfo + '路径: {0}; 结果:{1}\n\n'.format(i[0], supportTypes[i[1]])self.signalRunTime.emit(outRunInfo + '\n\n***********\n***********\n识别结果输出结束\n***********\n***********\n',False)targetPathRoot = window.getTargetPath()for i in supportTypes:if not os.path.exists(targetPathRoot + '/' + i):os.mkdir(targetPathRoot + '/' + i)self.signalRunTime.emit('\n\n开始进行分类迁移······', False)onlyMoveMax = window.getOnlyNumber()with open('.\\model\\model.json', encoding='utf8') as fp:supportTypes = json.load(fp)['support']for j in range(0, int(len(list_path) * 1.0 / onlyMoveMax + 1)):for i in list(zip(list_path, rs))[onlyMoveMax * j:onlyMoveMax * (j + 1)]:try:self.signalRunTime.emit('来源: {0}; 迁移至:{1}\n\n'.format(i[0], (targetPathRoot + '/' + supportTypes[i[1]])), False)shutil.move(i[0], targetPathRoot + '/' + supportTypes[i[1]])except Exception as e:self.signalRunTime.emit('ERROR: {0}'.format(e, False))self.signalRunTime.emit('\n\n迁移结束,任务完成\n\n', False)self.signalWorking.emit(False)

3.2.3 入口部分

# -*- coding: utf-8 -*-import osimport sysfrom concurrent.futures import ThreadPoolExecutorfrom PyQt5.QtWidgets import *import funfrom ui import Ui_FormthreadPool = ThreadPoolExecutor(max_workers=20)def openPath(callback):# 选择图片path = QFileDialog.getExistingDirectory(None, "选择存储文件夹", os.getcwd())if path == "":return 0callback(path)class MainWindow(QWidget, Ui_Form):service = Noneimg = Noneworking = Falsedef __init__(self, service_):super(MainWindow, self).__init__()self.service = service_self.setupUi(self)def openFromPath(self):"""选择来源路径"""openPath(callback=lambda x: self.fromPath.setText(x))def openTargetPath(self):"""选择输出路径"""openPath(callback=lambda x: self.targetPath.setText(x))def outRuntimeInfo(self, data, refresh=True):"""输出运行时:param data: 日志:param refresh: 追加或清空再输出"""if refresh:self.runtimeInfor.setText(data)else:self.runtimeInfor.setText(self.runtimeInfor.toPlainText() + '\n' + data)self.runtimeInfor.moveCursor(self.runtimeInfor.textCursor().End)def getFromPath(self):"""获取源路径:return: 源路径"""return self.fromPath.toPlainText()def getTargetPath(self):"""获取输出路径:return: 输出路径"""return self.targetPath.toPlainText()def outSupportTypes(self, data):"""输出模型支持的类型:param data: 类型串"""self.modelType.setText(data)def outModelInfo(self, data):"""输出模型信息:param data: 模型信息"""self.modelInfor.setText(data)def getOnlyNumber(self):"""单次处理图片数量:return: 数量"""return self.onlyNumber.value()def getRecursionPathStatus(self):"""是否递归目录"""return self.recursionPath.checkState() == 2def startRun(self):"""开始分类"""if self.working:self.outRuntimeInfo('任务执行中', False)returntry:threadPool.submit(service.startRun, self)except Exception as e:print(e)def setWorking(self, status):self.working = statusif __name__ == '__main__':service = fun.Service()app = QApplication(sys.argv)# 初始化窗口m = MainWindow(service)m.btu_selectFromPath.clicked.connect(m.openFromPath)m.btu_selectTargetPath.clicked.connect(m.openTargetPath)m.btu_startRun.clicked.connect(m.startRun)m.setWindowTitle('1920*1080壁纸分类')m.show()service.signalRunTime.connect(m.outRuntimeInfo)service.signalWorking.connect(m.setWorking)service.signalModelInfo.connect(m.outModelInfo)service.signalModelSupportTypes.connect(m.outSupportTypes)threadPool.submit(service.iniModel)sys.exit(app.exec_())

3.2.4 UI界面

4 结语

  虽说很简单,或许显得很那么······没用,但是也是自己的一个小成果,也算是又做了一个对自己有用的工具吧!

项目文件所在地址,内含训练好的模型,目前支持五种:/WindSnowLi/picture-classify

原文

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。