[Pytorch] model.eval()

2021. 7. 27. 17:41Deep Learning/framework

728x90

Pytorch로 구현된 딥러닝모델 코드들을 보면 자주 볼 수 있다

(저는 cyclegan 코드리뷰 중 궁금해졌습니다.)

def sample_images(batches_done):
    """Saves a generated sample from the test set"""
    imgs = next(iter(val_dataloader))
    G_AB.eval()
    G_BA.eval()
    real_A = Variable(imgs["A"].type(Tensor))
    fake_B = G_AB(real_A)
    real_B = Variable(imgs["B"].type(Tensor))
    fake_A = G_BA(real_B)
    ...
    ...

nn.Module에서 train time과 eval time에서 수행하는 다른 작업을 수행할 수 있도록 switching 하는 함수

train time과 eval time은 수행하야 하는 작업에 차이가 있다

Dropout이나 BatchNorm같은 것은 train에서만 수행하는 작업이다

.eval()함수는 evaluation과정에서 사용하지 않는 레이어들을 알아서 동작하지 않도록 해준다!

evaluation이나 validation과정에선 보통 model.eval() torch.no_grad()를 함께 사용

 

다시 train mode로 바꿔주려면 .train() 을 수행해줘야겠죠?

model.train()
728x90

'Deep Learning > framework' 카테고리의 다른 글

[Pytorch] torch.utils.data의 Dataset과 DataLoader  (0) 2021.08.01
[Pytorch] torch.nn.Parameter  (0) 2021.07.26
JSON파일 Python 으로 다루기  (0) 2021.07.04
[Pytorch] Variable  (0) 2021.06.19
[pytorch] nn.Upsample  (0) 2021.06.16