Pix2Pix code review 코드리뷰

2021. 4. 8. 14:03Computer vision 논문 & 코드리뷰

728x90

아무래도 Test보다는 Train이 고려할 것이 많다보니 Train 과정 먼저 살펴보겠습니다.

빨간 글씨: 의문점/공부해야할 부분

 

<Train.py>

1. parser에서 여러 옵션들을 parse해오고, dataset을 만들어 줌

opt = TrainOptions().parse()   # get training options
dataset = create_dataset(opt)  # create a dataset given opt.dataset_mode and other options
dataset_size = len(dataset)    # get the number of images in the dataset.
print('The number of training images = %d' % dataset_size)

ㅇ이 과정에서 TrainOptions class는 train에 필요한 옵션들을 가져오는데, BaseOptions class를 상속하며, BaseOptions의 initialize함수로 parser을 초기화 한 후 train에 필요한 추가적 옵션들을 추가한다.

create_dataset으로 옵션에 맞게 dataset을 load한다.

torch.utils.data에 대해 깊게 이해를 하는 시간을 가져야할 듯 하다

<BaseOptions의 initialize>
def initialize(self, parser):
        """Define the common options that are used in both training and test."""
        # basic parameters
        parser.add_argument('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')
        parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')
        parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0  0,1,2, 0,2. use -1 for CPU')
        parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
        # model parameters
        parser.add_argument('--model', type=str, default='cycle_gan', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]')
        parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')
        parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale')
        parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')
        parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')
        parser.add_argument('--netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')
        parser.add_argument('--netG', type=str, default='resnet_9blocks', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]')
        parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')
        parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]')
        parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]')
        parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')
        parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator')
        # dataset parameters
        parser.add_argument('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]')
        parser.add_argument('--direction', type=str, default='AtoB', help='AtoB or BtoA')
        parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
        parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data')
        parser.add_argument('--batch_size', type=int, default=1, help='input batch size')
        parser.add_argument('--load_size', type=int, default=286, help='scale images to this size')
        parser.add_argument('--crop_size', type=int, default=256, help='then crop to this size')
        parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
        parser.add_argument('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')
        parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')
        parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')
        # additional parameters
        parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
        parser.add_argument('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]')
        parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
        parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')
        self.initialized = True
        return parser
        
 <TrainOptions의 initialize>
 def initialize(self, parser):
        parser = BaseOptions.initialize(self, parser)
        # visdom and HTML visualization parameters
        parser.add_argument('--display_freq', type=int, default=400, help='frequency of showing training results on screen')
        parser.add_argument('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.')
        parser.add_argument('--display_id', type=int, default=1, help='window id of the web display')
        parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display')
        parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")')
        parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display')
        parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html')
        parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
        parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
        # network saving and loading parameters
        parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')
        parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs')
        parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration')
        parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
        parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>, ...')
        parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
        # training parameters
        parser.add_argument('--n_epochs', type=int, default=100, help='number of epochs with the initial learning rate')
        parser.add_argument('--n_epochs_decay', type=int, default=100, help='number of epochs to linearly decay learning rate to zero')
        parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
        parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
        parser.add_argument('--gan_mode', type=str, default='lsgan', help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.')
        parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images')
        parser.add_argument('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]')
        parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations')

        self.isTrain = True
        return parser

  2. 이렇게 받아온 옵션들을 넣어 model을 만들어주고, setup(네트워크 로드 및 네트워크 출력, schedulers 생성) 

model = create_model(opt)      # create a model given opt.model and other options
model.setup(opt)               # regular setup: load and print networks; create schedulers
visualizer = Visualizer(opt)   # create a visualizer that display/save images and plots
total_iters = 0                # the total number of training iterations

models/__init__.py에 있는 create_model(opt)는 opt.model의 이름을 갖는 모델(BaseModel의 subclass여야 한다)을 찾아 opt의 옵션들을 넣어 instance를 만들어 반환해줍니다. (저희는 pix2pix모델을 사용할 것인데, parser의 --model의 default값이 cyclegan이네요, 코드를 돌릴 때 pix2pix로 바꿔주어야 하겠습니다)

 

visualizer 는 주석대로 시각화/저장 위한 것이므로 자세한 설명은 하지 않겠습니다.

totla_iters는 일단 0으로 초기화를 해주고, epoch수에 따라 계산해줍니다.

 

3. epoch수 만큼 반복 - 본격적인 training

opt.epoch_count : 시작 epoch 값

opt.n_epochs: initial lr로 진행하는 epoch수

opt.n_epochs_decay : 선형적으로 lr을 0으로 감소시킬 epoch수 (???)

    for epoch in range(opt.epoch_count, opt.n_epochs + opt.n_epochs_decay + 1):    # outer loop for different epochs; we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>
        epoch_start_time = time.time()  # timer for entire epoch
        iter_data_time = time.time()    # timer for data loading per iteration
        epoch_iter = 0                  # the number of training iterations in current epoch, reset to 0 every epoch
        visualizer.reset()              # reset the visualizer: make sure it saves the results to HTML at least once every epoch
        model.update_learning_rate()    # update learning rates in the beginning of every epoch.
        for i, data in enumerate(dataset):  # inner loop within one epoch
            iter_start_time = time.time()  # timer for computation per iteration
            if total_iters % opt.print_freq == 0:
                t_data = iter_start_time - iter_data_time

            total_iters += opt.batch_size
            epoch_iter += opt.batch_size
            model.set_input(data)         # unpack data from dataset and apply preprocessing
            model.optimize_parameters()   # calculate loss functions, get gradients, update network weights

            if total_iters % opt.display_freq == 0:   # display images on visdom and save images to a HTML file
                save_result = total_iters % opt.update_html_freq == 0
                model.compute_visuals()
                visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)

            if total_iters % opt.print_freq == 0:    # print training losses and save logging information to the disk
                losses = model.get_current_losses()
                t_comp = (time.time() - iter_start_time) / opt.batch_size
                visualizer.print_current_losses(epoch, epoch_iter, losses, t_comp, t_data)
                if opt.display_id > 0:
                    visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, losses)

            if total_iters % opt.save_latest_freq == 0:   # cache our latest model every <save_latest_freq> iterations
                print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters))
                save_suffix = 'iter_%d' % total_iters if opt.save_by_iter else 'latest'
                model.save_networks(save_suffix)

            iter_data_time = time.time()
        if epoch % opt.save_epoch_freq == 0:              # cache our model every <save_epoch_freq> epochs
            print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters))
            model.save_networks('latest')
            model.save_networks(epoch)

        print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.n_epochs + opt.n_epochs_decay, time.time() - epoch_start_time))

<models/pix2pix_model.py>

본격적으로 pix2pix를 구현한 코드를 볼 수 있는데요, /models/base_model.pyBaseModel을 상속받고 있는 class로 구현되어 있습니다.

class Pix2PixModel(BaseModel):

BaseModel은 abc(abstract base class)로, 

<set_input>: unpack data from dataset and apply preprocessing.
<forward>: produce intermediate results.
<optimize_parameters>: calculate losses, gradients, and update network weights.

이 세가지 메소드가 @abstractmethod가 subclass가 꼭 구현해야 하는 메소드 이며.

주석에서는 subclass를 생성시 __init__에서 BaseModel.__init__(self, opt)를 꼭 처음에 call해주어야 한다고 하고 있다.

(modify_commandline_options 는 optional 이지만 pix2pixModel 클래스에서는 구현되어있음)


: : 주석 내용 : : 

주어진 paired data에서 input images와 output images를 mapping하는 것을 배우는 pix2pix model을 구현한 클래스이며, training 시 '--dataset_mode aligned' dataset 을 필요로 한다고 합니다.

● default

  1. U-Net getnerator('--netG unet256') 
  2. PatchGAN discriminator('--netD basic')를 사용
  3. '--gan_mode' : vanilla GAN loss (the cross-entropy objective used in the orignal GAN paper). (vanilla GAN Loss에 대해 공부하기)

1. 일단 __init__함수에서, 위에서 말한 것과 같이 BaseModel의 init을 한번 호출해 줍니다. 이후 pix2pix에 맞게 인스턴스 변수들을 지정해줍니다. 

 

## <BaseModel>의 __init__ ##
def __init__(self, opt):
        self.opt = opt
        self.gpu_ids = opt.gpu_ids
        self.isTrain = opt.isTrain
        self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')  # get device name: CPU or GPU
        self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)  # save all the checkpoints to save_dir
        if opt.preprocess != 'scale_width':  # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
            torch.backends.cudnn.benchmark = True
        self.loss_names = []
        self.model_names = []
        self.visual_names = []
        self.optimizers = []
        self.image_paths = []
        self.metric = 0  # used for learning rate policy 'plateau'
        
## BaseModel.__init__(self, opt)호출 이후의 Pix2PixModel의 __init__ ##
		# specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
        self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake']
        # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
        self.visual_names = ['real_A', 'fake_B', 'real_B']
        # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>
        if self.isTrain:
            self.model_names = ['G', 'D']
        else:  # during test time, only load G
            self.model_names = ['G']
        # define networks (both generator and discriminator)
        self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm,
                                      not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)

        if self.isTrain:  # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
            self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
                                          opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)

        if self.isTrain:
            # define loss functions
            self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
            self.criterionL1 = torch.nn.L1Loss()
            # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
            self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)

보시면 BaseModel의 init으로 옵션으로 옵션들, 사용할 device, 저장할 위치 등 기본적인 내용들을 인스턴스 변수로 지정하고, Pix2Pix init에서 loss, optimizer 등의 모델마다 다르게 지정해줘야 하는 것들을 인스턴스 변수로 지정해줍니다.

 

self.loss_names 말 그대로 loss들의 이름을 넣어주겠죠 G_GAN은 generator의 loss 에서 cGAN loss 부분을, 
G_L1은 generator의 loss에서 L1 loss 부분을 말합니다.

D_real과 D_fake는 Discriminator가 real/fake를 판단하면서 계산되는 loss라고 생각했는데, 왜 두개로 나뉘는지 이해가 가지 않았음
: pix2pix_model.py의 backward_D(self) 함수를 보고 해결!
A->B변환에서 fake인 B를 넣었을 때와 real인 B를 넣었을 때를 각 loss를 구해 합쳐서 1/2해 준 것이 최종 D의 loss
self.isTrain은 train인지 /test인지 결정해주기 위한 인스턴스 변수인데, train.py에서 True로 정해주는 걸 볼 수 있습니다.


self.isTrain=True

self.model_names에 'G', 'D'모두 넣어줍니다.
self.netD에서 discriminator를 정의해줍니다.(cGAN: input/output이미지 모두 필요)

models/network.py의 define_G, define_D로 Generator와 Discriminator를 정의해 주는 부분 도 볼 수 있습니다.

self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm,
                                      not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
if self.isTrain:  # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
        self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
                                      opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)

<network.py define_G, define_D>

define_G

input_nc (int) -- 인풋채널수 output_nc (int) -- 아웃풋채널수 ngf (int) -- last conv layer filter 수 (default:64)
netG (str) -- the architecture's name: resnet_9blocks | resnet_6blocks | unet_256 | unet_128
(default: resnet_9blocks(???pix2pix.ipynb에서 따로 고쳐주는 부분이 없네요 unet을 사용해야하는 것이 아닌지?))
norm (str) -- the name of normalization layers used in the network: batch | instance | none

(defualt: instance (normalization))
use_dropout (bool) -- if use dropout layers.

pix2pix에서 not opt.no_dropout을 넣어주네요! 실행시 따로 지정해주지 않으면 dropout을 적용하지 않는 것 같습니다.
init_type (str) -- the name of our initialization method(공부필요)

opt.init_type은  [normal | xavier | kaiming | orthogona] 중 하나로 정할 수 있는데 default값은 normal입니다.
init_gain (float) -- scaling factor for normal, xavier and orthogonal.

(default: 0.02)
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2

 

def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[]):
   
    net = None
    norm_layer = get_norm_layer(norm_type=norm)

    if netG == 'resnet_9blocks':
        net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9)
    elif netG == 'resnet_6blocks':
        net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6)
    elif netG == 'unet_128':
        net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
    elif netG == 'unet_256':
        net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
    else:
        raise NotImplementedError('Generator model name [%s] is not recognized' % netG)
    return init_net(net, init_type, init_gain, gpu_ids)

get_norm_layer함수 내에서는 다음과 같이 normalization을 수행해주네요, functools를 사용합니다.

elif norm_type == 'instance':
        norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)

 

ResnetGenerator, UnetGenerator를 사용해 network를 만들어주는데요(resnet은 일반 인코더디코더구조와 비교하기위햐 넣어줌), UnetGenerator만 살펴보겠습니다.

대부분 define_G를 실행할 때 넣어준 것이지만, num_downs만 살펴보면 되겠네요. unet128 : 7, unet_256 : 8 을 넣어주네요. 

num_downs가 7일 때 128x128이미지가 bottleneck에서 1x1이 된다고 나와 있네요. 다음 그림을 참조하면 좋을 것 같습니다.

https://www.researchgate.net/figure/Pix2pix-UNet-128-GAN-network-architecture_fig2_324962509

전형적인 nn.Module을 상속한 pytorch의 neural network 구현입니다.

이 내에서 사용 되는 UnetSkipConnectionBlock또한 network.py내의 class(nn.Module상속)입니다. skip connection을 가진 unet submodule을 정의한 class인데요, 자세한 코드 내용은 github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 를 참고하시면 됩니다.


class UnetSkipConnectionBlock(nn.Module):
...
def __init__(self, outer_nc, inner_nc, input_nc=None,
submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
...

innermost, outermost-> 가장 안쪽, 가장 바깥쪽 layer

outer_nc (int) -- the number of filters in the outer conv layer
inner_nc (int) -- the number of filters in the inner conv layer

inner_nc , outer_nc를 인풋채널수, 아웃풋 채널수와 헷갈리지 말기!

 

input_nc==None이면 input_nc를 outer_nc와 갖게 해주네요

downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
                         stride=2, padding=1, bias=use_bias)
downrelu = nn.LeakyReLU(0.2, True)
downnorm = norm_layer(inner_nc)
uprelu = nn.ReLU(True)
upnorm = norm_layer(outer_nc)

위와 같이 Unet에서의 conv, relu, norm을 설정해줍니다.

 

- outermost와 innermost 비교 (encoder부분에서 input쪽이 outermost bottleneck쪽이 innermost)

-innermost도, outtermost도 아닌 경우 (default)

dropout은 논문에서 0.5로 사용

 else:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
                                        kernel_size=4, stride=2,
                                        padding=1, bias=use_bias)
            down = [downrelu, downconv, downnorm]
            up = [uprelu, upconv, upnorm]

            if use_dropout:
                model = down + [submodule] + up + [nn.Dropout(0.5)]
            else:
                model = down + [submodule] + up

이후 

 self.model = nn.Sequential(*model)

def forward(self,x) 에서도 outermost일때와 아닐 때의 차이가 있네요 -> 아닐 때 skip connection 적용

def forward(self, x):
        if self.outermost:
            return self.model(x)
        else:   # add skip connections
            return torch.cat([x, self.model(x)], 1)

UnetGenerator

class UnetGenerator(nn.Module):
    """Create a Unet-based generator"""

    def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
        """Construct a Unet generator
        Parameters:
            input_nc (int)  -- the number of channels in input images
            output_nc (int) -- the number of channels in output images
            num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
                                image of size 128x128 will become of size 1x1 # at the bottleneck
            ngf (int)       -- the number of filters in the last conv layer
            norm_layer      -- normalization layer
        We construct the U-Net from the innermost layer to the outermost layer.
        It is a recursive process.
        """
        super(UnetGenerator, self).__init__()
        # construct unet structure
        unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)  # add the innermost layer
        for i in range(num_downs - 5):          # add intermediate layers with ngf * 8 filters
            unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
        # gradually reduce the number of filters from ngf * 8 to ngf
        unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
        unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
        unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
        self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)  # add the outermost layer

    def forward(self, input):
        """Standard forward"""
        return self.model(input)

UnetSkipConnectionBlock의 inner_nc, outer_nc에 ngf * 8, ngf * 4 ..... -> Unet 공부 필요

이렇게 Unet을 구현하는데, innermost ~> outermost 순으로 구현하는 것이 보이네요.

submodule에 이전에 정의한 unet_block을 넣어주며 점점 쌓아줍니다.

inner_nc와 outer_nc(필터수)를 ngf*x에서 ngf까지 점점 줄여주는 것을 볼 수 있네요, 마지막 UnetSkipConnectionBlock을 추가할 땐, outer_nc가 output_nc입니다.(outerlayer의 필터수가 output의 채널 수가 됩니다.)

............................  ......................................

define_D

discriminator을 생성합니다.

input_nc (int) -- the number of channels in input images ndf (int) -- the number of filters in the first conv layer netD (str) -- the architecture's name: basic | n_layers | pixel

basic이 default patchgan classifier
pixel이 논문에서의 pixelgan이라 보면 되겠다
n_layer?
n_layers_D (int) -- the number of conv layers in the discriminator;
effective when netD=='n_layers'
norm (str) -- the type of normalization layers used in the network. init_type (str) -- the name of the initialization method.
(normal/xavier/orthogonal)
init_gain (float) -- scaling factor for normal, xavier and orthogonal. gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2  

 

def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[]):

    net = None
    norm_layer = get_norm_layer(norm_type=norm)

    if netD == 'basic':  # default PatchGAN classifier
        net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer)
    elif netD == 'n_layers':  # more options
        net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer)
    elif netD == 'pixel':     # classify if each pixel is real or fake
        net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer)
    else:
        raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD)
    return init_net(net, init_type, init_gain, gpu_ids)

pixceldiscriminator은 비교를 위한 것이므로, basic, n_layers 의 경우에 사용 되는 NLayerDiscriminator 클래스를 살펴뵤자

<models/network.py NLayerDiscriminator>

class NLayerDiscriminator(nn.Module):
    """Defines a PatchGAN discriminator"""
    def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
        super(NLayerDiscriminator, self).__init__()
        if type(norm_layer) == functools.partial:  # no need to use bias as BatchNorm2d has affine parameters
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d
            
            ... (생략)
            
    def forward(self, input):
        """Standard forward."""
        return self.model(input)

전체적인 흐름을 걸명하자면, 일단 batchnorm일때는 bias를 사용하지 않는데(convolution할 때), 주석에 나온 것 처럼 효율성 측면에서 그렇게 설정한 것이고, 

이후 

kw에 kernel size 4 를 지정하고 padw 에 padding 1을 지정하여 계속 사용해주는데,

sequence라는 리스트에 conv layer, ReLU layer, normalizatino layer를 계속 추가해줍니다.

(하핳....대충 적은걸 일단 첨부...원래 글씨 이뻐욤,..,.,ㅎ허)


다시 pix2pix_model.py로 돌아와서, 

        if self.isTrain:
            # define loss functions
            self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
            self.criterionL1 = torch.nn.L1Loss()
            # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
            self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)

이 부분은 loss와 optimizer 설정하는 부분이네요.

networks.GANLoss에선 세가지 gan mode(lsgan, vanilla, wgangp)로 self.loss라는 인스턴트 변수에 loss를 구현하여줍니다. lsgan 이 우리가 사용할 cgan loss겠죠

논문에 나온대로 adam optimizer을 사용하네요.

 

set_input 함수와 forward함수는 간단히 살펴봅니다.

set_input 함수에선 어느방향으로 변환할 것인지 방향을 받아와서, self.real_A는 AtoB일때 A를, self.real_B는 AtoB일 때 B인풋을 말하며, self.image_paths엔 이제 경로를 넣어주는데, set_input함수에 들어가는 input은 dataloader로 생성된 인풋으로, dict형태라서, A, B, A_paths, B_paths등의 key로 데이터를 불러옵니다.(데이터 그 자체와 메타데이터 정보를 가지고 있죠)

 

forward함수는 당연히 netG에 real_A를 넣은 fake_B를 만들어냅니다.

 

이제 중요한 backpropagation과 관련된 부분인데요, 

    def backward_D(self):
        """Calculate GAN loss for the discriminator"""
        # Fake; stop backprop to the generator by detaching fake_B
        fake_AB = torch.cat((self.real_A, self.fake_B), 1)  # we use conditional GANs; we need to feed both input and output to the discriminator
        pred_fake = self.netD(fake_AB.detach())
        self.loss_D_fake = self.criterionGAN(pred_fake, False)
        # Real
        real_AB = torch.cat((self.real_A, self.real_B), 1)
        pred_real = self.netD(real_AB)
        self.loss_D_real = self.criterionGAN(pred_real, True)
        # combine loss and calculate gradients
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
        self.loss_D.backward()

    def backward_G(self):
        """Calculate GAN and L1 loss for the generator"""
        # First, G(A) should fake the discriminator
        fake_AB = torch.cat((self.real_A, self.fake_B), 1)
        pred_fake = self.netD(fake_AB)
        self.loss_G_GAN = self.criterionGAN(pred_fake, True)
        # Second, G(A) = B
        self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1
        # combine loss and calculate gradients
        self.loss_G = self.loss_G_GAN + self.loss_G_L1
        self.loss_G.backward()

input과 output을 모두 고려하는 cgan이니까, fake_AB라는 변수에 real_A와 fake_B를 합쳐서(concat) 사용해주네요.

- backward_D

netD에 fake_AB 넣어 fake인지 예측해보고, cganloss에 pred_fake와 이미지가 fake임을 알려주는 False를 넣어 self.loss_D_fake(fake이미지 넣었을 때의 loss) 생성하고, 

real_AB는 real이미지들을 concat해 cgan loss를 계산하여 self.loss_D_real를 생성하는데, 

이 두 loss를 1/2하여 최종 D의 loss인 loss_D 계산하며 loss_D.backward()로 backpropagation을 진행합니다.

-backward_G

fake_AB생성은 똑같고, netD에 fake_AB를 넣어 fake인지 예측한 후 cgan loss에 pred_fake를 넣어(이미지가 real이라고 말해줘야함-True) 계산해서 D를 속였는지 확인하고(속였다면 낮은 값이 나오겠죠?)

loss_G_L1은 우리가 논문에서 fake_B, real_B를 모두 넣어 L1 loss 계산한 수 lambda를 곱해준 L1term이네요

이 둘을 더해주어 self.loss_G라는 최종 G의 loss가 나오고, self.loss_G.backward()로 backpropagation을 진행해줍니다.

더보기

(netD의 weight들(model parameters)의 requires_grad에 아무런 영향을 주지 않으며 fake_AB 본인의 requires_grad를 False로 만듦-> fake 본인의 grad는 여전히 None이고, detach()에 의해 grad_fn도 None으로 바뀌었기 때문에 netG의 weight들로 backpropagation도 할 수 없다.->  netD만 학습하려는 것 (GAN이 아니라 일반적인 one stage Neural Network에서도 input의 requires_grad가 True일 필요는 전혀 없다. 우리가 학습하려는 것은 오직 model parameters, 즉 weight들이기 때문에 weight Tensor의 requires_grad만 True이면 된다)

 

 

optimize_parameters함수에선 D와 G를 update 하는 과정을 구현하고 있습니다.

    def optimize_parameters(self):
        self.forward()                   # compute fake images: G(A)
        # update D
        self.set_requires_grad(self.netD, True)  # enable backprop for D
        self.optimizer_D.zero_grad()     # set D's gradients to zero
        self.backward_D()                # calculate gradients for D
        self.optimizer_D.step()          # update D's weights
        # update G
        self.set_requires_grad(self.netD, False)  # D requires no gradients when optimizing G
        self.optimizer_G.zero_grad()        # set G's gradients to zero
        self.backward_G()                   # calculate graidents for G
        self.optimizer_G.step()             # udpate G's weights
728x90