안녕하세요!
오늘은 컴퓨터 비전 분야의 패러다임을 바꾼 Vision Transformer(ViT)의 핵심 개념과 이를 응용하여 여러 개의 레이블을 한 번에 예측하는 Multi-Branch Classifier를 구현한 과정을 정리해 보겠습니다.
Vision Transformer(ViT)
ViT는 "An Image is Worth 16x16 Words"라는 논문 제목처럼, 이미지를 마치 문장(Sequence)처럼 처리하는 모델입니다. 기존의 CNN이 이미지의 국소적인 특징(Local Feature)을 추출하는 데 집중했다면, ViT는 이미지 전체의 관계를 한 번에 파악하는 Global Context 학습에 강점이 있습니다.
핵심 구조 및 작동 원리
- Patch Partition & Embedding: 이미지를 고정된 크기(예: 16x16)의 패치로 나눕니다. 각 패치를 1차원으로 펼친(Flatten) 후 선형 변환을 거쳐 Patch Token으로 만듭니다.
- CLS Token & Positional Embedding:
- 이미지 전체의 정보를 요약하기 위해 학습 가능한 [CLS] 토큰을 시퀀스 맨 앞에 추가합니다.
- Transformer는 위치 정보를 알지 못하므로, 각 패치의 위치 정보를 더해주는 Positional Embedding을 수행합니다.
- Transformer Encoder: Self-Attention 메커니즘을 통해 이미지 내 모든 패치 간의 상관관계를 계산합니다. 이를 통해 멀리 떨어진 픽셀 간의 연관성도 효과적으로 학습합니다.
- MLP Head: 최종적으로 [CLS] 토큰의 출력값을 사용하여 이미지를 분류합니다.
1. PyTorch로 이해하는 ViT 전처리 (Code Review)
이미지 데이터를 Transformer가 이해할 수 있는 토큰 형태로 만드는 과정이 가장 중요합니다.
# 8x8 이미지를 4x4 패치로 나누는 과정 예시
image = np.random.rand(8, 8, 3).astype(np.float32)
# 1. 패치 단위로 reshape (2x2개의 패치 생성)
image = image.reshape(2, 4, 2, 4, 3)
# 2. 차원 순서 변경 (패치행, 패치열, H, W, C)
image = image.transpose(0, 2, 1, 3, 4)
# 3. 최종 패치 리스트 생성
patches = image.reshape(-1, 4, 4, 3) # (4, 4, 4, 3)
이후 각 패치는 Linear Layer를 통과하여 embedding_dim 차원의 벡터로 투영(Projection)됩니다.
2. 실습: Multi-Branch ViT 모델 설계
이번 실습에서는 하나의 이미지에서 여러 개의 숫자 레이블(예: 4자리 숫자 인식)을 동시에 예측하기 위해 Multi-Branch 구조를 사용했습니다.
모델 아키텍처 (timm 라이브러리 활용)
timm 라이브러리의 vit_base_patch16_224 모델을 Backbone으로 사용하고, 상단에 독립적인 4개의 Classifier(Branch)를 달아 각 자릿수를 예측하도록 설계했습니다.
class ViTMultiBranchClassifier(nn.Module):
def __init__(self, num_classes, num_branches, model_name='vit_base_patch16_224', pretrained=True):
super().__init__()
# Pretrained ViT 로드 (Feature Extractor로 사용)
self.vit = timm.create_model(model_name, pretrained=pretrained, num_classes=0)
# 각 자릿수를 예측할 독립적인 Head들
self.branches = nn.ModuleList([
nn.Sequential(
nn.Linear(self.vit.embed_dim, 256),
nn.GELU(),
nn.Dropout(0.3),
nn.Linear(256, num_classes)
) for _ in range(num_branches)
])
def forward(self, x):
features = self.vit(x) # 공통 특징 추출
outputs = [branch(features) for branch in self.branches] # 각 Branch별 예측
return outputs
3. 데이터 증강 및 학습 전략
성능을 높이기 위해 Albumentations 라이브러리를 활용하여 강력한 데이터 증강(Augmentation)을 적용했습니다.
- ShiftScaleRotate: 이미지를 랜덤하게 회전하고 이동시켜 위치 변화에 강인하게 만듭니다.
- LongestMaxSize & PadIfNeeded: 이미지의 비율을 유지하면서 모델 입력 크기(224x224)에 맞게 패딩을 채웁니다.
- Normalize: ImageNet 데이터셋의 평균과 표준편차를 사용하여 정규화합니다.
주의점: ToTensorV2는 torchvision.transforms의 ToTensor와 달리 0~1 사이로 스케일링을 자동으로 해주지 않으므로, 반드시 그전에 Normalize를 수행해야 합니다.
4. 학습 및 결과 확인
학습 시 Loss는 각 Branch에서 나온 CrossEntropyLoss의 합으로 계산하였습니다.
- Optimizer: Adam (Learning Rate: 0.0001)
- Evaluation: 각 Branch별로 Accuracy를 측정하고, 가장 높은 성능을 보인 모델을 저장했습니다.

최종 결과 시각화
학습된 모델로 검증 데이터를 예측한 결과, 여러 자리의 숫자를 정확하게 맞추는 것을 확인할 수 있었습니다. denormalize 과정을 거쳐 원래의 이미지를 복원하고 예측값(Pred)과 함께 출력하여 직관적으로 성능을 검토했습니다.

5. 요약 및 결론
- ViT는 이미지를 패치 단위 시퀀스로 변환하여 Transformer의 Self-Attention을 적용한 모델입니다.
- 데이터가 많을수록 CNN보다 뛰어난 성능을 보이며, 패치 크기 조절을 통해 유연한 대응이 가능합니다.
- Multi-Branch 구조를 통해 하나의 Backbone 네트워크로 여러 개의 태스크(여러 숫자 인식 등)를 동시에 수행할 수 있음을 확인했습니다.
'개념 정리 step2 > 멀티모달(Multi-modal)' 카테고리의 다른 글
| [멀티 모달] 오토인코더(Autoencoder): 비지도 학습과 생성 모델의 기초 (0) | 2026.01.30 |
|---|---|
| [Deep Learning] 동영상 데이터 분석: 3D CNN과 수화 인식 실습 (0) | 2026.01.29 |
| [NLP] KLUE-BERT 기반 멀티레이블 혐오 표현 분류 실습 (0) | 2026.01.22 |
| [Deep Learning] 트랜스포머(Transformer): NLP 아키텍처 정리 (0) | 2026.01.19 |
| [딥러닝 NLP] NLU에서 트랜스포머 어텐션까지 핵심 개념 정리 (1) | 2026.01.16 |