Pytorch

[Pytorch] 모델 저장하기, 불러오기

왕초보코딩러 2024. 7. 23. 15:42
728x90

torch 모델을 학습시키고 저장하거나 불러오는 방법입니다

 

필요 라이브러리 임포트

import torch

 

 

torch 모델 저장

torch.save()

 

두 가지 방법이 있습니다.

1. 모델 전체 저장하기(용량이 크다)

2. 모델 가중치만 저장하기(용량이 작다)

 

1. 모델 전체 저장하기

# 모델, 저장할 이름.pth
torch.save(model, 'model_all.pth')

 

2. 모델 가중치만 저장하기(권장)

# 모델 가중치, 저장할 이름.pth
torch.save(model.state_dict(), 'model_weight.pth')

 

 

 

torch 모델 불러오기

torch.load()

두 방법 모두 모델 구조를 정의하는 클래스가 필요합니다!

 

 

모델 전체를 저장한 pth 파일 불러오기

# 모델 저장 이름
model = torch.load('model_all.pth')

 

 

모델 가중치만 저장한 pth 파일 불러오기

# 클래스로 모델 정의
model = 클래스()
# 모델 저장 이름
model.load_state_dict(torch.load('model_all.pth'))

 

 

 


예를 들어, 제가 torch에 내장된 resnet 모델을 Feature Extraction하겠습니다.

 

필요 라이브러리 임포트

import torch
import torch.nn as nn
from torchvision import models

 

모델 구조 정의

model = models.resnet32(pretraind=True)

for param in model.parameters():
  param.requires_grad=False # 프리징
  
model.fc = nn.Linear(512, 3) # 분류기 부분만 바꿈

 

모델을 학습했다고 치고 저장해보겠습니다.

# resnet34를 사용하여 3개의 카테고리 분류하는 모델
torch.save(model, 'resnet34_3.pth')

 


다른 곳에서 모델 불러오기

 

필요 라이브러리 임포트

import torch
import torch.nn as nn
from torchvision import models

 

모델 구조 정의

load_model = models.resnet32(pretraind=True)
load_model.fc = nn.Linear(512, 3)

load_model.load_state_dict(torch.load('resnet34_3.pth'))

 

이제 model을 다시 사용할 수 있습니다!