티스토리 뷰
텐서(Tensor)
흔히 수학에서 사용하는 벡터(1차원), 행렬(2차원)에 이어 데이터를 3차원단위로 묶어 표현할 때 텐서를 사용한다.
즉, 3차원 행렬이라고 보면 된다. 3개의 인덱스가 [세로, 가로, 깊이] 순으로 특정 텐서안의 요소를 지정하는데 사용한다.
예를 들어 이미지인 경우 [batch, width, height] 자연어 처리의 경우 [batch, length, dim]을 의미한다.
Numpy & PyTorch
Numpy 배열과 PyTorch 텐서 비교
NumPy | PyTorch | |
선언 | array | FloatTensor |
차원 | ndim | dim() |
크기 | shape | size() |
초기 선언시에 배열에 선언해 넣는 방법은 동일하다. 사용하기 직관적이고 NumPy와 비슷하다.
Broadcasting
행렬 연산에서 사이즈를 맞추어서 자동으로 만들어서 사용하는 기능을 뜻한다.
다음과 같이 텐서 m1은 1x2크기의 벡터이고 m2는 두 개의 스칼라값 이다.
m1 = torch.FloatTensor([[2,3]])
m2 = torch.FloatTensor([[3],[4]])
print(m1 + m2)
결과는 다음과 같이 텐서들은 2x2로 만들어서 수행된 결과를 보여준다.
tensor([[5., 6.],
[6., 7.]])
매트릭스 곱 & 산술곱
m1.matmul(m2): 매트릭스 곱하기(Matrix Multiplication)
m1.mul(m2): 산술연산(Element-wise Multiplication)
평균(mean)
t.mean(): 전체 평균
t.mean(dim=0): 첫번째 차원단위 평균
t.mean(dim=1): 두번째 차원단위 평균
t = torch.FloatTensor([[1, 2], [3, 4]])
print(t)
print(t.mean())
print(t.mean(dim=0))
print(t.mean(dim=1))
print(t.mean(dim=-1))
실행결과
tensor(2.5000)
tensor([2., 3.])
tensor([1.5000, 3.5000])
tensor([1.5000, 3.5000])
합계(sum)
t.sum()
t.mean(dim=0)
t.mean(dim=1)
최대(Max, ArgMax)
t.max(): 전체 텐서 데이터 중 최대 값만 인출
t.max(dim=0): 특정 차원단위 최대값(max)와 최대값을 갖는 데이터 인덱스 인출
차원을 지정하면 두개의 데이터를 뽑을 수 있다. t.max(dim=0)[0]이 최대값(max)를 나타내고 t.max(dim=0)[1]이 최대값을 갖는 데이터 인덱스를 나타내게 된다.
t = torch.FloatTensor([[1, 2], [3, 4]])
print(t)
print(t.max())
print(t.max(dim=0)) # Returns two values: max and argmax
print('Max: ', t.max(dim=0)[0])
print('Argmax: ', t.max(dim=0)[1])
tensor([[1., 2.],
[3., 4.]])
tensor(4.)
(tensor([3., 4.]), tensor([1, 1]))
Max: tensor([3., 4.])
Argmax: tensor([1, 1])
Reference
https://jhui.github.io/2018/02/09/PyTorch-Basic-operations/
'CVML > PyTorch' 카테고리의 다른 글
Detectron2 Docker환경상에서 설치 후 데모 실행하기 (0) | 2020.10.29 |
---|---|
PyTorch로 Convolution Layer 출력 크기 계산해보기 (0) | 2019.06.08 |
모두를 위한 딥러닝 시즌2 - PyTorch 링크 (0) | 2019.04.01 |
도커로 PyTorch 학습환경 구축하기 (0) | 2019.04.01 |
- Total
- Today
- Yesterday
- docker
- tensorflow
- nvidia
- CAD
- Maker
- 메이커
- Fusion360
- 우분투
- conda
- Streamlit
- Python
- Stable Diffusion
- WSL
- 3d프린터
- ssh
- cura
- 한글
- 단축키
- vvvv
- ubuntu
- git
- Arduino
- nodejs
- MicroBit
- opencv
- 파이썬
- comfyUI
- fablab
- vscode
- Linux
일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | 5 | ||
6 | 7 | 8 | 9 | 10 | 11 | 12 |
13 | 14 | 15 | 16 | 17 | 18 | 19 |
20 | 21 | 22 | 23 | 24 | 25 | 26 |
27 | 28 | 29 | 30 | 31 |