Pytorch

[파이토치] 기본2

왕초보코딩러 2024. 8. 1. 20:44
728x90

tensor shape 변환

요소의 수를 맞추는 것이 중요!! 안 맞추면 에러

-1을 이용하여 알아서 계산하게 할 수 있다

1. view()

# 3X3X3
data = torch.rand(3, 3, 3) # 총 요소 수 27개

# 3X9로 바꾸기
data.view(3, 9)

# 행은 3으로, 열은 알아서 계산
data.view(3, -1)

# 1차원으로
data.view(-1)

2. reshape()

# 3X3X3
data = torch.rand(3, 3, 3) # 총 요소 수 27개

# 3X9로 바꾸기
data.reshape(3, 9)

# 행은 3으로, 열은 알아서 계산
data.reshape(3, -1)

# 1차원으로
data.reshape(-1)

 

tensor의 형태 확인

1. shape 확인

data.shape

2. 차원 확인

data.dim()

3. 데이터 타입 확인

data.dtype

4. 데이터 타입 변경: type()

# 바꿀 타입 쓰기
data.type(torch.float32)

 

tensor 연산

+, -, *, / 가능

 

최소값, 최대값, 평균 등

# 최소값
data.min() # 하나 리턴

data.min(dim=0)
# dim=0: 각 행의
# values: 최소값
# indices: 몇 번째 인덱스에 있는지(열 인덱스)


# 최대값
data.max() # 하나 리턴

data.max(dim=1)
# dim=1: 각 열의
# values: 최소값
# indices: 몇 번째 인덱스에 있는지(행 인덱스)

 

 

인덱싱과 슬라이싱

data = torch.tensor([[1,2,3],[4,5,6]])

data[0] tensor([1, 2, 3])

data[0][1] # tensor(2)

data[:2, :2] # 행, 열 tensor([[1, 2], [4, 5]])

 

차원 추가, 삭제

추가: unsqueeze()

# 배치 차원 붙여야 할 때: unsqueeze (자주 쓰임)
# 앞에 차원 1 붙여준다
data = torch.rand(3, 128, 128)
data.unsqueeze(dim=0).shape       # torch.Size([1, 3, 128, 128])


data = torch.rand(3, 128, 128)
data.unsqueeze(dim=1).shape       # torch.Size([3, 1, 128, 128])

 

삭제: squeeze()

# 차원이 1인 게 없으면 -> 변화 없음
data = torch.rand(128, 2, 128)
data.squeeze().shape                        # torch.Size([128, 2, 128])

# 차원이 1인 게 있으면 -> 하나 삭제
data = torch.rand(1, 128, 128)
data.squeeze().shape                        # torch.Size([128, 128])

 

 

데이터 합치기

cat(): 합치려는 차원 외의 shape 값이 같아야 한다

# error 합치려는 차원 외 shape 값이 같아야 함
data1 = torch.rand(1,4,2)
data2 = torch.rand(3,4,2)

result = torch.cat((data1, data2), dim=1) # 1차원 제외 값이 같아야 함


result = torch.cat((data1, data2), dim=0) # 0차원 제외 값이 같아야 함
result.shape                              # torch.Size([4, 4, 2])

 

 

데이터 쌓기

stack(): 모든 shape값이 같아야 한다

data1 = torch.rand(3, 128, 128)
data2 = torch.rand(3, 128, 128)

torch.stack([data1, data2], dim=0).shape  # torch.Size([2, 3, 128, 128])

torch.stack([data1, data2], dim=1).shape  # torch.Size([3, 2, 128, 128])

'Pytorch' 카테고리의 다른 글

[Pytorch] 모델 저장하기, 불러오기  (0) 2024.07.23
[Pytorch] 텐서(Tensor) 자료형  (0) 2024.07.16