2000字范文,分享全网优秀范文,学习好帮手!
2000字范文 > Python深度学习(11):GAN生成青蛙图片

Python深度学习(11):GAN生成青蛙图片

时间:2023-11-29 21:49:57

相关推荐

Python深度学习(11):GAN生成青蛙图片

算法简介

GAN最直观的解释就是博弈,同时训练两个网络(生成网络和判别网络),二者训练都是为了打败彼此。生成网络将随机潜在向量转换为图像,判别器试图分辨真实图像与生成图像。

生成网络:以一张随机向量作为输入,解码为合成图像

判别网络:输入图像,输出真或假的类别

实验中搭建的是DCGAN深度卷积生成式对抗网络,即生成网络和判别网络都是深度卷积网络,具体实现流程如下:

(1)潜在空间抽取随机噪声

(2)生成网络利用这些随机噪声生成图像

(3)将生成图像与真实图像打上标签,并混合,

(4)利用混合后的图像集去训练判别网络

(5)回到(1)

代码实现

import kerasfrom keras import layersimport numpy as npimport osfrom keras.preprocessing import imagelatent_dim = 32height = 32width = 32channels = 3generator_input = keras.Input(shape=(latent_dim,))# 生成模型# 将输入转换为16*16的128个通道的特征图x = layers.Dense(128 * 16 * 16)(generator_input)x = layers.LeakyReLU()(x)x = layers.Reshape((16, 16, 128))(x)x = layers.Conv2D(256, 5, padding='same')(x)x = layers.LeakyReLU()(x)x = layers.Conv2DTranspose(256, 4, strides=2, padding='same')(x) # 使用Conv2DTranspose层对图像进行上采样x = layers.LeakyReLU()(x)x = layers.Conv2D(256, 5, padding='same')(x)x = layers.LeakyReLU()(x)x = layers.Conv2D(256, 5, padding='same')(x)x = layers.LeakyReLU()(x)x = layers.Conv2D(channels, 7, activation='tanh', padding='same')(x)generator = keras.models.Model(generator_input, x) # 将生成器实例化(由向量映射到图像)print(generator.summary())# 判别器模型discriminator_input = layers.Input(shape=(height, width, channels))x = layers.Conv2D(128, 3)(discriminator_input)x = layers.LeakyReLU()(x)x = layers.Conv2D(128, 4, strides=2)(x)x = layers.LeakyReLU()(x)x = layers.Conv2D(128, 4, strides=2)(x)x = layers.LeakyReLU()(x)x = layers.Flatten()(x)x = layers.Dropout(0.4)(x)x = layers.Dense(1, activation='sigmoid')(x)discriminator = keras.models.Model(discriminator_input, x) # 将判别器模型实例化discriminator.summary()discriminator_optimizer = keras.optimizers.RMSprop(lr=0.0008, clipvalue=1.0, decay=1e-8)pile(optimizer=discriminator_optimizer, loss='binary_crossentropy')discriminator.trainable = False # 将判别器权重设置为不可训练gan_input = keras.Input(shape=(latent_dim,))gan_output = discriminator(generator(gan_input))gan = keras.models.Model(gan_input, gan_output) # gan模型实例化gan_optimizer = keras.optimizers.RMSprop(lr=0.0004, clipvalue=1.0, decay=1e-8)pile(optimizer=gan_optimizer, loss='binary_crossentropy')(x_train, y_train), (_, _) = keras.datasets.cifar10.load_data() # 加载cifar10训练集x_train = x_train[y_train.flatten() == 6] # 选择其中的青蛙图像x_train = x_train.reshape((x_train.shape[0],) + (height, width, channels)).astype('float32') / 255. # 数据标准化iterations = 10000batch_size = 20save_dir = './gan_png' # 指定保存生成图像的目录start = 0for step in range(iterations):random_latent_vectors = np.random.normal(size=(batch_size, latent_dim)) # 在潜在空间采样随机点generated_images = generator.predict(random_latent_vectors) # 将随机点解码为图像stop = start + batch_sizereal_images = x_train[start: stop]combined_images = np.concatenate([generated_images, real_images]) # 将生成的图像与真实图像混合labels = np.concatenate([np.ones((batch_size, 1)), np.zeros((batch_size, 1))]) # 所有图像的标签labels += 0.05 * np.random.random(labels.shape) # 向标签中添加随机噪声d_loss = discriminator.train_on_batch(combined_images, labels) # 训练判别器random_latent_vectors = np.random.normal(size=(batch_size, latent_dim)) # 在潜在空间采样随机点misleading_targets = np.zeros((batch_size, 1)) # 合并标签a_loss = gan.train_on_batch(random_latent_vectors, misleading_targets) # 冻结判别器,训练生成器# 保存与展示start += batch_sizeif start > len(x_train) - batch_size:start = 0if step % 100 == 0:gan.save_weights('gan.h5')print('discriminator loss:', d_loss)print('adversarial loss:', a_loss)img = image.array_to_img(generated_images[0] * 255., scale=False)img.save(os.path.join(save_dir, 'generated_frog' + str(step) + '.png'))img = image.array_to_img(real_images[0] * 255., scale=False)img.save(os.path.join(save_dir, 'real_frog' + str(step) + '.png'))

运行结果:

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。