[Pytorch] torch.no_grad
2021. 5. 22. 15:15ㆍDeep Learning/framework
728x90
기록을 추적하는 것(과 메모리를 사용하는 것)을 방지하기 위해, 코드 블럭을 with torch.no_grad(): 로 감쌀 수 있음.
https://pytorch.org/docs/stable/generated/torch.no_grad.html
no_grad — PyTorch 1.8.1 documentation
Shortcuts
pytorch.org
- gradient 계산이 되지 않게 하는 Context-manager
- gradient calculation을 금지하는 것은 inference에서 유용
- (Tensor.backward()를 호출하지 않는다는 가정하에)
- 이 context 내부에서 새로 생성된 tensor들은 requires_grad = False상태가 됨
- 메모리 사용량을 아껴줌
- thread local한 context manager -> 다른 thread의 계산에는 영향 x
- 파이썬의 데코레이터로도 사용 가능
>>> x = torch.tensor([1], requires_grad=True)
>>> with torch.no_grad():
... y = x * 2
>>> y.requires_grad
False
>>> @torch.no_grad()
... def doubler(x):
... return x * 2
>>> z = doubler(x)
>>> z.requires_grad
False
728x90
'Deep Learning > framework' 카테고리의 다른 글
[Pytorch] Tensor 의 .view()메소드 / .reshape()메소드 (0) | 2021.05.24 |
---|---|
[Pytorch] torch.sort (0) | 2021.05.22 |
[PyTorch] tensor 부등호 연산 (?), list와 차이점 (0) | 2021.05.13 |
[Pytorch] gpu가 인식이 안 되는 오류 - Linux 에서 Nvidia driver설치 (0) | 2021.04.05 |
torch.backends.cudnn.benchmark = True (0) | 2021.03.30 |