[Deep Learning] Vision Transformer(ViT) Multi-Branch 구현 실습

2026. 1. 23. 22:55·개념 정리 step2/멀티모달(Multi-modal)

안녕하세요!

오늘은 컴퓨터 비전 분야의 패러다임을 바꾼 Vision Transformer(ViT)의 핵심 개념과 이를 응용하여 여러 개의 레이블을 한 번에 예측하는 Multi-Branch Classifier를 구현한 과정을 정리해 보겠습니다.


Vision Transformer(ViT)

ViT는 "An Image is Worth 16x16 Words"라는 논문 제목처럼, 이미지를 마치 문장(Sequence)처럼 처리하는 모델입니다. 기존의 CNN이 이미지의 국소적인 특징(Local Feature)을 추출하는 데 집중했다면, ViT는 이미지 전체의 관계를 한 번에 파악하는 Global Context 학습에 강점이 있습니다.

핵심 구조 및 작동 원리

  1. Patch Partition & Embedding: 이미지를 고정된 크기(예: 16x16)의 패치로 나눕니다. 각 패치를 1차원으로 펼친(Flatten) 후 선형 변환을 거쳐 Patch Token으로 만듭니다.
  2. CLS Token & Positional Embedding:
    • 이미지 전체의 정보를 요약하기 위해 학습 가능한 [CLS] 토큰을 시퀀스 맨 앞에 추가합니다.
    • Transformer는 위치 정보를 알지 못하므로, 각 패치의 위치 정보를 더해주는 Positional Embedding을 수행합니다.
  3. Transformer Encoder: Self-Attention 메커니즘을 통해 이미지 내 모든 패치 간의 상관관계를 계산합니다. 이를 통해 멀리 떨어진 픽셀 간의 연관성도 효과적으로 학습합니다.
  4. MLP Head: 최종적으로 [CLS] 토큰의 출력값을 사용하여 이미지를 분류합니다.

1. PyTorch로 이해하는 ViT 전처리 (Code Review)

이미지 데이터를 Transformer가 이해할 수 있는 토큰 형태로 만드는 과정이 가장 중요합니다.

# 8x8 이미지를 4x4 패치로 나누는 과정 예시
image = np.random.rand(8, 8, 3).astype(np.float32)

# 1. 패치 단위로 reshape (2x2개의 패치 생성)
image = image.reshape(2, 4, 2, 4, 3) 
# 2. 차원 순서 변경 (패치행, 패치열, H, W, C)
image = image.transpose(0, 2, 1, 3, 4) 
# 3. 최종 패치 리스트 생성
patches = image.reshape(-1, 4, 4, 3) # (4, 4, 4, 3)

이후 각 패치는 Linear Layer를 통과하여 embedding_dim 차원의 벡터로 투영(Projection)됩니다.


2. 실습: Multi-Branch ViT 모델 설계

이번 실습에서는 하나의 이미지에서 여러 개의 숫자 레이블(예: 4자리 숫자 인식)을 동시에 예측하기 위해 Multi-Branch 구조를 사용했습니다.

모델 아키텍처 (timm 라이브러리 활용)

timm 라이브러리의 vit_base_patch16_224 모델을 Backbone으로 사용하고, 상단에 독립적인 4개의 Classifier(Branch)를 달아 각 자릿수를 예측하도록 설계했습니다.

class ViTMultiBranchClassifier(nn.Module):
    def __init__(self, num_classes, num_branches, model_name='vit_base_patch16_224', pretrained=True):
        super().__init__()
        # Pretrained ViT 로드 (Feature Extractor로 사용)
        self.vit = timm.create_model(model_name, pretrained=pretrained, num_classes=0)
        
        # 각 자릿수를 예측할 독립적인 Head들
        self.branches = nn.ModuleList([
            nn.Sequential(
                nn.Linear(self.vit.embed_dim, 256),
                nn.GELU(),
                nn.Dropout(0.3),
                nn.Linear(256, num_classes)
            ) for _ in range(num_branches)
        ])

    def forward(self, x):
        features = self.vit(x) # 공통 특징 추출
        outputs = [branch(features) for branch in self.branches] # 각 Branch별 예측
        return outputs

3. 데이터 증강 및 학습 전략

성능을 높이기 위해 Albumentations 라이브러리를 활용하여 강력한 데이터 증강(Augmentation)을 적용했습니다.

  • ShiftScaleRotate: 이미지를 랜덤하게 회전하고 이동시켜 위치 변화에 강인하게 만듭니다.
  • LongestMaxSize & PadIfNeeded: 이미지의 비율을 유지하면서 모델 입력 크기(224x224)에 맞게 패딩을 채웁니다.
  • Normalize: ImageNet 데이터셋의 평균과 표준편차를 사용하여 정규화합니다.

주의점: ToTensorV2는 torchvision.transforms의 ToTensor와 달리 0~1 사이로 스케일링을 자동으로 해주지 않으므로, 반드시 그전에 Normalize를 수행해야 합니다.


4. 학습 및 결과 확인

학습 시 Loss는 각 Branch에서 나온 CrossEntropyLoss의 합으로 계산하였습니다.

  • Optimizer: Adam (Learning Rate: 0.0001)
  • Evaluation: 각 Branch별로 Accuracy를 측정하고, 가장 높은 성능을 보인 모델을 저장했습니다.

최종 결과 시각화

학습된 모델로 검증 데이터를 예측한 결과, 여러 자리의 숫자를 정확하게 맞추는 것을 확인할 수 있었습니다. denormalize 과정을 거쳐 원래의 이미지를 복원하고 예측값(Pred)과 함께 출력하여 직관적으로 성능을 검토했습니다.


5. 요약 및 결론

  1. ViT는 이미지를 패치 단위 시퀀스로 변환하여 Transformer의 Self-Attention을 적용한 모델입니다.
  2. 데이터가 많을수록 CNN보다 뛰어난 성능을 보이며, 패치 크기 조절을 통해 유연한 대응이 가능합니다.
  3. Multi-Branch 구조를 통해 하나의 Backbone 네트워크로 여러 개의 태스크(여러 숫자 인식 등)를 동시에 수행할 수 있음을 확인했습니다.

'개념 정리 step2 > 멀티모달(Multi-modal)' 카테고리의 다른 글

[멀티 모달] 오토인코더(Autoencoder): 비지도 학습과 생성 모델의 기초  (0) 2026.01.30
[Deep Learning] 동영상 데이터 분석: 3D CNN과 수화 인식 실습  (0) 2026.01.29
[NLP] KLUE-BERT 기반 멀티레이블 혐오 표현 분류 실습  (0) 2026.01.22
[Deep Learning] 트랜스포머(Transformer): NLP 아키텍처 정리  (0) 2026.01.19
[딥러닝 NLP] NLU에서 트랜스포머 어텐션까지 핵심 개념 정리  (1) 2026.01.16
'개념 정리 step2/멀티모달(Multi-modal)' 카테고리의 다른 글
  • [멀티 모달] 오토인코더(Autoencoder): 비지도 학습과 생성 모델의 기초
  • [Deep Learning] 동영상 데이터 분석: 3D CNN과 수화 인식 실습
  • [NLP] KLUE-BERT 기반 멀티레이블 혐오 표현 분류 실습
  • [Deep Learning] 트랜스포머(Transformer): NLP 아키텍처 정리
고니3000원
고니3000원
공부 내용 정리, 자기발전 블로그 입니다. 기존 네이버 블로그에서 티스토리로 이전했습니다. https://blog.naver.com/pak1010pak
  • 고니3000원
    곤이의 공부 블로그
    고니3000원
  • 전체
    오늘
    어제
    • 분류 전체보기 (176)
      • 1. AI 논문 + 모델 분석 (19)
        • AI 논문 분석 (13)
        • AI 모델 분석 (6)
      • 2. 자료구조와 알고리즘 (16)
        • 2-1 자료구조와 알고리즘 (13)
        • 2-2 강화학습 알고리즘 (3)
      • 3. 자습 & 메모(실전, 실습, 프로젝트) (25)
        • 3-1 문제 해석 (4)
        • 3-2 메모(실전, 프로젝트) (14)
        • 3-3 배포 실전 공부 (7)
      • 4. [팀] 프로젝트 및 공모전 (14)
        • 4-1 팀 프로젝트(메모, 공부) (1)
        • 4-2 Meat-A-Eye (6)
        • 4-3 RL-Tycoon-Agent (3)
        • 4-4 구조물 안정성 물리 추론 AI 경진대회(D.. (4)
      • 5. [개인] 프로젝트 및 공모전 (0)
        • 4-1 귀멸의칼날디펜스(자바스크립트 활용) (5)
        • 4-2 바탕화면 AI 펫 프로그램 (4)
        • 4-3 개인 프로젝트(기타) (3)
      • 개념 정리 step1 (32)
        • Python 기초 (7)
        • DBMS (1)
        • HTML | CSS (3)
        • Git | GitHub (1)
        • JavaScript (5)
        • Node.js (5)
        • React (1)
        • 데이터 분석 (6)
        • Python Engineering (3)
      • 개념 정리 step2 (56)
        • Machine | Deep Learning (15)
        • 멀티모달(Multi-modal) (23)
        • 강화 학습 (10)
        • AI Agent (8)
  • 블로그 메뉴

    • 홈
    • 태그
    • 방명록
  • 링크

    • 네이버 곤이의 블로그(Naver->Tistory)
    • Github
  • 공지사항

  • 인기 글

  • 태그

    github
    javascript
    학습
    구현
    paddleocr
    자료구조
    Attention Is All You Need
    Ai
    OCR학습
    알고리즘
    파이썬
    pandas
    Vision
    프로젝트
    bottleneck
    transformer
    Python
    Grad-CAM
    강화학습
    논문 리뷰
    파인튜닝
    ViT
    데이터분석
    자바스크립트
    귀칼
    html
    공모전
    강화 학습
    OCR
    EfficientNet
  • 최근 댓글

  • 최근 글

  • hELLO· Designed By정상우.v4.10.5
고니3000원
[Deep Learning] Vision Transformer(ViT) Multi-Branch 구현 실습
상단으로

티스토리툴바