[Pytorch] torch.nn.Parameter

2021. 7. 26. 21:37Deep Learning/framework

728x90

arameter로 간주되는 tensor의 일종

Parameters는 class = torch.Tensor의 subclasses

Module class와 함께하면 특별한 속성을 가진다

Module Attributes가 할당되면 자동으로 Module의 parameter들의 list로 추가되고 Module.Parameters iterator에 추가 된다.

=>Tensor 할당과 다른 효과

e.i. RNN의 last hidden state같은 일부 임시상태를 모델에 cache하려고 할 수 있기 때문

Parameter같은 class가 없다면 이러한 임시적인 것들도 등록되어버릴것이다.

 

ex)

def __init__(self, input_size, hidden_size, correlation_func=1, do_similarity=False):
        super(AttentionScore, self).__init__()
        self.correlation_func = correlation_func
        self.hidden_size = hidden_size

        if correlation_func == 2 or correlation_func == 3:
            self.linear = nn.Linear(input_size, hidden_size, bias=False)
            if do_similarity:
                self.diagonal = Parameter(torch.ones(1, 1, 1) / (hidden_size ** 0.5), requires_grad=False)
            else:
                self.diagonal = Parameter(torch.ones(1, 1, hidden_size), requires_grad=True)

        if correlation_func == 4:
            self.linear = nn.Linear(input_size, input_size, bias=False)

        if correlation_func == 5:
            self.linear = nn.Linear(input_size, hidden_size, bias=False)

 

728x90

'Deep Learning > framework' 카테고리의 다른 글

[Pytorch] torch.utils.data의 Dataset과 DataLoader  (0) 2021.08.01
[Pytorch] model.eval()  (0) 2021.07.27
JSON파일 Python 으로 다루기  (0) 2021.07.04
[Pytorch] Variable  (0) 2021.06.19
[pytorch] nn.Upsample  (0) 2021.06.16