博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
Sym-GAN
阅读量:5339 次
发布时间:2019-06-15

本文共 47593 字,大约阅读时间需要 158 分钟。

import sys; sys.path.append("/home/hxj/anaconda3/lib/python3.6/site-packages")from __future__ import print_functionimport osimport matplotlib as mplimport tarfileimport matplotlib.image as mpimgfrom matplotlib import pyplot as pltimport cv2import mxnet as mxfrom mxnet import gluonfrom mxnet import ndarray as ndfrom mxnet.gluon import nn, utilsfrom mxnet.gluon.nn import Dense, Activation, Conv2D, Conv2DTranspose, \    BatchNorm, LeakyReLU, Flatten, HybridSequential, HybridBlock, Dropoutfrom mxnet import autogradimport numpy as npepochs = 500batch_size = 10use_gpu = Truectx = mx.gpu() if use_gpu else mx.cpu()lr = 0.0002beta1 = 0.5#lambda1 = 100lambda1 = 10pool_size = 50
img_horizon = mx.image.HorizontalFlipAug(1)def load_retinex(batch_size):    img_in_list = []    img_out_list = []    """    path='CAS/Lighting_aligned_128'    ground_path = 'CAS/Lighting_aligned_128_retinex_to_color'        for path, _, fnames in os.walk(path):        for fname in fnames:            if not fname.endswith('.png'):                continue                                  lingting_img = os.path.join(path, fname)            ground_img = os.path.join(ground_path,fname)                                #补充水平翻转和光照增加或者减少50%            img_arr_fname = mx.image.imread(lingting_img).astype(np.float32)/127.5 - 1            img_arr_fname_t = img_horizon(img_arr_fname)            img_arr_gnema = mx.image.imread(ground_img).astype(np.float32)/127.5 - 1            img_arr_gnema_t = img_horizon(img_arr_gnema)                        img_arr_fname = cv2.cvtColor(img_arr_fname.asnumpy(), cv2.COLOR_RGB2LAB)            img_arr_fname_t = cv2.cvtColor(img_arr_fname_t.asnumpy(), cv2.COLOR_RGB2LAB)            img_arr_gnema = cv2.cvtColor(img_arr_gnema.asnumpy(), cv2.COLOR_RGB2LAB)            img_arr_gnema_t = cv2.cvtColor(img_arr_gnema_t.asnumpy(), cv2.COLOR_RGB2LAB)                                    img_arr_in, img_arr_out = [img_arr_fname[:,:,0].reshape((1,) + img_arr_in.shape),                                       img_arr_out.reshape((1,) + img_arr_out.shape)]            img_in_list.append(img_arr_in)            img_out_list.append(img_arr_out)                        img_arr_in_t, img_arr_out_t = [nd.transpose(img_arr_fname_t, (2,0,1)),                                           nd.transpose(img_arr_gnema_t, (2,0,1))]            img_arr_in_t, img_arr_out_t = [img_arr_in_t.reshape((1,) + img_arr_in_t.shape),                                           img_arr_out_t.reshape((1,) + img_arr_out_t.shape)]            img_in_list.append(img_arr_in_t)            img_out_list.append(img_arr_out_t)    """           mulpath_lighting = 'MultiPIE/MultiPIE_Lighting/'    mulpaht_ground = 'MultiPIE/MultiPIE_Lighting/'    for path, _, fnames in os.walk(mulpath_lighting):        for fname in fnames:            num = fname[14:16]            if num !='07':                lingting_img = os.path.join(mulpath_lighting, fname)                ground_img = os.path.join(mulpaht_ground,fname[:14]+'07.png')                img_arr_fname = mx.image.imread(lingting_img).astype(np.float32)/127.5 - 1                img_arr_gnema = mx.image.imread(ground_img).astype(np.float32)/127.5 - 1                                                  #img_arr_fname = mx.image.imresize(img_arr_fname,256,256)            #img_arr_gnema = mx.image.imresize(img_arr_gnema,256,256)            #补充水平翻转和光照增加或者减少50%            #img_arr_fname_b = img_bright(img_arr_fname)                                img_arr_fname_t = img_horizon(img_arr_fname)                img_arr_gnema_t = img_horizon(img_arr_gnema)              #lighting image 共4个,normal ground truth共2个                                         img_arr_in, img_arr_out = [nd.transpose(img_arr_fname, (2,0,1)),                                           nd.transpose(img_arr_gnema, (2,0,1))]                img_arr_in, img_arr_out = [img_arr_in.reshape((1,) + img_arr_in.shape),                                           img_arr_out.reshape((1,) + img_arr_out.shape)]                img_in_list.append(img_arr_in)                img_out_list.append(img_arr_out)                            img_arr_in_t, img_arr_out_t = [nd.transpose(img_arr_fname_t, (2,0,1)),                                               nd.transpose(img_arr_gnema_t, (2,0,1))]                img_arr_in_t, img_arr_out_t = [img_arr_in_t.reshape((1,) + img_arr_in_t.shape),                                               img_arr_out_t.reshape((1,) + img_arr_out_t.shape)]                img_in_list.append(img_arr_in_t)                img_out_list.append(img_arr_out_t)                                return mx.io.NDArrayIter(data=[nd.concat(*img_in_list,dim=0), nd.concat(*img_out_list,dim=0)],batch_size=batch_size)
img_wd = 256img_ht = 256train_img_path = '../data/edges2handbags/train_mini/'val_img_path = '../data/edges2handbags/val/' def load_data(path, batch_size, is_reversed=False):    img_in_list = []    img_out_list = []    for path, _, fnames in os.walk(path):        for fname in fnames:            if not fname.endswith('.jpg'):                continue            img = os.path.join(path, fname)            img_arr = mx.image.imread(img).astype(np.float32)/127.5 - 1            img_arr = mx.image.imresize(img_arr, img_wd * 2, img_ht)            # Crop input and output images            img_arr_in, img_arr_out = [mx.image.fixed_crop(img_arr, 0, 0, img_wd, img_ht),                                       mx.image.fixed_crop(img_arr, img_wd, 0, img_wd, img_ht)]            img_arr_in, img_arr_out = [nd.transpose(img_arr_in, (2,0,1)),                                       nd.transpose(img_arr_out, (2,0,1))]            img_arr_in, img_arr_out = [img_arr_in.reshape((1,) + img_arr_in.shape),                                       img_arr_out.reshape((1,) + img_arr_out.shape)]            img_in_list.append(img_arr_out if is_reversed else img_arr_in)            img_out_list.append(img_arr_in if is_reversed else img_arr_out)    return mx.io.NDArrayIter(data=[nd.concat(*img_in_list, dim=0), nd.concat(*img_out_list, dim=0)],                             batch_size=batch_size)train_data = load_data(train_img_path, batch_size, is_reversed=False)val_data = load_data(val_img_path, batch_size, is_reversed=False)
img_horizon = mx.image.HorizontalFlipAug(1)def load_retinex(batch_size):    img_in_list = []    img_out_list = []        path='CAS/Lighting_aligned_128'    ground_path = 'CAS/Normal_aligned_128'    img_in_list = []    img_out_list = []    """     for path, _, fnames in os.walk(path):        for fname in fnames:            if not fname.endswith('.png'):                continue                        temp_name = fname[0:9]+'_IEU+00_PM+00_EN_A0_D0_T0_BB_M0_R0_S0.png'            ground_img = os.path.join(ground_path, temp_name)            if not os.path.exists(ground_img):                temp_name = fname[0:9]+'_IEU+00_PM+00_EN_A0_D0_T0_BB_M0_R1_S0.png'                ground_img = os.path.join(ground_path, temp_name)            if not os.path.exists(ground_img):                continue            lingting_img = os.path.join(path, fname)                                #补充水平翻转和光照增加或者减少50%            img_arr_fname = mx.image.imread(lingting_img).astype(np.float32)/127.5 - 1            img_arr_fname_t = img_horizon(img_arr_fname)                                 img_arr_gnema = mx.image.imread(ground_img).astype(np.float32)/127.5 - 1            img_arr_gnema_t = img_horizon(img_arr_gnema)                          img_arr_in, img_arr_out = [nd.transpose(img_arr_fname, (2,0,1)),                                    nd.transpose(img_arr_gnema, (2,0,1))]            img_arr_in, img_arr_out = [img_arr_in.reshape((1,) + img_arr_in.shape),                                    img_arr_out.reshape((1,) + img_arr_out.shape)]            img_in_list.append(img_arr_in)            img_out_list.append(img_arr_out)                                 img_arr_in_t, img_arr_out_t = [nd.transpose(img_arr_fname_t, (2,0,1)),                                            nd.transpose(img_arr_gnema_t, (2,0,1))]            img_arr_in_t, img_arr_out_t = [img_arr_in_t.reshape((1,) + img_arr_in_t.shape),                                         img_arr_out_t.reshape((1,) + img_arr_out_t.shape)]            img_in_list.append(img_arr_in_t)            img_out_list.append(img_arr_out_t)                """           mulpath_lighting = 'MultiPIE/MultiPIE_Lighting_128/'    mulpaht_ground = 'MultiPIE/MultiPIE_Lighting_128/'    for path, _, fnames in os.walk(mulpath_lighting):        for fname in fnames:            num = fname[14:16]            if num !='07':                lingting_img = os.path.join(mulpath_lighting, fname)                ground_img = os.path.join(mulpaht_ground,fname[:14]+'07.png')                img_arr_fname = mx.image.imread(lingting_img).astype(np.float32)/127.5 - 1                img_arr_gnema = mx.image.imread(ground_img).astype(np.float32)/127.5 - 1                        #img_arr_fname = mx.image.imresize(img_arr_fname,256,256)            #img_arr_gnema = mx.image.imresize(img_arr_gnema,256,256)            #补充水平翻转和光照增加或者减少50%            #img_arr_fname_b = img_bright(img_arr_fname)                                img_arr_fname_t = img_horizon(img_arr_fname)                img_arr_gnema_t = img_horizon(img_arr_gnema)              #lighting image 共4个,normal ground truth共2个                                         img_arr_in, img_arr_out = [nd.transpose(img_arr_fname, (2,0,1)),                                           nd.transpose(img_arr_gnema, (2,0,1))]                img_arr_in, img_arr_out = [img_arr_in.reshape((1,) + img_arr_in.shape),                                           img_arr_out.reshape((1,) + img_arr_out.shape)]                img_in_list.append(img_arr_in)                img_out_list.append(img_arr_out)                            img_arr_in_t, img_arr_out_t = [nd.transpose(img_arr_fname_t, (2,0,1)),                                               nd.transpose(img_arr_gnema_t, (2,0,1))]                img_arr_in_t, img_arr_out_t = [img_arr_in_t.reshape((1,) + img_arr_in_t.shape),                                               img_arr_out_t.reshape((1,) + img_arr_out_t.shape)]                img_in_list.append(img_arr_in_t)                img_out_list.append(img_arr_out_t)                           return mx.io.NDArrayIter(data=[nd.concat(*img_in_list,dim=0), nd.concat(*img_out_list,dim=0)],batch_size=batch_size)
def visualize(img_arr):    plt.imshow(((img_arr.asnumpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8))    plt.axis('off')def preview_train_data(train_data):    img_in_list, img_out_list = train_data.next().data    for i in range(4):        plt.subplot(2,4,i+1)        visualize(img_in_list[i])        plt.subplot(2,4,i+5)        visualize(img_out_list[i])    plt.show()train_data = load_retinex(10)preview_train_data(train_data)
# Define Unet generator skip blockclass UnetSkipUnit(HybridBlock):    def __init__(self, inner_channels, outer_channels, inner_block=None, innermost=False, outermost=False,                 use_dropout=False, use_bias=False):        super(UnetSkipUnit, self).__init__()        with self.name_scope():            self.outermost = outermost            en_conv = Conv2D(channels=inner_channels, kernel_size=4, strides=2, padding=1,                             in_channels=outer_channels, use_bias=use_bias)            en_relu = LeakyReLU(alpha=0.2)            en_norm = BatchNorm(momentum=0.1, in_channels=inner_channels)            de_relu = Activation(activation='relu')            de_norm = BatchNorm(momentum=0.1, in_channels=outer_channels)            if innermost:                de_conv = Conv2DTranspose(channels=outer_channels, kernel_size=4, strides=2, padding=1,                                          in_channels=inner_channels, use_bias=use_bias)                encoder = [en_relu, en_conv]                decoder = [de_relu, de_conv, de_norm]                model = encoder + decoder            elif outermost:                de_conv = Conv2DTranspose(channels=outer_channels, kernel_size=4, strides=2, padding=1,                                          in_channels=inner_channels * 2)                encoder = [en_conv]                decoder = [de_relu, de_conv, Activation(activation='tanh')]                model = encoder + [inner_block] + decoder            else:                de_conv = Conv2DTranspose(channels=outer_channels, kernel_size=4, strides=2, padding=1,                                          in_channels=inner_channels * 2, use_bias=use_bias)                encoder = [en_relu, en_conv, en_norm]                decoder = [de_relu, de_conv, de_norm]                model = encoder + [inner_block] + decoder            if use_dropout:                model += [Dropout(rate=0.5)]            self.model = HybridSequential()            with self.model.name_scope():                for block in model:                    self.model.add(block)    def hybrid_forward(self, F, x):        if self.outermost:            return self.model(x)        else:            return F.concat(self.model(x), x, dim=1)# Define Unet generatorclass UnetGenerator(HybridBlock):    def __init__(self, in_channels, num_downs, ngf=64, use_dropout=True):        super(UnetGenerator, self).__init__()        #Build unet generator structure        unet = UnetSkipUnit(ngf * 8, ngf * 8, innermost=True)        for _ in range(num_downs - 5):            unet = UnetSkipUnit(ngf * 8, ngf * 8, unet, use_dropout=use_dropout)        unet = UnetSkipUnit(ngf * 8, ngf * 4, unet)        unet = UnetSkipUnit(ngf * 4, ngf * 2, unet)        unet = UnetSkipUnit(ngf * 2, ngf * 1, unet)        unet = UnetSkipUnit(ngf, in_channels, unet, outermost=True)        with self.name_scope():            self.model = unet    def hybrid_forward(self, F, x):        return self.model(x)# Define the PatchGAN discriminatorclass Discriminator(HybridBlock):    def __init__(self, in_channels, ndf=64, n_layers=3, use_sigmoid=False, use_bias=False):        super(Discriminator, self).__init__()        with self.name_scope():            self.model = HybridSequential()            kernel_size = 4            padding = int(np.ceil((kernel_size - 1)/2))            self.model.add(Conv2D(channels=ndf, kernel_size=kernel_size, strides=2,                                  padding=padding, in_channels=in_channels))            self.model.add(LeakyReLU(alpha=0.2))            nf_mult = 1            for n in range(1, n_layers):                nf_mult_prev = nf_mult                nf_mult = min(2 ** n, 8)                self.model.add(Conv2D(channels=ndf * nf_mult, kernel_size=kernel_size, strides=2,                                      padding=padding, in_channels=ndf * nf_mult_prev,                                      use_bias=use_bias))                self.model.add(BatchNorm(momentum=0.1, in_channels=ndf * nf_mult))                self.model.add(LeakyReLU(alpha=0.2))            nf_mult_prev = nf_mult            nf_mult = min(2 ** n_layers, 8)            self.model.add(Conv2D(channels=ndf * nf_mult, kernel_size=kernel_size, strides=1,                                  padding=padding, in_channels=ndf * nf_mult_prev,                                  use_bias=use_bias))            self.model.add(BatchNorm(momentum=0.1, in_channels=ndf * nf_mult))            self.model.add(LeakyReLU(alpha=0.2))            self.model.add(Conv2D(channels=1, kernel_size=kernel_size, strides=1,                                  padding=padding, in_channels=ndf * nf_mult))            if use_sigmoid:                self.model.add(Activation(activation='sigmoid'))    def hybrid_forward(self, F, x):        out = self.model(x)        #print(out)        return out
def param_init(param):    if param.name.find('conv') != -1:        if param.name.find('weight') != -1:            param.initialize(init=mx.init.Normal(0.02), ctx=ctx)                    else:            param.initialize(init=mx.init.Zero(), ctx=ctx)    elif param.name.find('batchnorm') != -1:        param.initialize(init=mx.init.Zero(), ctx=ctx)        # Initialize gamma from normal distribution with mean 1 and std 0.02        if param.name.find('gamma') != -1:            param.set_data(nd.random_normal(1, 0.02, param.data().shape))def network_init(net):    with net.name_scope():        for param in net.collect_params().values():            param_init(param)def set_network():    # Pixel2pixel networks    netG1 = UnetGenerator(in_channels=3, num_downs=6)    netD1 = Discriminator(in_channels=6)    netG2 = UnetGenerator(in_channels=3, num_downs=6)    netD2 = Discriminator(in_channels=6)    # Initialize parameters    network_init(netG1)    network_init(netD1)    network_init(netG2)    network_init(netD2)    # trainer for the generator and the discriminator    trainerG1 = gluon.Trainer(netG1.collect_params(), 'adam', {
'learning_rate': lr, 'beta1': beta1}) trainerD1 = gluon.Trainer(netD1.collect_params(), 'adam', {
'learning_rate': lr, 'beta1': beta1}) trainerG2 = gluon.Trainer(netG2.collect_params(), 'adam', {
'learning_rate': lr, 'beta1': beta1}) trainerD2 = gluon.Trainer(netD2.collect_params(), 'adam', {
'learning_rate': lr, 'beta1': beta1}) return netG1, netD1, trainerG1, trainerD1, netG2, netD2, trainerG2, trainerD2# Loss#GAN_loss = gluon.loss.SigmoidBinaryCrossEntropyLoss()GAN_loss = gluon.loss.L2Loss()L1_loss = gluon.loss.L1Loss()L2_loss = gluon.loss.L2Loss()netG1, netD1, trainerG1, trainerD1, netG2, netD2, trainerG2, trainerD2 = set_network()
class ImagePool():    def __init__(self, pool_size):        self.pool_size = pool_size        if self.pool_size > 0:            self.num_imgs = 0            self.images = []    def query(self, images):        if self.pool_size == 0:            return images        ret_imgs = []        for i in range(images.shape[0]):            image = nd.expand_dims(images[i], axis=0)            if self.num_imgs < self.pool_size:                self.num_imgs = self.num_imgs + 1                self.images.append(image)                ret_imgs.append(image)            else:                p = nd.random_uniform(0, 1, shape=(1,)).asscalar()                if p > 0.5:                    random_id = nd.random_uniform(0, self.pool_size - 1, shape=(1,)).astype(np.uint8).asscalar()                    tmp = self.images[random_id].copy()                    self.images[random_id] = image                    ret_imgs.append(tmp)                else:                    ret_imgs.append(image)        ret_imgs = nd.concat(*ret_imgs, dim=0)        return ret_imgs

#这是retinex使用的代码

def singleScaleRetinex(img, sigma):    retinex = np.log10(img) - np.log10(cv2.GaussianBlur(img, (0, 0), sigma))    return retinexdef multiScaleRetinex(img, sigma_list):    retinex = np.zeros_like(img)    for sigma in sigma_list:        retinex += singleScaleRetinex(img, sigma)    retinex = retinex / len(sigma_list)    return retinexdef colorRestoration(img, alpha, beta):    img_sum = np.sum(img, axis=2, keepdims=True)    color_restoration = beta * (np.log10(alpha * img) - np.log10(img_sum))    return color_restorationdef simplestColorBalance(img, low_clip, high_clip):        total = img.shape[0] * img.shape[1]    for i in range(img.shape[2]):        unique, counts = np.unique(img[:, :, i], return_counts=True)        current = 0        for u, c in zip(unique, counts):                        if float(current) / total < low_clip:                low_val = u            if float(current) / total < high_clip:                high_val = u            current += c        img[:, :, i] = np.maximum(np.minimum(img[:, :, i], high_val), low_val)    return img    def MSRCR(img, sigma_list, G, b, alpha, beta, low_clip, high_clip):    img = np.float64(img) + 1.0    img_retinex = multiScaleRetinex(img, sigma_list)        img_color = colorRestoration(img, alpha, beta)        img_msrcr = G * (img_retinex * img_color + b)    for i in range(img_msrcr.shape[2]):        img_msrcr[:, :, i] = (img_msrcr[:, :, i] - np.min(img_msrcr[:, :, i])) / \                             (np.max(img_msrcr[:, :, i]) - np.min(img_msrcr[:, :, i])) * \                             255        img_msrcr = np.uint8(np.minimum(np.maximum(img_msrcr, 0), 255))    img_msrcr = simplestColorBalance(img_msrcr, low_clip, high_clip)           return img_msrcrdef automatedMSRCR(img, sigma_list):    img = np.float64(img) + 1.0    img_retinex = multiScaleRetinex(img, sigma_list)    for i in range(img_retinex.shape[2]):        unique, count = np.unique(np.int32(img_retinex[:, :, i] * 100), return_counts=True)        for u, c in zip(unique, count):            if u == 0:                zero_count = c                break                    low_val = unique[0] / 100.0        high_val = unique[-1] / 100.0        for u, c in zip(unique, count):            if u < 0 and c < zero_count * 0.1:                low_val = u / 100.0            if u > 0 and c < zero_count * 0.1:                high_val = u / 100.0                break        img_retinex[:, :, i] = np.maximum(np.minimum(img_retinex[:, :, i], high_val), low_val)                img_retinex[:, :, i] = (img_retinex[:, :, i] - np.min(img_retinex[:, :, i])) / \                               (np.max(img_retinex[:, :, i]) - np.min(img_retinex[:, :, i])) \                               * 255    img_retinex = np.uint8(img_retinex)    return img_retinexdef MSRCP(img, sigma_list, low_clip, high_clip):    img = np.float64(img) + 1.0    intensity = np.sum(img, axis=2) / img.shape[2]        retinex = multiScaleRetinex(intensity, sigma_list)    intensity = np.expand_dims(intensity, 2)    retinex = np.expand_dims(retinex, 2)    intensity1 = simplestColorBalance(retinex, low_clip, high_clip)    intensity1 = (intensity1 - np.min(intensity1)) / \                 (np.max(intensity1) - np.min(intensity1)) * \                 255.0 + 1.0    img_msrcp = np.zeros_like(img)    for y in range(img_msrcp.shape[0]):        for x in range(img_msrcp.shape[1]):            B = np.max(img[y, x])            A = np.minimum(256.0 / B, intensity1[y, x, 0] / intensity[y, x, 0])            img_msrcp[y, x, 0] = A * img[y, x, 0]            img_msrcp[y, x, 1] = A * img[y, x, 1]            img_msrcp[y, x, 2] = A * img[y, x, 2]    img_msrcp = np.uint8(img_msrcp - 1.0)    return img_msrcp

#预训练

from datetime import datetimeimport timeimport loggingdef facc(label, pred):        pred = pred.ravel()        label = label.ravel()        return ((pred > 0.5) == label).mean()def pre_train():    metric = mx.metric.CustomMetric(facc)    stamp =  datetime.now().strftime('%Y_%m_%d-%H_%M')    logging.basicConfig(level=logging.DEBUG)    for epoch in range(epochs):        tic = time.time()        btic = time.time()        train_data.reset()        iter = 0        for batch in train_data:            ############################            # (1) Update D network: maximize log(D(x, y)) + log(1 - D(x, G(x, z)))            ###########################            real_in = batch.data[0].as_in_context(ctx)            real_out = batch.data[1].as_in_context(ctx)                                      with autograd.record():                fake_out = netG1(real_in)                errG1 = L1_loss(fake_out, real_out)*lambda1                #errG1 = land_mark_errs(real_in, fake_out)                errG1.backward()            trainerG1.step(batch.data[0].shape[0])                         with autograd.record():                fake_out2 = netG2(real_out)                errG2 = L1_loss(fake_out2, real_in)*lambda1                 errG2.backward()            trainerG2.step(batch.data[0].shape[0])                   # Print log infomation every ten batches            if iter % 10 == 0:                name, acc = metric.get()                logging.info('speed: {} samples/s'.format(batch_size / (time.time() - btic)))                logging.info('G1generator1 loss = %f, binary training acc = %f at iter %d epoch %d'                        %(nd.mean(errG1).asscalar(), acc, iter, epoch))                logging.info('G1generator2 loss = %f, binary training acc = %f at iter %d epoch %d'                         %(nd.mean(errG2).asscalar(), acc, iter, epoch))                       iter = iter + 1            btic = time.time()        name, acc = metric.get()        metric.reset()        logging.info('\nbinary training acc at epoch %d: %s=%f' % (epoch, name, acc))        logging.info('time: %f' % (time.time() - tic))        # Visualize one generated image for each epoch        fake_img = fake_out[0]        visualize(fake_img)        plt.show()                #fake_img2 = fake_out2[0]        #visualize(fake_img2)        #plt.show()pre_train()
def save_data(path,tpath):    img_in_list = []    img_out_list = []    for path, _, fnames in os.walk(path):        for fname in fnames:            if not fname.endswith('.jpg'):                continue            img = os.path.join(path, fname)            img_arr = mx.image.imread(img).astype(np.float32)/127.5 - 1            img_arr = mx.image.imresize(img_arr, img_wd * 2, img_ht)            # Crop input and output images            img_arr_in, img_arr_out = [mx.image.fixed_crop(img_arr, 0, 0, img_wd, img_ht),                                       mx.image.fixed_crop(img_arr, img_wd, 0, img_wd, img_ht)]            #img_arr_in = mx.image.imresize(img_arr_in,128,128)            #img_arr_out = mx.image.imresize(img_arr_out,128,128)            img_arr_in, img_arr_out = [nd.transpose(img_arr_in, (2,0,1)),                                       nd.transpose(img_arr_out, (2,0,1))]            img_arr_in, img_arr_out = [img_arr_in.reshape((1,) + img_arr_in.shape),                                       img_arr_out.reshape((1,) + img_arr_out.shape)]            img_out = netG1(img_arr_out.as_in_context(ctx))            img_out1 = img_out[0]            img_out2 = ((img_out1.asnumpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8)                            save_name = tpath+fname                        cv2.imwrite(save_name, img_out2)save_data("../data/edges2handbags/val/","../data/edges2handbags/G1andG2/")
netD1 = Discriminator(in_channels=6)netD2 = Discriminator(in_channels=6)network_init(netD1)network_init(netD2)trainerD1 = gluon.Trainer(netD1.collect_params(), 'adam', {
'learning_rate': lr, 'beta1': beta1})trainerD2 = gluon.Trainer(netD2.collect_params(), 'adam', {
'learning_rate': lr, 'beta1': beta1})
from datetime import datetimeimport timeimport loggingdef facc(label, pred):        pred = pred.ravel()        label = label.ravel()        return ((pred > 0.5) == label).mean()def dual_pre_train():    metric = mx.metric.CustomMetric(facc)    stamp =  datetime.now().strftime('%Y_%m_%d-%H_%M')    logging.basicConfig(level=logging.DEBUG)    for epoch in range(epochs):        tic = time.time()        btic = time.time()        PIE_normal_to_lighting.reset()        iter = 0        for (batch1, batch2)  in zip(retinex_data,PIE_normal_to_lighting):            ############################            # (1) Update D network: maximize log(D(x, y)) + log(1 - D(x, G(x, z)))            ###########################            real_in = batch1.data[0].as_in_context(ctx)            real_out = batch1.data[1].as_in_context(ctx)            lighing_bad = batch2.data[0].as_in_context(ctx)             lighing_good = batch2.data[1].as_in_context(ctx)                                                       with autograd.record():                fake_out = netG1(real_in)                #errG1 = L1_loss(real_out, fake_out) + L1_loss(netG1(netG2(fake_out)),real_out)                errG1 = L1_loss(real_in, fake_out)+L1_loss(netG1(netG2(lighing_good)), lighing_good)                #增加一个三方loss                #errG1 = L1_loss(real_out, fake_out) + L1_loss(netG1(netG2(fake_out)),real_out)                            #                         + L1_loss(netG1(netG2(fake_out)),fake_out)                 errG1.backward()            trainerG1.step(batch1.data[0].shape[0])                        with autograd.record():                fake_out3 = netG2(real_out)                #errG2 = L1_loss(real_in, fake_out3) + L1_loss(netG2(netG1(fake_out3)),real_in)                errG2 = L1_loss(lighing_good, fake_out3)+L1_loss(netG2(netG1(real_in)), real_in)                 #增加一个三方loss                #errG2 = L1_loss(real_in, fake_out3) + L1_loss(netG2(netG1(fake_out3)),real_in)                           #                         + L1_loss(netG2(netG1(fake_out3)),fake_out3)                 errG2.backward()            trainerG2.step(batch2.data[0].shape[0])                        # Print log infomation every ten batches            if iter % 10 == 0:                name, acc = metric.get()                logging.info('speed: {} samples/s'.format(batch_size / (time.time() - btic)))                logging.info('G1generator loss = %f, binary training acc = %f at iter %d epoch %d'                         %(nd.mean(errG1).asscalar(), acc, iter, epoch))                logging.info('G2generator loss = %f, binary training acc = %f at iter %d epoch %d'                         %(nd.mean(errG2).asscalar(), acc, iter, epoch))            iter = iter + 1            btic = time.time()        name, acc = metric.get()        metric.reset()        logging.info('\nbinary training acc at epoch %d: %s=%f' % (epoch, name, acc))        logging.info('time: %f' % (time.time() - tic))        # Visualize one generated image for each epoch        fake_img = fake_out[0]        visualize(fake_img)        plt.show()dual_pre_train()
def test_netG(Spath,Tpath):    for path, _, fnames in os.walk(Spath):        for fname in fnames:            if not fname.endswith('.png'):                continue            #num = fname[14:16]            #if num !='07':                #continue            test_img = os.path.join(path, fname)            img_fname = mx.image.imread(test_img)             img_arr_fname = img_fname.astype(np.float32)/127.5 - 1            img_arr_fname = mx.image.imresize(img_arr_fname,128,128)            img_arr_in = nd.transpose(img_arr_fname, (2,0,1))            img_arr_in = img_arr_in.reshape((1,) + img_arr_in.shape)            img_out = netG1(img_arr_in.as_in_context(ctx))            img_out = img_out[0]            #img_out = mx.image.imresize(img_out,120,165)            save_name = Tpath+ fname            plt.imsave(save_name, ((img_out.asnumpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8) )            #test_netG('MultiPIE/MultiPIE_test_128_Gray/','MultiPIE/relighting/')test_netG('MultiPIE/Bio_relighing2/','MultiPIE/Bio_color/')

#使用opencv的人脸特征点作为损失

fileDir = '/home/hxj/gluon-tutorials/GAN/openface/'sys.path.append(os.path.join(fileDir))import argparseimport cv2import dlibimport matplotlib.pyplot as pltfrom pylab import plot  from openface.align_dlib import AlignDlibmodelDir = os.path.join(fileDir, 'models')openfaceModelDir = os.path.join(modelDir, 'openface')dlibModelDir = os.path.join(modelDir, 'dlib')dlibFacePredictor= os.path.join(dlibModelDir, "shape_predictor_68_face_landmarks.dat")def land_mark_errs(batch1,batch2):    align = AlignDlib(dlibFacePredictor)    sum_err = nd.zeros((10)).as_in_context(ctx)    i=0    for (x,y) in zip(batch1,batch2):        x1 = ((x.asnumpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8)        y1 = ((y.asnumpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8)        """        bbx = align.getLargestFaceBoundingBox(x1)        if bbx is None:            x1_r = MSRCR(x1,[15, 80, 250], 5.0, 25.0, 125.0, 46.0, 0.01, 0.99)            bbx = align.getLargestFaceBoundingBox(x1_r)            if bbx is None:                lab = cv2.cvtColor(x1,cv2.COLOR_RGB2LUV)                bbx = align.getLargestFaceBoundingBox(lab)                #if bbx is None:                    #print('bbx is none')                        bby = align.getLargestFaceBoundingBox(y1)        if bby is None:            y1_r = MSRCR(y1,[15, 80, 250], 5.0, 25.0, 125.0, 46.0, 0.01, 0.99)            bby = align.getLargestFaceBoundingBox(y1_r)            if bby is None:                lab = cv2.cvtColor(y1, cv2.COLOR_RGB2LUV)                bby = align.getLargestFaceBoundingBox(lab)                if bby is None:                    #print('bby is none')            if bby is None:            continue        if bbx is None:            #bbx= bby            continue        """        bbx = dlib.rectangle(-19, -19, 124, 125)        bby = dlib.rectangle(-19, -19, 124, 125)        landmarks_x = nd.array(align.findLandmarks(x1, bbx))        landmarks_y = nd.array(align.findLandmarks(y1, bby))        if landmarks_x  is None:            continue        if landmarks_y is None:            continue        sum_err[i]=nd.sum(nd.abs(landmarks_x -landmarks_y))/68        i+=1    return sum_err
from datetime import datetimeimport timeimport loggingdef facc(label, pred):        pred = pred.ravel()        label = label.ravel()        return ((pred > 0.5) == label).mean()def generate_train_single():    image_pool = ImagePool(pool_size)    metric = mx.metric.CustomMetric(facc)    stamp =  datetime.now().strftime('%Y_%m_%d-%H_%M')    logging.basicConfig(level=logging.DEBUG)        for epoch in range(epochs):        tic = time.time()        btic = time.time()        train_data.reset()        iter = 0        for batch1  in train_data:        #for batch in range(400):            ############################            # (1) Update D network: maximize log(D(x, y)) + log(1 - D(x, G(x, z)))            ###########################            real_in = batch1.data[0].as_in_context(ctx) #将train数据的输入和输出调出来            real_out = batch1.data[1].as_in_context(ctx)                        #G1            fake_out = netG1(real_in)            with autograd.record():                errG1 = L1_loss(real_out, fake_out)* 20 +  L1_loss(netG1(netG2(real_out)), real_out) *10                              #land_mark_errs(real_in, fake_out)*0.4                errG1.backward()                                                                                                                trainerG1.step(batch1.data[0].shape[0])                                   #G2            fake_out2 = netG2(real_out)            with autograd.record():                errG2 = L1_loss(real_in, fake_out2)* 20 +  L1_loss(netG2(netG1(real_in)), real_in)*10                                #land_mark_errs(real_out, fake_out2)*0.4            trainerG2.step(batch1.data[0].shape[0])                        # Print log infomation every ten batches            if iter % 10 == 0:                name, acc = metric.get()                logging.info('speed: {} samples/s'.format(batch_size / (time.time() - btic)))                logging.info('generator1 loss = %f, binary training acc = %f at iter %d epoch %d'                         %(nd.mean(errG1).asscalar(), acc, iter, epoch))                logging.info('generator2 loss = %f, binary training acc = %f at iter %d epoch %d'                         %(nd.mean(errG2).asscalar(), acc, iter, epoch))            iter = iter + 1            btic = time.time()        name, acc = metric.get()        metric.reset()        logging.info('\nbinary training acc at epoch %d: %s=%f' % (epoch, name, acc))        logging.info('time: %f' % (time.time() - tic))        # Visualize one generated image for each epoch               fake_img = fake_out[0]        visualize(fake_img)        plt.show()                generate_train_single()
from skimage import iobgrImg = cv2.imread('CAS/test_aligned_128/FM_000046_IFD+90_PM+00_EN_A0_D0_T0_BW_M0_R1_S0.png')rgbImg = cv2.cvtColor(bgrImg, cv2.COLOR_BGR2RGB)plt.imshow(rgbImg)plt.show()lab = cv2.cvtColor(bgrImg, cv2.COLOR_BGR2LAB)plt.imshow(lab)plt.show()img_test = lab[:,:,0].astype(np.float32)/127.5 - 1img_test = nd.array(img_test)img_arr_in= img_test.reshape((1,1,) + img_test.shape).as_in_context(ctx)test1 = netG1(img_arr_in)test2 = test1[0][0]cv2.imshow(((test2.asnumpy() + 1.0) * 127.5).astype(np.uint8))
from datetime import datetimeimport timeimport loggingdef facc(label, pred):        pred = pred.ravel()        label = label.ravel()        return ((pred > 0.5) == label).mean()def Dual_train_single():    image_pool = ImagePool(pool_size)    metric = mx.metric.CustomMetric(facc)    stamp =  datetime.now().strftime('%Y_%m_%d-%H_%M')    logging.basicConfig(level=logging.DEBUG)        for epoch in range(epochs):        tic = time.time()        btic = time.time()        train_data.reset()        iter = 0        for batch1  in train_data:        #for batch in range(400):            ############################            # (1) Update D network: maximize log(D(x, y)) + log(1 - D(x, G(x, z)))            ###########################            real_in = batch1.data[0].as_in_context(ctx) #将train数据的输入和输出调出来            real_out = batch1.data[1].as_in_context(ctx)                                             #D1              fake_out = netG1(real_in)            fake_concat = image_pool.query(nd.concat(real_in, fake_out, dim=1))            with autograd.record():                output = netD1(fake_concat)                #output = netD1(fake_out)                fake_label = nd.zeros(output.shape, ctx=ctx)                errD_fake = GAN_loss(output, fake_label)                metric.update([fake_label,], [output,])                                           # Train with real image                real_concat = image_pool.query(nd.concat(real_in, real_out, dim=1))                #ground truth 也要经过G1                output = netD1(real_concat)                 real_label = nd.ones(output.shape, ctx=ctx)                errD_real = GAN_loss(output, real_label)                                errD1 = (errD_real + errD_fake) *0.5                errD1.backward()                metric.update([real_label,], [output,])            trainerD1.step(batch1.data[0].shape[0])                       #G1            with autograd.record():                #fake_out = netG1(real_in)                fake_concat = image_pool.query(nd.concat(real_in, fake_out, dim=1))                output = netD1(fake_concat)                #output = netD1(fake_out)                real_label = nd.ones(output.shape, ctx=ctx)                #errG1 = GAN_loss(output, real_label) + L1_loss(real_out, fake_out) * lambda1+ \                #L1_loss(netG2(netG1(real_in)), real_in) * lambda1                #errG1 = GAN_loss(output, real_label) + L1_loss(real_in, fake_out) * lambda1+ \                #L1_loss(netG1(netG2(fake_out)), fake_out) * lambda1                                                              errG1 = GAN_loss(output, real_label) + L1_loss(real_out, fake_out) * 20+ \                L1_loss(netG1(netG2(real_out)), real_out) *10                #land_mark_errs(real_out, fake_out)                            errG1.backward()                                                                                                                 trainerG1.step(batch1.data[0].shape[0])                                   #D2              fake_out2 = netG2(real_out)            fake_concat2 = image_pool.query(nd.concat(real_out, fake_out2, dim=1))            with autograd.record():                output2 = netD2(fake_concat2)                fake_label2 = nd.zeros(output2.shape, ctx=ctx)                errD_fake2 = GAN_loss(output2, fake_label2)                metric.update([fake_label2,], [output2,])                                           # Train with real image                real_concat2 = image_pool.query(nd.concat(real_out, real_in, dim=1))                output2 = netD2(real_concat2)                real_label2 = nd.ones(output2.shape, ctx=ctx)                errD_real2 = GAN_loss(output2, real_label2)                                errD2 = (errD_real2 + errD_fake2) * 0.5                 errD2.backward()                metric.update([real_label2,], [output2,])            trainerD2.step(batch1.data[0].shape[0])                       #G2               with autograd.record():                #fake_out2 = netG2(real_out)                fake_concat2 = image_pool.query(nd.concat(real_out, fake_out2, dim=1))                output2 = netD2(fake_concat2)                real_label2 = nd.ones(output2.shape, ctx=ctx)                                            #errG2 = GAN_loss(output2, real_label2)+ L1_loss(real_in, fake_out2) * lambda1+ \                #L1_loss(netG1(netG2(lighing_good)), lighing_good) * lambda1                errG2 = GAN_loss(output2, real_label2)+ L1_loss(real_in, fake_out2) * 20+ \                L1_loss(netG2(netG1(real_in)), real_in) *10                #land_mark_errs(real_in, fake_out2)                errG2.backward()                            trainerG2.step(batch1.data[0].shape[0])                        # Print log infomation every ten batches            if iter % 10 == 0:                name, acc = metric.get()                logging.info('speed: {} samples/s'.format(batch_size / (time.time() - btic)))                logging.info('discriminator1 loss = %f, generator1 loss = %f, binary training acc = %f at iter %d epoch %d'                         %(nd.mean(errD1).asscalar(),                           nd.mean(errG1).asscalar(), acc, iter, epoch))                logging.info('discriminator2 loss = %f, generator2 loss = %f, binary training acc = %f at iter %d epoch %d'                         %(nd.mean(errD2).asscalar(),                           nd.mean(errG2).asscalar(), acc, iter, epoch))            iter = iter + 1            btic = time.time()        name, acc = metric.get()        metric.reset()        logging.info('\nbinary training acc at epoch %d: %s=%f' % (epoch, name, acc))        logging.info('time: %f' % (time.time() - tic))        # Visualize one generated image for each epoch               fake_img = fake_out[0]        visualize(fake_img)        plt.show()                Dual_train_single()
from datetime import datetimeimport timeimport loggingdef facc(label, pred):        pred = pred.ravel()        label = label.ravel()        return ((pred > 0.5) == label).mean()def train():    #image_pool = ImagePool(pool_size)    metric = mx.metric.CustomMetric(facc)    stamp =  datetime.now().strftime('%Y_%m_%d-%H_%M')    logging.basicConfig(level=logging.DEBUG)        for epoch in range(epochs):        tic = time.time()        btic = time.time()        retinex_data.reset()        PIE_normal_to_lighting.reset()        iter = 0        for (batch1, batch2)  in zip(retinex_data,PIE_normal_to_lighting):        #for batch in range(400):            ############################            # (1) Update D network: maximize log(D(x, y)) + log(1 - D(x, G(x, z)))            ###########################            real_in = batch1.data[0].as_in_context(ctx) #将train数据的输入和输出调出来            real_out = batch1.data[1].as_in_context(ctx)            lighing_bad = batch2.data[0].as_in_context(ctx)             lighing_good = batch2.data[1].as_in_context(ctx)                                              fake_out = netG1(real_in)            #D1              with autograd.record():                                #fake_concat = image_pool.query(nd.concat(real_in, fake_out, dim=1))                #output = netD1(fake_concat)                output = netD1(fake_out)                fake_label = nd.zeros(output.shape, ctx=ctx)                errD_fake = GAN_loss(output, fake_label)                metric.update([fake_label,], [output,])                                           # Train with real image                #real_concat = image_pool.query(nd.concat(real_in, lighing_good, dim=1))                output = netD1(lighing_good)                real_label = nd.ones(output.shape, ctx=ctx)                errD_real = GAN_loss(output, real_label)                                errD1 = (errD_real + errD_fake) * 0.5                 errD1.backward()                metric.update([real_label,], [output,])            trainerD1.step(batch1.data[0].shape[0])                       #G1            with autograd.record():                #fake_concat = image_pool.query(nd.concat(real_in, fake_out, dim=1))                #output = netD1(fake_concat)                fake_out = netG1(real_in)                output = netD1(fake_out)                real_label = nd.ones(output.shape, ctx=ctx)                #errG1 = GAN_loss(output, real_label) + L1_loss(real_out, fake_out) * lambda1+ \                #L1_loss(netG2(netG1(real_in)), real_in) * lambda1                #errG1 = GAN_loss(output, real_label) + L1_loss(real_in, fake_out) * lambda1+ \                #L1_loss(netG1(netG2(fake_out)), fake_out) * lambda1                errG1 = GAN_loss(output, real_label) + L1_loss(real_in, fake_out) * lambda1+ \                L1_loss(netG1(netG2(lighing_good)), lighing_good) * lambda1                errG1.backward()                                                                                                                 trainerG1.step(batch1.data[0].shape[0])                                   #D2              fake_out2 = netG2(lighing_good)            with autograd.record():                #fake_concat2 = image_pool.query(nd.concat(lighing_good, fake_out2, dim=1))                output2 = netD2(fake_out2)                fake_label2 = nd.zeros(output2.shape, ctx=ctx)                errD_fake2 = GAN_loss(output2, fake_label2)                metric.update([fake_label2,], [output2,])                                           # Train with real image                #real_concat2 = image_pool.query(nd.concat(lighing_good, real_in, dim=1))                output2 = netD2(real_in)                real_label2 = nd.ones(output2.shape, ctx=ctx)                errD_real2 = GAN_loss(output2, real_label2)                                errD2 = (errD_real2 + errD_fake2) * 0.5                 errD2.backward()                metric.update([real_label2,], [output2,])            trainerD2.step(batch2.data[0].shape[0])                       #G2               with autograd.record():                fake_out2 = netG2(lighing_good)                #fake_concat2 = image_pool.query(nd.concat(lighing_good, fake_out2, dim=1))                output2 = netD2(fake_out2)                real_label2 = nd.ones(output2.shape, ctx=ctx)                              #errG2 = GAN_loss(output2, real_label2)+ L1_loss(real_in, fake_out2) * lambda1+ \                #L1_loss(netG1(netG2(lighing_good)), lighing_good) * lambda1                errG2 = GAN_loss(output2, real_label2)+ L1_loss(lighing_good, fake_out2) * lambda1+ \                L1_loss(netG2(netG1(real_in)), real_in) * lambda1                errG2.backward()                            trainerG2.step(batch2.data[0].shape[0])                        # Print log infomation every ten batches            if iter % 10 == 0:                name, acc = metric.get()                logging.info('speed: {} samples/s'.format(batch_size / (time.time() - btic)))                logging.info('discriminator1 loss = %f, generator1 loss = %f, binary training acc = %f at iter %d epoch %d'                         %(nd.mean(errD1).asscalar(),                           nd.mean(errG1).asscalar(), acc, iter, epoch))                logging.info('discriminator2 loss = %f, generator2 loss = %f, binary training acc = %f at iter %d epoch %d'                         %(nd.mean(errD2).asscalar(),                           nd.mean(errG2).asscalar(), acc, iter, epoch))            iter = iter + 1            btic = time.time()        name, acc = metric.get()        metric.reset()        logging.info('\nbinary training acc at epoch %d: %s=%f' % (epoch, name, acc))        logging.info('time: %f' % (time.time() - tic))        # Visualize one generated image for each epoch               fake_img = fake_out[0]        visualize(fake_img)        plt.show()                train()

 

转载于:https://www.cnblogs.com/hxjbc/p/9480112.html

你可能感兴趣的文章
360浏览器兼容模式 不能$.post (不是a 连接 onclick的问题!!)
查看>>
spring注入Properties
查看>>
【BZOJ-2295】我爱你啊 暴力
查看>>
【BZOJ-1055】玩具取名 区间DP
查看>>
Bit Twiddling Hacks
查看>>
Windwos中的线程同步
查看>>
LeetCode : Reverse Vowels of a String
查看>>
时间戳与日期的相互转换
查看>>
jmeter(五)创建web测试计划
查看>>
python基本数据类型
查看>>
1305: [CQOI2009]dance跳舞 - BZOJ
查看>>
关于TDD的思考
查看>>
Cocos2d-x学习之windows 7 android环境搭建
查看>>
将html代码中的大写标签转换成小写标签
查看>>
jmeter多线程组间的参数传递
查看>>
零散笔记
查看>>
学 Win32 汇编[22] - 逻辑运算指令: AND、OR、XOR、NOT、TEST
查看>>
MaiN
查看>>
[Python学习] 简单网络爬虫抓取博客文章及思想介绍
查看>>
第四章 解析库的使用 4.2 BeautifulSoup的使用
查看>>