import sys; sys.path.append("/home/hxj/anaconda3/lib/python3.6/site-packages")from __future__ import print_functionimport osimport matplotlib as mplimport tarfileimport matplotlib.image as mpimgfrom matplotlib import pyplot as pltimport cv2import mxnet as mxfrom mxnet import gluonfrom mxnet import ndarray as ndfrom mxnet.gluon import nn, utilsfrom mxnet.gluon.nn import Dense, Activation, Conv2D, Conv2DTranspose, \ BatchNorm, LeakyReLU, Flatten, HybridSequential, HybridBlock, Dropoutfrom mxnet import autogradimport numpy as npepochs = 500batch_size = 10use_gpu = Truectx = mx.gpu() if use_gpu else mx.cpu()lr = 0.0002beta1 = 0.5#lambda1 = 100lambda1 = 10pool_size = 50
img_horizon = mx.image.HorizontalFlipAug(1)def load_retinex(batch_size): img_in_list = [] img_out_list = [] """ path='CAS/Lighting_aligned_128' ground_path = 'CAS/Lighting_aligned_128_retinex_to_color' for path, _, fnames in os.walk(path): for fname in fnames: if not fname.endswith('.png'): continue lingting_img = os.path.join(path, fname) ground_img = os.path.join(ground_path,fname) #补充水平翻转和光照增加或者减少50% img_arr_fname = mx.image.imread(lingting_img).astype(np.float32)/127.5 - 1 img_arr_fname_t = img_horizon(img_arr_fname) img_arr_gnema = mx.image.imread(ground_img).astype(np.float32)/127.5 - 1 img_arr_gnema_t = img_horizon(img_arr_gnema) img_arr_fname = cv2.cvtColor(img_arr_fname.asnumpy(), cv2.COLOR_RGB2LAB) img_arr_fname_t = cv2.cvtColor(img_arr_fname_t.asnumpy(), cv2.COLOR_RGB2LAB) img_arr_gnema = cv2.cvtColor(img_arr_gnema.asnumpy(), cv2.COLOR_RGB2LAB) img_arr_gnema_t = cv2.cvtColor(img_arr_gnema_t.asnumpy(), cv2.COLOR_RGB2LAB) img_arr_in, img_arr_out = [img_arr_fname[:,:,0].reshape((1,) + img_arr_in.shape), img_arr_out.reshape((1,) + img_arr_out.shape)] img_in_list.append(img_arr_in) img_out_list.append(img_arr_out) img_arr_in_t, img_arr_out_t = [nd.transpose(img_arr_fname_t, (2,0,1)), nd.transpose(img_arr_gnema_t, (2,0,1))] img_arr_in_t, img_arr_out_t = [img_arr_in_t.reshape((1,) + img_arr_in_t.shape), img_arr_out_t.reshape((1,) + img_arr_out_t.shape)] img_in_list.append(img_arr_in_t) img_out_list.append(img_arr_out_t) """ mulpath_lighting = 'MultiPIE/MultiPIE_Lighting/' mulpaht_ground = 'MultiPIE/MultiPIE_Lighting/' for path, _, fnames in os.walk(mulpath_lighting): for fname in fnames: num = fname[14:16] if num !='07': lingting_img = os.path.join(mulpath_lighting, fname) ground_img = os.path.join(mulpaht_ground,fname[:14]+'07.png') img_arr_fname = mx.image.imread(lingting_img).astype(np.float32)/127.5 - 1 img_arr_gnema = mx.image.imread(ground_img).astype(np.float32)/127.5 - 1 #img_arr_fname = mx.image.imresize(img_arr_fname,256,256) #img_arr_gnema = mx.image.imresize(img_arr_gnema,256,256) #补充水平翻转和光照增加或者减少50% #img_arr_fname_b = img_bright(img_arr_fname) img_arr_fname_t = img_horizon(img_arr_fname) img_arr_gnema_t = img_horizon(img_arr_gnema) #lighting image 共4个,normal ground truth共2个 img_arr_in, img_arr_out = [nd.transpose(img_arr_fname, (2,0,1)), nd.transpose(img_arr_gnema, (2,0,1))] img_arr_in, img_arr_out = [img_arr_in.reshape((1,) + img_arr_in.shape), img_arr_out.reshape((1,) + img_arr_out.shape)] img_in_list.append(img_arr_in) img_out_list.append(img_arr_out) img_arr_in_t, img_arr_out_t = [nd.transpose(img_arr_fname_t, (2,0,1)), nd.transpose(img_arr_gnema_t, (2,0,1))] img_arr_in_t, img_arr_out_t = [img_arr_in_t.reshape((1,) + img_arr_in_t.shape), img_arr_out_t.reshape((1,) + img_arr_out_t.shape)] img_in_list.append(img_arr_in_t) img_out_list.append(img_arr_out_t) return mx.io.NDArrayIter(data=[nd.concat(*img_in_list,dim=0), nd.concat(*img_out_list,dim=0)],batch_size=batch_size)
img_wd = 256img_ht = 256train_img_path = '../data/edges2handbags/train_mini/'val_img_path = '../data/edges2handbags/val/' def load_data(path, batch_size, is_reversed=False): img_in_list = [] img_out_list = [] for path, _, fnames in os.walk(path): for fname in fnames: if not fname.endswith('.jpg'): continue img = os.path.join(path, fname) img_arr = mx.image.imread(img).astype(np.float32)/127.5 - 1 img_arr = mx.image.imresize(img_arr, img_wd * 2, img_ht) # Crop input and output images img_arr_in, img_arr_out = [mx.image.fixed_crop(img_arr, 0, 0, img_wd, img_ht), mx.image.fixed_crop(img_arr, img_wd, 0, img_wd, img_ht)] img_arr_in, img_arr_out = [nd.transpose(img_arr_in, (2,0,1)), nd.transpose(img_arr_out, (2,0,1))] img_arr_in, img_arr_out = [img_arr_in.reshape((1,) + img_arr_in.shape), img_arr_out.reshape((1,) + img_arr_out.shape)] img_in_list.append(img_arr_out if is_reversed else img_arr_in) img_out_list.append(img_arr_in if is_reversed else img_arr_out) return mx.io.NDArrayIter(data=[nd.concat(*img_in_list, dim=0), nd.concat(*img_out_list, dim=0)], batch_size=batch_size)train_data = load_data(train_img_path, batch_size, is_reversed=False)val_data = load_data(val_img_path, batch_size, is_reversed=False)
img_horizon = mx.image.HorizontalFlipAug(1)def load_retinex(batch_size): img_in_list = [] img_out_list = [] path='CAS/Lighting_aligned_128' ground_path = 'CAS/Normal_aligned_128' img_in_list = [] img_out_list = [] """ for path, _, fnames in os.walk(path): for fname in fnames: if not fname.endswith('.png'): continue temp_name = fname[0:9]+'_IEU+00_PM+00_EN_A0_D0_T0_BB_M0_R0_S0.png' ground_img = os.path.join(ground_path, temp_name) if not os.path.exists(ground_img): temp_name = fname[0:9]+'_IEU+00_PM+00_EN_A0_D0_T0_BB_M0_R1_S0.png' ground_img = os.path.join(ground_path, temp_name) if not os.path.exists(ground_img): continue lingting_img = os.path.join(path, fname) #补充水平翻转和光照增加或者减少50% img_arr_fname = mx.image.imread(lingting_img).astype(np.float32)/127.5 - 1 img_arr_fname_t = img_horizon(img_arr_fname) img_arr_gnema = mx.image.imread(ground_img).astype(np.float32)/127.5 - 1 img_arr_gnema_t = img_horizon(img_arr_gnema) img_arr_in, img_arr_out = [nd.transpose(img_arr_fname, (2,0,1)), nd.transpose(img_arr_gnema, (2,0,1))] img_arr_in, img_arr_out = [img_arr_in.reshape((1,) + img_arr_in.shape), img_arr_out.reshape((1,) + img_arr_out.shape)] img_in_list.append(img_arr_in) img_out_list.append(img_arr_out) img_arr_in_t, img_arr_out_t = [nd.transpose(img_arr_fname_t, (2,0,1)), nd.transpose(img_arr_gnema_t, (2,0,1))] img_arr_in_t, img_arr_out_t = [img_arr_in_t.reshape((1,) + img_arr_in_t.shape), img_arr_out_t.reshape((1,) + img_arr_out_t.shape)] img_in_list.append(img_arr_in_t) img_out_list.append(img_arr_out_t) """ mulpath_lighting = 'MultiPIE/MultiPIE_Lighting_128/' mulpaht_ground = 'MultiPIE/MultiPIE_Lighting_128/' for path, _, fnames in os.walk(mulpath_lighting): for fname in fnames: num = fname[14:16] if num !='07': lingting_img = os.path.join(mulpath_lighting, fname) ground_img = os.path.join(mulpaht_ground,fname[:14]+'07.png') img_arr_fname = mx.image.imread(lingting_img).astype(np.float32)/127.5 - 1 img_arr_gnema = mx.image.imread(ground_img).astype(np.float32)/127.5 - 1 #img_arr_fname = mx.image.imresize(img_arr_fname,256,256) #img_arr_gnema = mx.image.imresize(img_arr_gnema,256,256) #补充水平翻转和光照增加或者减少50% #img_arr_fname_b = img_bright(img_arr_fname) img_arr_fname_t = img_horizon(img_arr_fname) img_arr_gnema_t = img_horizon(img_arr_gnema) #lighting image 共4个,normal ground truth共2个 img_arr_in, img_arr_out = [nd.transpose(img_arr_fname, (2,0,1)), nd.transpose(img_arr_gnema, (2,0,1))] img_arr_in, img_arr_out = [img_arr_in.reshape((1,) + img_arr_in.shape), img_arr_out.reshape((1,) + img_arr_out.shape)] img_in_list.append(img_arr_in) img_out_list.append(img_arr_out) img_arr_in_t, img_arr_out_t = [nd.transpose(img_arr_fname_t, (2,0,1)), nd.transpose(img_arr_gnema_t, (2,0,1))] img_arr_in_t, img_arr_out_t = [img_arr_in_t.reshape((1,) + img_arr_in_t.shape), img_arr_out_t.reshape((1,) + img_arr_out_t.shape)] img_in_list.append(img_arr_in_t) img_out_list.append(img_arr_out_t) return mx.io.NDArrayIter(data=[nd.concat(*img_in_list,dim=0), nd.concat(*img_out_list,dim=0)],batch_size=batch_size)
def visualize(img_arr): plt.imshow(((img_arr.asnumpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8)) plt.axis('off')def preview_train_data(train_data): img_in_list, img_out_list = train_data.next().data for i in range(4): plt.subplot(2,4,i+1) visualize(img_in_list[i]) plt.subplot(2,4,i+5) visualize(img_out_list[i]) plt.show()train_data = load_retinex(10)preview_train_data(train_data)
# Define Unet generator skip blockclass UnetSkipUnit(HybridBlock): def __init__(self, inner_channels, outer_channels, inner_block=None, innermost=False, outermost=False, use_dropout=False, use_bias=False): super(UnetSkipUnit, self).__init__() with self.name_scope(): self.outermost = outermost en_conv = Conv2D(channels=inner_channels, kernel_size=4, strides=2, padding=1, in_channels=outer_channels, use_bias=use_bias) en_relu = LeakyReLU(alpha=0.2) en_norm = BatchNorm(momentum=0.1, in_channels=inner_channels) de_relu = Activation(activation='relu') de_norm = BatchNorm(momentum=0.1, in_channels=outer_channels) if innermost: de_conv = Conv2DTranspose(channels=outer_channels, kernel_size=4, strides=2, padding=1, in_channels=inner_channels, use_bias=use_bias) encoder = [en_relu, en_conv] decoder = [de_relu, de_conv, de_norm] model = encoder + decoder elif outermost: de_conv = Conv2DTranspose(channels=outer_channels, kernel_size=4, strides=2, padding=1, in_channels=inner_channels * 2) encoder = [en_conv] decoder = [de_relu, de_conv, Activation(activation='tanh')] model = encoder + [inner_block] + decoder else: de_conv = Conv2DTranspose(channels=outer_channels, kernel_size=4, strides=2, padding=1, in_channels=inner_channels * 2, use_bias=use_bias) encoder = [en_relu, en_conv, en_norm] decoder = [de_relu, de_conv, de_norm] model = encoder + [inner_block] + decoder if use_dropout: model += [Dropout(rate=0.5)] self.model = HybridSequential() with self.model.name_scope(): for block in model: self.model.add(block) def hybrid_forward(self, F, x): if self.outermost: return self.model(x) else: return F.concat(self.model(x), x, dim=1)# Define Unet generatorclass UnetGenerator(HybridBlock): def __init__(self, in_channels, num_downs, ngf=64, use_dropout=True): super(UnetGenerator, self).__init__() #Build unet generator structure unet = UnetSkipUnit(ngf * 8, ngf * 8, innermost=True) for _ in range(num_downs - 5): unet = UnetSkipUnit(ngf * 8, ngf * 8, unet, use_dropout=use_dropout) unet = UnetSkipUnit(ngf * 8, ngf * 4, unet) unet = UnetSkipUnit(ngf * 4, ngf * 2, unet) unet = UnetSkipUnit(ngf * 2, ngf * 1, unet) unet = UnetSkipUnit(ngf, in_channels, unet, outermost=True) with self.name_scope(): self.model = unet def hybrid_forward(self, F, x): return self.model(x)# Define the PatchGAN discriminatorclass Discriminator(HybridBlock): def __init__(self, in_channels, ndf=64, n_layers=3, use_sigmoid=False, use_bias=False): super(Discriminator, self).__init__() with self.name_scope(): self.model = HybridSequential() kernel_size = 4 padding = int(np.ceil((kernel_size - 1)/2)) self.model.add(Conv2D(channels=ndf, kernel_size=kernel_size, strides=2, padding=padding, in_channels=in_channels)) self.model.add(LeakyReLU(alpha=0.2)) nf_mult = 1 for n in range(1, n_layers): nf_mult_prev = nf_mult nf_mult = min(2 ** n, 8) self.model.add(Conv2D(channels=ndf * nf_mult, kernel_size=kernel_size, strides=2, padding=padding, in_channels=ndf * nf_mult_prev, use_bias=use_bias)) self.model.add(BatchNorm(momentum=0.1, in_channels=ndf * nf_mult)) self.model.add(LeakyReLU(alpha=0.2)) nf_mult_prev = nf_mult nf_mult = min(2 ** n_layers, 8) self.model.add(Conv2D(channels=ndf * nf_mult, kernel_size=kernel_size, strides=1, padding=padding, in_channels=ndf * nf_mult_prev, use_bias=use_bias)) self.model.add(BatchNorm(momentum=0.1, in_channels=ndf * nf_mult)) self.model.add(LeakyReLU(alpha=0.2)) self.model.add(Conv2D(channels=1, kernel_size=kernel_size, strides=1, padding=padding, in_channels=ndf * nf_mult)) if use_sigmoid: self.model.add(Activation(activation='sigmoid')) def hybrid_forward(self, F, x): out = self.model(x) #print(out) return out
def param_init(param): if param.name.find('conv') != -1: if param.name.find('weight') != -1: param.initialize(init=mx.init.Normal(0.02), ctx=ctx) else: param.initialize(init=mx.init.Zero(), ctx=ctx) elif param.name.find('batchnorm') != -1: param.initialize(init=mx.init.Zero(), ctx=ctx) # Initialize gamma from normal distribution with mean 1 and std 0.02 if param.name.find('gamma') != -1: param.set_data(nd.random_normal(1, 0.02, param.data().shape))def network_init(net): with net.name_scope(): for param in net.collect_params().values(): param_init(param)def set_network(): # Pixel2pixel networks netG1 = UnetGenerator(in_channels=3, num_downs=6) netD1 = Discriminator(in_channels=6) netG2 = UnetGenerator(in_channels=3, num_downs=6) netD2 = Discriminator(in_channels=6) # Initialize parameters network_init(netG1) network_init(netD1) network_init(netG2) network_init(netD2) # trainer for the generator and the discriminator trainerG1 = gluon.Trainer(netG1.collect_params(), 'adam', { 'learning_rate': lr, 'beta1': beta1}) trainerD1 = gluon.Trainer(netD1.collect_params(), 'adam', { 'learning_rate': lr, 'beta1': beta1}) trainerG2 = gluon.Trainer(netG2.collect_params(), 'adam', { 'learning_rate': lr, 'beta1': beta1}) trainerD2 = gluon.Trainer(netD2.collect_params(), 'adam', { 'learning_rate': lr, 'beta1': beta1}) return netG1, netD1, trainerG1, trainerD1, netG2, netD2, trainerG2, trainerD2# Loss#GAN_loss = gluon.loss.SigmoidBinaryCrossEntropyLoss()GAN_loss = gluon.loss.L2Loss()L1_loss = gluon.loss.L1Loss()L2_loss = gluon.loss.L2Loss()netG1, netD1, trainerG1, trainerD1, netG2, netD2, trainerG2, trainerD2 = set_network()
class ImagePool(): def __init__(self, pool_size): self.pool_size = pool_size if self.pool_size > 0: self.num_imgs = 0 self.images = [] def query(self, images): if self.pool_size == 0: return images ret_imgs = [] for i in range(images.shape[0]): image = nd.expand_dims(images[i], axis=0) if self.num_imgs < self.pool_size: self.num_imgs = self.num_imgs + 1 self.images.append(image) ret_imgs.append(image) else: p = nd.random_uniform(0, 1, shape=(1,)).asscalar() if p > 0.5: random_id = nd.random_uniform(0, self.pool_size - 1, shape=(1,)).astype(np.uint8).asscalar() tmp = self.images[random_id].copy() self.images[random_id] = image ret_imgs.append(tmp) else: ret_imgs.append(image) ret_imgs = nd.concat(*ret_imgs, dim=0) return ret_imgs
#这是retinex使用的代码
def singleScaleRetinex(img, sigma): retinex = np.log10(img) - np.log10(cv2.GaussianBlur(img, (0, 0), sigma)) return retinexdef multiScaleRetinex(img, sigma_list): retinex = np.zeros_like(img) for sigma in sigma_list: retinex += singleScaleRetinex(img, sigma) retinex = retinex / len(sigma_list) return retinexdef colorRestoration(img, alpha, beta): img_sum = np.sum(img, axis=2, keepdims=True) color_restoration = beta * (np.log10(alpha * img) - np.log10(img_sum)) return color_restorationdef simplestColorBalance(img, low_clip, high_clip): total = img.shape[0] * img.shape[1] for i in range(img.shape[2]): unique, counts = np.unique(img[:, :, i], return_counts=True) current = 0 for u, c in zip(unique, counts): if float(current) / total < low_clip: low_val = u if float(current) / total < high_clip: high_val = u current += c img[:, :, i] = np.maximum(np.minimum(img[:, :, i], high_val), low_val) return img def MSRCR(img, sigma_list, G, b, alpha, beta, low_clip, high_clip): img = np.float64(img) + 1.0 img_retinex = multiScaleRetinex(img, sigma_list) img_color = colorRestoration(img, alpha, beta) img_msrcr = G * (img_retinex * img_color + b) for i in range(img_msrcr.shape[2]): img_msrcr[:, :, i] = (img_msrcr[:, :, i] - np.min(img_msrcr[:, :, i])) / \ (np.max(img_msrcr[:, :, i]) - np.min(img_msrcr[:, :, i])) * \ 255 img_msrcr = np.uint8(np.minimum(np.maximum(img_msrcr, 0), 255)) img_msrcr = simplestColorBalance(img_msrcr, low_clip, high_clip) return img_msrcrdef automatedMSRCR(img, sigma_list): img = np.float64(img) + 1.0 img_retinex = multiScaleRetinex(img, sigma_list) for i in range(img_retinex.shape[2]): unique, count = np.unique(np.int32(img_retinex[:, :, i] * 100), return_counts=True) for u, c in zip(unique, count): if u == 0: zero_count = c break low_val = unique[0] / 100.0 high_val = unique[-1] / 100.0 for u, c in zip(unique, count): if u < 0 and c < zero_count * 0.1: low_val = u / 100.0 if u > 0 and c < zero_count * 0.1: high_val = u / 100.0 break img_retinex[:, :, i] = np.maximum(np.minimum(img_retinex[:, :, i], high_val), low_val) img_retinex[:, :, i] = (img_retinex[:, :, i] - np.min(img_retinex[:, :, i])) / \ (np.max(img_retinex[:, :, i]) - np.min(img_retinex[:, :, i])) \ * 255 img_retinex = np.uint8(img_retinex) return img_retinexdef MSRCP(img, sigma_list, low_clip, high_clip): img = np.float64(img) + 1.0 intensity = np.sum(img, axis=2) / img.shape[2] retinex = multiScaleRetinex(intensity, sigma_list) intensity = np.expand_dims(intensity, 2) retinex = np.expand_dims(retinex, 2) intensity1 = simplestColorBalance(retinex, low_clip, high_clip) intensity1 = (intensity1 - np.min(intensity1)) / \ (np.max(intensity1) - np.min(intensity1)) * \ 255.0 + 1.0 img_msrcp = np.zeros_like(img) for y in range(img_msrcp.shape[0]): for x in range(img_msrcp.shape[1]): B = np.max(img[y, x]) A = np.minimum(256.0 / B, intensity1[y, x, 0] / intensity[y, x, 0]) img_msrcp[y, x, 0] = A * img[y, x, 0] img_msrcp[y, x, 1] = A * img[y, x, 1] img_msrcp[y, x, 2] = A * img[y, x, 2] img_msrcp = np.uint8(img_msrcp - 1.0) return img_msrcp
#预训练
from datetime import datetimeimport timeimport loggingdef facc(label, pred): pred = pred.ravel() label = label.ravel() return ((pred > 0.5) == label).mean()def pre_train(): metric = mx.metric.CustomMetric(facc) stamp = datetime.now().strftime('%Y_%m_%d-%H_%M') logging.basicConfig(level=logging.DEBUG) for epoch in range(epochs): tic = time.time() btic = time.time() train_data.reset() iter = 0 for batch in train_data: ############################ # (1) Update D network: maximize log(D(x, y)) + log(1 - D(x, G(x, z))) ########################### real_in = batch.data[0].as_in_context(ctx) real_out = batch.data[1].as_in_context(ctx) with autograd.record(): fake_out = netG1(real_in) errG1 = L1_loss(fake_out, real_out)*lambda1 #errG1 = land_mark_errs(real_in, fake_out) errG1.backward() trainerG1.step(batch.data[0].shape[0]) with autograd.record(): fake_out2 = netG2(real_out) errG2 = L1_loss(fake_out2, real_in)*lambda1 errG2.backward() trainerG2.step(batch.data[0].shape[0]) # Print log infomation every ten batches if iter % 10 == 0: name, acc = metric.get() logging.info('speed: {} samples/s'.format(batch_size / (time.time() - btic))) logging.info('G1generator1 loss = %f, binary training acc = %f at iter %d epoch %d' %(nd.mean(errG1).asscalar(), acc, iter, epoch)) logging.info('G1generator2 loss = %f, binary training acc = %f at iter %d epoch %d' %(nd.mean(errG2).asscalar(), acc, iter, epoch)) iter = iter + 1 btic = time.time() name, acc = metric.get() metric.reset() logging.info('\nbinary training acc at epoch %d: %s=%f' % (epoch, name, acc)) logging.info('time: %f' % (time.time() - tic)) # Visualize one generated image for each epoch fake_img = fake_out[0] visualize(fake_img) plt.show() #fake_img2 = fake_out2[0] #visualize(fake_img2) #plt.show()pre_train()
def save_data(path,tpath): img_in_list = [] img_out_list = [] for path, _, fnames in os.walk(path): for fname in fnames: if not fname.endswith('.jpg'): continue img = os.path.join(path, fname) img_arr = mx.image.imread(img).astype(np.float32)/127.5 - 1 img_arr = mx.image.imresize(img_arr, img_wd * 2, img_ht) # Crop input and output images img_arr_in, img_arr_out = [mx.image.fixed_crop(img_arr, 0, 0, img_wd, img_ht), mx.image.fixed_crop(img_arr, img_wd, 0, img_wd, img_ht)] #img_arr_in = mx.image.imresize(img_arr_in,128,128) #img_arr_out = mx.image.imresize(img_arr_out,128,128) img_arr_in, img_arr_out = [nd.transpose(img_arr_in, (2,0,1)), nd.transpose(img_arr_out, (2,0,1))] img_arr_in, img_arr_out = [img_arr_in.reshape((1,) + img_arr_in.shape), img_arr_out.reshape((1,) + img_arr_out.shape)] img_out = netG1(img_arr_out.as_in_context(ctx)) img_out1 = img_out[0] img_out2 = ((img_out1.asnumpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8) save_name = tpath+fname cv2.imwrite(save_name, img_out2)save_data("../data/edges2handbags/val/","../data/edges2handbags/G1andG2/")
netD1 = Discriminator(in_channels=6)netD2 = Discriminator(in_channels=6)network_init(netD1)network_init(netD2)trainerD1 = gluon.Trainer(netD1.collect_params(), 'adam', { 'learning_rate': lr, 'beta1': beta1})trainerD2 = gluon.Trainer(netD2.collect_params(), 'adam', { 'learning_rate': lr, 'beta1': beta1})
from datetime import datetimeimport timeimport loggingdef facc(label, pred): pred = pred.ravel() label = label.ravel() return ((pred > 0.5) == label).mean()def dual_pre_train(): metric = mx.metric.CustomMetric(facc) stamp = datetime.now().strftime('%Y_%m_%d-%H_%M') logging.basicConfig(level=logging.DEBUG) for epoch in range(epochs): tic = time.time() btic = time.time() PIE_normal_to_lighting.reset() iter = 0 for (batch1, batch2) in zip(retinex_data,PIE_normal_to_lighting): ############################ # (1) Update D network: maximize log(D(x, y)) + log(1 - D(x, G(x, z))) ########################### real_in = batch1.data[0].as_in_context(ctx) real_out = batch1.data[1].as_in_context(ctx) lighing_bad = batch2.data[0].as_in_context(ctx) lighing_good = batch2.data[1].as_in_context(ctx) with autograd.record(): fake_out = netG1(real_in) #errG1 = L1_loss(real_out, fake_out) + L1_loss(netG1(netG2(fake_out)),real_out) errG1 = L1_loss(real_in, fake_out)+L1_loss(netG1(netG2(lighing_good)), lighing_good) #增加一个三方loss #errG1 = L1_loss(real_out, fake_out) + L1_loss(netG1(netG2(fake_out)),real_out) # + L1_loss(netG1(netG2(fake_out)),fake_out) errG1.backward() trainerG1.step(batch1.data[0].shape[0]) with autograd.record(): fake_out3 = netG2(real_out) #errG2 = L1_loss(real_in, fake_out3) + L1_loss(netG2(netG1(fake_out3)),real_in) errG2 = L1_loss(lighing_good, fake_out3)+L1_loss(netG2(netG1(real_in)), real_in) #增加一个三方loss #errG2 = L1_loss(real_in, fake_out3) + L1_loss(netG2(netG1(fake_out3)),real_in) # + L1_loss(netG2(netG1(fake_out3)),fake_out3) errG2.backward() trainerG2.step(batch2.data[0].shape[0]) # Print log infomation every ten batches if iter % 10 == 0: name, acc = metric.get() logging.info('speed: {} samples/s'.format(batch_size / (time.time() - btic))) logging.info('G1generator loss = %f, binary training acc = %f at iter %d epoch %d' %(nd.mean(errG1).asscalar(), acc, iter, epoch)) logging.info('G2generator loss = %f, binary training acc = %f at iter %d epoch %d' %(nd.mean(errG2).asscalar(), acc, iter, epoch)) iter = iter + 1 btic = time.time() name, acc = metric.get() metric.reset() logging.info('\nbinary training acc at epoch %d: %s=%f' % (epoch, name, acc)) logging.info('time: %f' % (time.time() - tic)) # Visualize one generated image for each epoch fake_img = fake_out[0] visualize(fake_img) plt.show()dual_pre_train()
def test_netG(Spath,Tpath): for path, _, fnames in os.walk(Spath): for fname in fnames: if not fname.endswith('.png'): continue #num = fname[14:16] #if num !='07': #continue test_img = os.path.join(path, fname) img_fname = mx.image.imread(test_img) img_arr_fname = img_fname.astype(np.float32)/127.5 - 1 img_arr_fname = mx.image.imresize(img_arr_fname,128,128) img_arr_in = nd.transpose(img_arr_fname, (2,0,1)) img_arr_in = img_arr_in.reshape((1,) + img_arr_in.shape) img_out = netG1(img_arr_in.as_in_context(ctx)) img_out = img_out[0] #img_out = mx.image.imresize(img_out,120,165) save_name = Tpath+ fname plt.imsave(save_name, ((img_out.asnumpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8) ) #test_netG('MultiPIE/MultiPIE_test_128_Gray/','MultiPIE/relighting/')test_netG('MultiPIE/Bio_relighing2/','MultiPIE/Bio_color/')
#使用opencv的人脸特征点作为损失
fileDir = '/home/hxj/gluon-tutorials/GAN/openface/'sys.path.append(os.path.join(fileDir))import argparseimport cv2import dlibimport matplotlib.pyplot as pltfrom pylab import plot from openface.align_dlib import AlignDlibmodelDir = os.path.join(fileDir, 'models')openfaceModelDir = os.path.join(modelDir, 'openface')dlibModelDir = os.path.join(modelDir, 'dlib')dlibFacePredictor= os.path.join(dlibModelDir, "shape_predictor_68_face_landmarks.dat")def land_mark_errs(batch1,batch2): align = AlignDlib(dlibFacePredictor) sum_err = nd.zeros((10)).as_in_context(ctx) i=0 for (x,y) in zip(batch1,batch2): x1 = ((x.asnumpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8) y1 = ((y.asnumpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8) """ bbx = align.getLargestFaceBoundingBox(x1) if bbx is None: x1_r = MSRCR(x1,[15, 80, 250], 5.0, 25.0, 125.0, 46.0, 0.01, 0.99) bbx = align.getLargestFaceBoundingBox(x1_r) if bbx is None: lab = cv2.cvtColor(x1,cv2.COLOR_RGB2LUV) bbx = align.getLargestFaceBoundingBox(lab) #if bbx is None: #print('bbx is none') bby = align.getLargestFaceBoundingBox(y1) if bby is None: y1_r = MSRCR(y1,[15, 80, 250], 5.0, 25.0, 125.0, 46.0, 0.01, 0.99) bby = align.getLargestFaceBoundingBox(y1_r) if bby is None: lab = cv2.cvtColor(y1, cv2.COLOR_RGB2LUV) bby = align.getLargestFaceBoundingBox(lab) if bby is None: #print('bby is none') if bby is None: continue if bbx is None: #bbx= bby continue """ bbx = dlib.rectangle(-19, -19, 124, 125) bby = dlib.rectangle(-19, -19, 124, 125) landmarks_x = nd.array(align.findLandmarks(x1, bbx)) landmarks_y = nd.array(align.findLandmarks(y1, bby)) if landmarks_x is None: continue if landmarks_y is None: continue sum_err[i]=nd.sum(nd.abs(landmarks_x -landmarks_y))/68 i+=1 return sum_err
from datetime import datetimeimport timeimport loggingdef facc(label, pred): pred = pred.ravel() label = label.ravel() return ((pred > 0.5) == label).mean()def generate_train_single(): image_pool = ImagePool(pool_size) metric = mx.metric.CustomMetric(facc) stamp = datetime.now().strftime('%Y_%m_%d-%H_%M') logging.basicConfig(level=logging.DEBUG) for epoch in range(epochs): tic = time.time() btic = time.time() train_data.reset() iter = 0 for batch1 in train_data: #for batch in range(400): ############################ # (1) Update D network: maximize log(D(x, y)) + log(1 - D(x, G(x, z))) ########################### real_in = batch1.data[0].as_in_context(ctx) #将train数据的输入和输出调出来 real_out = batch1.data[1].as_in_context(ctx) #G1 fake_out = netG1(real_in) with autograd.record(): errG1 = L1_loss(real_out, fake_out)* 20 + L1_loss(netG1(netG2(real_out)), real_out) *10 #land_mark_errs(real_in, fake_out)*0.4 errG1.backward() trainerG1.step(batch1.data[0].shape[0]) #G2 fake_out2 = netG2(real_out) with autograd.record(): errG2 = L1_loss(real_in, fake_out2)* 20 + L1_loss(netG2(netG1(real_in)), real_in)*10 #land_mark_errs(real_out, fake_out2)*0.4 trainerG2.step(batch1.data[0].shape[0]) # Print log infomation every ten batches if iter % 10 == 0: name, acc = metric.get() logging.info('speed: {} samples/s'.format(batch_size / (time.time() - btic))) logging.info('generator1 loss = %f, binary training acc = %f at iter %d epoch %d' %(nd.mean(errG1).asscalar(), acc, iter, epoch)) logging.info('generator2 loss = %f, binary training acc = %f at iter %d epoch %d' %(nd.mean(errG2).asscalar(), acc, iter, epoch)) iter = iter + 1 btic = time.time() name, acc = metric.get() metric.reset() logging.info('\nbinary training acc at epoch %d: %s=%f' % (epoch, name, acc)) logging.info('time: %f' % (time.time() - tic)) # Visualize one generated image for each epoch fake_img = fake_out[0] visualize(fake_img) plt.show() generate_train_single()
from skimage import iobgrImg = cv2.imread('CAS/test_aligned_128/FM_000046_IFD+90_PM+00_EN_A0_D0_T0_BW_M0_R1_S0.png')rgbImg = cv2.cvtColor(bgrImg, cv2.COLOR_BGR2RGB)plt.imshow(rgbImg)plt.show()lab = cv2.cvtColor(bgrImg, cv2.COLOR_BGR2LAB)plt.imshow(lab)plt.show()img_test = lab[:,:,0].astype(np.float32)/127.5 - 1img_test = nd.array(img_test)img_arr_in= img_test.reshape((1,1,) + img_test.shape).as_in_context(ctx)test1 = netG1(img_arr_in)test2 = test1[0][0]cv2.imshow(((test2.asnumpy() + 1.0) * 127.5).astype(np.uint8))
from datetime import datetimeimport timeimport loggingdef facc(label, pred): pred = pred.ravel() label = label.ravel() return ((pred > 0.5) == label).mean()def Dual_train_single(): image_pool = ImagePool(pool_size) metric = mx.metric.CustomMetric(facc) stamp = datetime.now().strftime('%Y_%m_%d-%H_%M') logging.basicConfig(level=logging.DEBUG) for epoch in range(epochs): tic = time.time() btic = time.time() train_data.reset() iter = 0 for batch1 in train_data: #for batch in range(400): ############################ # (1) Update D network: maximize log(D(x, y)) + log(1 - D(x, G(x, z))) ########################### real_in = batch1.data[0].as_in_context(ctx) #将train数据的输入和输出调出来 real_out = batch1.data[1].as_in_context(ctx) #D1 fake_out = netG1(real_in) fake_concat = image_pool.query(nd.concat(real_in, fake_out, dim=1)) with autograd.record(): output = netD1(fake_concat) #output = netD1(fake_out) fake_label = nd.zeros(output.shape, ctx=ctx) errD_fake = GAN_loss(output, fake_label) metric.update([fake_label,], [output,]) # Train with real image real_concat = image_pool.query(nd.concat(real_in, real_out, dim=1)) #ground truth 也要经过G1 output = netD1(real_concat) real_label = nd.ones(output.shape, ctx=ctx) errD_real = GAN_loss(output, real_label) errD1 = (errD_real + errD_fake) *0.5 errD1.backward() metric.update([real_label,], [output,]) trainerD1.step(batch1.data[0].shape[0]) #G1 with autograd.record(): #fake_out = netG1(real_in) fake_concat = image_pool.query(nd.concat(real_in, fake_out, dim=1)) output = netD1(fake_concat) #output = netD1(fake_out) real_label = nd.ones(output.shape, ctx=ctx) #errG1 = GAN_loss(output, real_label) + L1_loss(real_out, fake_out) * lambda1+ \ #L1_loss(netG2(netG1(real_in)), real_in) * lambda1 #errG1 = GAN_loss(output, real_label) + L1_loss(real_in, fake_out) * lambda1+ \ #L1_loss(netG1(netG2(fake_out)), fake_out) * lambda1 errG1 = GAN_loss(output, real_label) + L1_loss(real_out, fake_out) * 20+ \ L1_loss(netG1(netG2(real_out)), real_out) *10 #land_mark_errs(real_out, fake_out) errG1.backward() trainerG1.step(batch1.data[0].shape[0]) #D2 fake_out2 = netG2(real_out) fake_concat2 = image_pool.query(nd.concat(real_out, fake_out2, dim=1)) with autograd.record(): output2 = netD2(fake_concat2) fake_label2 = nd.zeros(output2.shape, ctx=ctx) errD_fake2 = GAN_loss(output2, fake_label2) metric.update([fake_label2,], [output2,]) # Train with real image real_concat2 = image_pool.query(nd.concat(real_out, real_in, dim=1)) output2 = netD2(real_concat2) real_label2 = nd.ones(output2.shape, ctx=ctx) errD_real2 = GAN_loss(output2, real_label2) errD2 = (errD_real2 + errD_fake2) * 0.5 errD2.backward() metric.update([real_label2,], [output2,]) trainerD2.step(batch1.data[0].shape[0]) #G2 with autograd.record(): #fake_out2 = netG2(real_out) fake_concat2 = image_pool.query(nd.concat(real_out, fake_out2, dim=1)) output2 = netD2(fake_concat2) real_label2 = nd.ones(output2.shape, ctx=ctx) #errG2 = GAN_loss(output2, real_label2)+ L1_loss(real_in, fake_out2) * lambda1+ \ #L1_loss(netG1(netG2(lighing_good)), lighing_good) * lambda1 errG2 = GAN_loss(output2, real_label2)+ L1_loss(real_in, fake_out2) * 20+ \ L1_loss(netG2(netG1(real_in)), real_in) *10 #land_mark_errs(real_in, fake_out2) errG2.backward() trainerG2.step(batch1.data[0].shape[0]) # Print log infomation every ten batches if iter % 10 == 0: name, acc = metric.get() logging.info('speed: {} samples/s'.format(batch_size / (time.time() - btic))) logging.info('discriminator1 loss = %f, generator1 loss = %f, binary training acc = %f at iter %d epoch %d' %(nd.mean(errD1).asscalar(), nd.mean(errG1).asscalar(), acc, iter, epoch)) logging.info('discriminator2 loss = %f, generator2 loss = %f, binary training acc = %f at iter %d epoch %d' %(nd.mean(errD2).asscalar(), nd.mean(errG2).asscalar(), acc, iter, epoch)) iter = iter + 1 btic = time.time() name, acc = metric.get() metric.reset() logging.info('\nbinary training acc at epoch %d: %s=%f' % (epoch, name, acc)) logging.info('time: %f' % (time.time() - tic)) # Visualize one generated image for each epoch fake_img = fake_out[0] visualize(fake_img) plt.show() Dual_train_single()
from datetime import datetimeimport timeimport loggingdef facc(label, pred): pred = pred.ravel() label = label.ravel() return ((pred > 0.5) == label).mean()def train(): #image_pool = ImagePool(pool_size) metric = mx.metric.CustomMetric(facc) stamp = datetime.now().strftime('%Y_%m_%d-%H_%M') logging.basicConfig(level=logging.DEBUG) for epoch in range(epochs): tic = time.time() btic = time.time() retinex_data.reset() PIE_normal_to_lighting.reset() iter = 0 for (batch1, batch2) in zip(retinex_data,PIE_normal_to_lighting): #for batch in range(400): ############################ # (1) Update D network: maximize log(D(x, y)) + log(1 - D(x, G(x, z))) ########################### real_in = batch1.data[0].as_in_context(ctx) #将train数据的输入和输出调出来 real_out = batch1.data[1].as_in_context(ctx) lighing_bad = batch2.data[0].as_in_context(ctx) lighing_good = batch2.data[1].as_in_context(ctx) fake_out = netG1(real_in) #D1 with autograd.record(): #fake_concat = image_pool.query(nd.concat(real_in, fake_out, dim=1)) #output = netD1(fake_concat) output = netD1(fake_out) fake_label = nd.zeros(output.shape, ctx=ctx) errD_fake = GAN_loss(output, fake_label) metric.update([fake_label,], [output,]) # Train with real image #real_concat = image_pool.query(nd.concat(real_in, lighing_good, dim=1)) output = netD1(lighing_good) real_label = nd.ones(output.shape, ctx=ctx) errD_real = GAN_loss(output, real_label) errD1 = (errD_real + errD_fake) * 0.5 errD1.backward() metric.update([real_label,], [output,]) trainerD1.step(batch1.data[0].shape[0]) #G1 with autograd.record(): #fake_concat = image_pool.query(nd.concat(real_in, fake_out, dim=1)) #output = netD1(fake_concat) fake_out = netG1(real_in) output = netD1(fake_out) real_label = nd.ones(output.shape, ctx=ctx) #errG1 = GAN_loss(output, real_label) + L1_loss(real_out, fake_out) * lambda1+ \ #L1_loss(netG2(netG1(real_in)), real_in) * lambda1 #errG1 = GAN_loss(output, real_label) + L1_loss(real_in, fake_out) * lambda1+ \ #L1_loss(netG1(netG2(fake_out)), fake_out) * lambda1 errG1 = GAN_loss(output, real_label) + L1_loss(real_in, fake_out) * lambda1+ \ L1_loss(netG1(netG2(lighing_good)), lighing_good) * lambda1 errG1.backward() trainerG1.step(batch1.data[0].shape[0]) #D2 fake_out2 = netG2(lighing_good) with autograd.record(): #fake_concat2 = image_pool.query(nd.concat(lighing_good, fake_out2, dim=1)) output2 = netD2(fake_out2) fake_label2 = nd.zeros(output2.shape, ctx=ctx) errD_fake2 = GAN_loss(output2, fake_label2) metric.update([fake_label2,], [output2,]) # Train with real image #real_concat2 = image_pool.query(nd.concat(lighing_good, real_in, dim=1)) output2 = netD2(real_in) real_label2 = nd.ones(output2.shape, ctx=ctx) errD_real2 = GAN_loss(output2, real_label2) errD2 = (errD_real2 + errD_fake2) * 0.5 errD2.backward() metric.update([real_label2,], [output2,]) trainerD2.step(batch2.data[0].shape[0]) #G2 with autograd.record(): fake_out2 = netG2(lighing_good) #fake_concat2 = image_pool.query(nd.concat(lighing_good, fake_out2, dim=1)) output2 = netD2(fake_out2) real_label2 = nd.ones(output2.shape, ctx=ctx) #errG2 = GAN_loss(output2, real_label2)+ L1_loss(real_in, fake_out2) * lambda1+ \ #L1_loss(netG1(netG2(lighing_good)), lighing_good) * lambda1 errG2 = GAN_loss(output2, real_label2)+ L1_loss(lighing_good, fake_out2) * lambda1+ \ L1_loss(netG2(netG1(real_in)), real_in) * lambda1 errG2.backward() trainerG2.step(batch2.data[0].shape[0]) # Print log infomation every ten batches if iter % 10 == 0: name, acc = metric.get() logging.info('speed: {} samples/s'.format(batch_size / (time.time() - btic))) logging.info('discriminator1 loss = %f, generator1 loss = %f, binary training acc = %f at iter %d epoch %d' %(nd.mean(errD1).asscalar(), nd.mean(errG1).asscalar(), acc, iter, epoch)) logging.info('discriminator2 loss = %f, generator2 loss = %f, binary training acc = %f at iter %d epoch %d' %(nd.mean(errD2).asscalar(), nd.mean(errG2).asscalar(), acc, iter, epoch)) iter = iter + 1 btic = time.time() name, acc = metric.get() metric.reset() logging.info('\nbinary training acc at epoch %d: %s=%f' % (epoch, name, acc)) logging.info('time: %f' % (time.time() - tic)) # Visualize one generated image for each epoch fake_img = fake_out[0] visualize(fake_img) plt.show() train()