국비지원교육/Python
선형회귀 with googlecolab (1)
HanJW96
2024. 5. 30. 10:50
1. import
import torch
from torch.autograd import Variable
2. Tensor vs Variable
x_tensor = torch.Tensor(3,4)
x_tensor
tensor => Variable로 바꾸기
x_variable = Variable(x_tensor)
x_variable
근데 왜 똑같이 tensor로 나올까?
pytorch 0.4이상 버전에서는 Tensor에 Variable이 통합되고, Variable은 deprecated임
굳이 variable로 바꿀 필요가없다!
# data 속성
x_variable.data
tensor([[1.4729e+07, 4.4386e-41, 1.5362e+07, 4.4386e-41],
[1.5362e+07, 4.4386e-41, 1.5362e+07, 4.4386e-41],
[1.5362e+07, 4.4386e-41, 1.5362e+07, 4.4386e-41]])
# grad 속성 : 값에 대한 gradient
# Variable 생성시 초기화되면서, gradient 도 같이 정의됨
print(x_variable.grad)
None
경사도가 None이 나온 이유 ?
# requires_grad 속성 : 값에 대한 gradient 요구시 사용함
print(x_variable.requires_grad)
False
생성할 때 grad를 계산하도록 해야함. 아무것도 쓰지않으면 requires_grad = false임
다시 만들기
# gradient 에 대한 연산을 수행하게 함 : True
x_variable = Variable(x_tensor, requires_grad=True)
x_variable.requires_grad
3. Graph & Variables
# create graph
x = Variable(torch.FloatTensor(3,4), requires_grad=True)
y = x**2 + 4*x
z = 2*y + 3
x.requires_grad, y.requires_grad, z.requires_grad
backword(gradient, retain_graph, create_graph, retain_variables)??
현재 값 w, r, t, graph 에 대한 gradient 계산 함수임 역전파 알고리즘 적용된 계산함수
위의 z 값으로 x의 gradient를 계산해 냄
loss = torch.FloatTensor(3,4)
z.backward(loss)
print(x.grad)
y.grad, z.grad
<result>
tensor([[ 2.0223e+08, 3.5509e-40, -8.9486e+20, 2.4784e-40],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 5.6052e-44, 6.1642e+32, 5.7718e+23]])
<ipython-input-19-f26e95b4536c>:6: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations. (Triggered internally at aten/src/ATen/core/TensorBody.h:489.)
y.grad, z.grad
(None, None)