[Pytorch] torch.no_grad

2021. 5. 22. 15:15Deep Learning/framework

728x90

기록을 추적하는 것(과 메모리를 사용하는 것)을 방지하기 위해, 코드 블럭을 with torch.no_grad(): 로 감쌀 수 있음.
https://pytorch.org/docs/stable/generated/torch.no_grad.html

 

no_grad — PyTorch 1.8.1 documentation

Shortcuts

pytorch.org


  1. gradient 계산이 되지 않게 하는 Context-manager
  2. gradient calculation을 금지하는 것은 inference에서 유용
  3. (Tensor.backward()를 호출하지 않는다는 가정하에)
  4. 이 context 내부에서 새로 생성된 tensor들은 requires_grad = False상태가 됨
  5. 메모리 사용량을 아껴줌
  6. thread local한 context manager -> 다른 thread의 계산에는 영향 x
  7. 파이썬의 데코레이터로도 사용 가능
>>> x = torch.tensor([1], requires_grad=True)
>>> with torch.no_grad():
...   y = x * 2
>>> y.requires_grad
False
>>> @torch.no_grad()
... def doubler(x):
...     return x * 2
>>> z = doubler(x)
>>> z.requires_grad
False
728x90