[논문 구현] 멀티모달 병목 트랜스포머(MBT) PyTorch 구현 및 RAVDESS 감정 인식 테스트

2026. 2. 16. 20:55·1. AI 논문 + 모델 분석/AI 논문 분석

1. 프로젝트 개요

이번 프로젝트는 비디오(시각)와 오디오(청각) 데이터를 효율적으로 융합하는 아키텍처를 제안한 논문, Attention Bottlenecks for Multimodal Fusion (NeurIPS 2021)의 아이디어를 직접 PyTorch로 구현해 본 기록입니다.

논문의 핵심인 멀티모달 병목(Bottleneck) 구조가 실제로 어떻게 작동하는지 확인하기 위해, Kaggle의 RAVDESS (배우 감정 인식 비디오/오디오) 데이터셋을 활용하여 화자의 감정을 8가지로 분류하는 모델을 구축하고 테스트를 진행했습니다.

 

논문 링크:

https://arxiv.org/abs/2107.00135

 

Attention Bottlenecks for Multimodal Fusion

Humans perceive the world by concurrently processing and fusing high-dimensional inputs from multiple modalities such as vision and audio. Machine perception models, in stark contrast, are typically modality-specific and optimised for unimodal benchmarks,

arxiv.org

 

논문 리뷰 블로그:

https://pak1010pak.tistory.com/131

 

[논문 리뷰] Multimodal Bottleneck Transformer (MBT)

안녕하세요!오늘 리뷰할 논문은 인간의 인지 방식에서 영감을 받아 비디오 분류(Video Classification)를 위한 새로운 멀티모달 융합 방식을 제안한 Multimodal Bottleneck Transformer (MBT) 입니다.처음 멀티모

pak1010pak.tistory.com


2. 환경 설정 및 데이터 파이프라인 최적화

학습 환경은 A100 80GB GPU가 탑재된 Google Colab 환경을 사용했습니다.

초기에는 매 에포크마다 원본 .mp4 비디오를 열어서 프레임을 자르고 오디오를 추출했습니다. 하지만 이 방식은 심각한 CPU 병목(Bottleneck) 현상을 유발하여 고성능 GPU를 전혀 활용하지 못하는 문제가 있었습니다.

이를 해결하기 위해 사전 텐서 캐싱(Caching) 파이프라인을 구축했습니다.

  • Vision: OpenCV를 사용하여 균일한 간격으로 16프레임을 추출하고 RGB 변환 및 정규화를 거쳐 텐서로 만듭니다.
  • Audio: Librosa를 활용해 오디오 트랙을 128x128 크기의 Mel-Spectrogram으로 변환합니다.
  • 최적화: 위 과정을 데이터셋 전체에 대해 딱 한 번만 수행한 뒤, 파이토치 텐서 파일(.pt)로 디스크에 저장합니다. 학습 시에는 무거운 디코딩 과정 없이 저장된 텐서만 로드하므로 A100의 연산 속도를 100% 활용할 수 있게 되었습니다.
# 1. 단 한 번만 실행하는 전처리 캐싱 로직
save_dir = "ravdess_preprocessed"
os.makedirs(save_dir, exist_ok=True)

# 원본 비디오/오디오를 미리 텐서로 변환하여 .pt 파일로 디스크에 저장
for i in tqdm(range(len(multimodal_dataset))):
    vision_tensor, audio_tensor, label_tensor = multimodal_dataset[i]
    save_path = os.path.join(save_dir, f"data_{i}.pt")
    torch.save({
        'vision': vision_tensor,
        'audio': audio_tensor,
        'label': label_tensor
    }, save_path)

# 2. 훈련 시 사용할 초고속 커스텀 Dataset
class FastRAVDESSDataset(Dataset):
    def __init__(self, preprocessed_dir):
        self.file_paths = [os.path.join(preprocessed_dir, f) for f in os.listdir(preprocessed_dir) if f.endswith('.pt')]
        
    def __getitem__(self, idx):
        # 무거운 연산 없이 디스크에서 텐서만 바로 메모리로 로드
        data = torch.load(self.file_paths[idx], weights_only=True)
        return data['vision'], data['audio'], data['label']
처음에는 __getitem__ 안에서 매번 cv2.VideoCapture와 librosa.load를 호출했습니다. 하지만 이 방식은 막강한 A100 GPU를 두고도 CPU가 영상 디코딩을 하느라 병목(Bottleneck)이 생기는 참사가 발생했습니다. 이를 해결하기 위해 데이터 사전 전처리 및 캐싱 로직을 적용했습니다.

3. 모델 아키텍처: SimpleMBT 구현

논문에서 제안한 'Mid-Fusion' 전략과 'Attention Bottleneck' 구조를 SimpleMBT 클래스로 구현했습니다.

  • 특징 추출 (Feature Extraction)
    • 시각 정보는 사전 학습된 ResNet18의 초기 레이어를 통과시켜 특징 벡터로 압축합니다.
    • 청각 정보(스펙트로그램)는 Conv2d를 사용한 패치 임베딩을 통해 트랜스포머에 들어갈 시퀀스 형태로 변환합니다.
  • 병목 융합 (Bottleneck Fusion)
    • 초기 레이어(Early phase)에서는 시각과 청각이 서로 독립적인 트랜스포머 인코더를 통과합니다. (유니모달 학습)
    • 지정된 중간 레이어(Fusion layer)부터는 학습 가능한 파라미터인 병목 토큰(Bottleneck Tokens)을 입력 시퀀스에 이어 붙입니다(torch.cat).
    • 시각/청각 토큰들은 직접 교류하지 않고 오직 병목 토큰하고만 어텐션 연산을 수행합니다. 이후 각 모달리티에서 업데이트된 병목 토큰의 평균을 내어 다음 레이어로 전달하는 대칭 업데이트 방식을 적용했습니다.

논문의 핵심: 병목 어텐션 (Bottleneck Attention) 융합 로직

MBT의 가장 큰 특징인 '지정된 레이어부터 병목 토큰을 통해 정보를 교환하는' 로직입니다. 복잡한 수식 대신 PyTorch의 텐서 조작(torch.cat과 슬라이싱)을 통해 직관적으로 구현할 수 있었습니다.

# SimpleMBT 클래스의 forward 함수 내부 로직 중 일부

# 배치 크기에 맞게 병목 토큰(학습 가능한 파라미터) 확장
b_tokens = self.bottleneck_tokens.expand(batch_size, -1, -1)

# 트랜스포머 인코더 통과 및 병목 융합 (Mid-Fusion)
for i in range(len(self.vision_transformers)):
    if i < self.fusion_layer:
        # [Early Phase] 융합 레이어 도달 전: 각자 유니모달로 학습
        v_tokens = self.vision_transformers[i](v_tokens)
        a_tokens = self.audio_transformers[i](a_tokens)
    else:
        # [Mid-Fusion Phase] 융합 시작: 데이터 토큰 뒤에 병목 토큰을 이어 붙임
        v_input = torch.cat([v_tokens, b_tokens], dim=1)
        a_input = torch.cat([a_tokens, b_tokens], dim=1)
        
        # 각각의 트랜스포머 통과 (이 과정에서 데이터와 병목 토큰 간의 Attention 발생)
        v_out = self.vision_transformers[i](v_input)
        a_out = self.audio_transformers[i](a_input)
        
        # 통과 후 텐서 분리: 원래 데이터 토큰과 업데이트된 임시 병목 토큰
        v_tokens = v_out[:, :-self.bottleneck_size, :]
        v_b_out = v_out[:, -self.bottleneck_size:, :]
        
        a_tokens = a_out[:, :-self.bottleneck_size, :]
        a_b_out = a_out[:, -self.bottleneck_size:, :]
        
        # [핵심] 병목 토큰 대칭 업데이트: 시각/청각에서 얻은 임시 병목의 평균을 계산
        b_tokens = (v_b_out + a_b_out) / 2.0

4. 학습 진행 및 결과

고속 데이터로더(num_workers=12, pin_memory=True)를 적용한 결과, A100 환경에서 1 에포크당 약 37초라는 매우 빠른 속도로 학습을 완료할 수 있었습니다.

[Epoch 1] Loss: 0.5081, Accuracy: 84.91%
[Epoch 2] Loss: 0.1584, Accuracy: 95.47%
[Epoch 3] Loss: 0.1112, Accuracy: 96.49%
[Epoch 4] Loss: 0.0958, Accuracy: 97.21%
[Epoch 5] Loss: 0.0512, Accuracy: 98.41%
[Epoch 6] Loss: 0.0627, Accuracy: 98.00%
[Epoch 7] Loss: 0.0622, Accuracy: 98.12%
[Epoch 8] Loss: 0.0296, Accuracy: 99.27%
[Epoch 9] Loss: 0.0240, Accuracy: 99.35%
[Epoch 10] Loss: 0.0260, Accuracy: 99.18%

단 10 에포크 만에 훈련 정확도 99.18%, Loss 0.0260을 달성했습니다. RAVDESS 데이터셋의 규모가 크지 않아 모델이 학습 데이터에 빠르게 피팅된 경향이 있지만, 아키텍처가 의도한 대로 수렴하며 정상 작동한다는 것을 확실히 검증했습니다.


5. 추론 및 시각적 검증 (Visual Inference)

Loss와 Accuracy 숫자만으로는 모델이 실제로 멀티모달 데이터를 어떻게 이해하고 있는지 파악하기 어렵습니다. 이를 직관적으로 확인하기 위해 무작위로 5개의 샘플을 뽑아 검증 결과를 시각화하는 로직을 작성했습니다.

  • 왼쪽 패널 (Video Frame): 모델이 입력으로 받은 비디오의 중간 프레임 이미지입니다. 실제 정답 감정(Actual Emotion)이 함께 표시됩니다.
  • 가운데 패널 (Audio Mel-Spectrogram): 입력된 오디오 파형의 주파수 특성을 보여주는 멜 스펙트로그램 시각화입니다.
  • 오른쪽 패널 (Prediction): 최종적으로 모델이 8가지 감정에 대해 예측한 확률(%) 분포입니다. 정답 클래스는 초록색 막대로 표시되며, 오답일 경우 빨간색 막대로 표시되도록 구성하여 모델이 어떤 감정을 서로 헷갈려하는지 직관적으로 분석할 수 있습니다.

6. 회고 및 마무리

텍스트나 단일 이미지를 넘어, 비디오와 오디오라는 서로 다른 차원의 데이터를 하나의 모델 안에서 제어하는 방법을 배울 수 있었습니다.

특히 모든 토큰을 연결하는 기존의 Self-Attention 방식 대신, 소수의 '병목 토큰'만 활용하여 핵심 정보만 교환하게 만드는 아이디어가 연산량 감소는 물론 정보의 정제 측면에서도 매우 훌륭한 접근임을 코드로 직접 확인했습니다. 또한, 아무리 좋은 GPU를 사용하더라도 데이터 전처리 파이프라인(CPU/IO)이 받쳐주지 않으면 무용지물이 된다는 엔지니어링 측면의 중요한 교훈도 얻었습니다.

 

GitHub 링크

https://github.com/gonida1010/MBT-PyTorch-RAVDESS

 

GitHub - gonida1010/MBT-PyTorch-RAVDESS: Implementing the "Attention Bottlenecks for Multimodal Fusion" architecture in PyTorch

Implementing the "Attention Bottlenecks for Multimodal Fusion" architecture in PyTorch and testing emotion recognition performance on the Kaggle RAVDESS dataset. - gonida1010/MBT-PyTorch-...

github.com

'1. AI 논문 + 모델 분석 > AI 논문 분석' 카테고리의 다른 글

[논문 리뷰] Swin Transformer 정리  (0) 2026.02.23
[논문 리뷰] DETR: 객체 검출(Object Detection)의 End-to-End 파이프라인  (0) 2026.02.22
[논문 리뷰] Multimodal Bottleneck Transformer (MBT)  (0) 2026.02.15
[논문 리뷰] CLIP: 텍스트로 이미지를 이해하는 비전 모델 (OpenAI)  (0) 2026.02.12
[논문 리뷰] Bahdanau Attention: 정렬과 번역을 동시에 학습하는 신경망 기계 번역  (0) 2026.02.09
'1. AI 논문 + 모델 분석/AI 논문 분석' 카테고리의 다른 글
  • [논문 리뷰] Swin Transformer 정리
  • [논문 리뷰] DETR: 객체 검출(Object Detection)의 End-to-End 파이프라인
  • [논문 리뷰] Multimodal Bottleneck Transformer (MBT)
  • [논문 리뷰] CLIP: 텍스트로 이미지를 이해하는 비전 모델 (OpenAI)
고니3000원
고니3000원
공부 내용 정리, 자기발전 블로그 입니다. 기존 네이버 블로그에서 티스토리로 이전했습니다. https://blog.naver.com/pak1010pak
  • 고니3000원
    곤이의 공부 블로그
    고니3000원
  • 전체
    오늘
    어제
    • 분류 전체보기 (176)
      • 1. AI 논문 + 모델 분석 (19)
        • AI 논문 분석 (13)
        • AI 모델 분석 (6)
      • 2. 자료구조와 알고리즘 (16)
        • 2-1 자료구조와 알고리즘 (13)
        • 2-2 강화학습 알고리즘 (3)
      • 3. 자습 & 메모(실전, 실습, 프로젝트) (25)
        • 3-1 문제 해석 (4)
        • 3-2 메모(실전, 프로젝트) (14)
        • 3-3 배포 실전 공부 (7)
      • 4. [팀] 프로젝트 및 공모전 (14)
        • 4-1 팀 프로젝트(메모, 공부) (1)
        • 4-2 Meat-A-Eye (6)
        • 4-3 RL-Tycoon-Agent (3)
        • 4-4 구조물 안정성 물리 추론 AI 경진대회(D.. (4)
      • 5. [개인] 프로젝트 및 공모전 (0)
        • 4-1 귀멸의칼날디펜스(자바스크립트 활용) (5)
        • 4-2 바탕화면 AI 펫 프로그램 (4)
        • 4-3 개인 프로젝트(기타) (3)
      • 개념 정리 step1 (32)
        • Python 기초 (7)
        • DBMS (1)
        • HTML | CSS (3)
        • Git | GitHub (1)
        • JavaScript (5)
        • Node.js (5)
        • React (1)
        • 데이터 분석 (6)
        • Python Engineering (3)
      • 개념 정리 step2 (56)
        • Machine | Deep Learning (15)
        • 멀티모달(Multi-modal) (23)
        • 강화 학습 (10)
        • AI Agent (8)
  • 블로그 메뉴

    • 홈
    • 태그
    • 방명록
  • 링크

    • 네이버 곤이의 블로그(Naver->Tistory)
    • Github
  • 공지사항

  • 인기 글

  • 태그

    Vision
    공모전
    학습
    bottleneck
    OCR학습
    데이터분석
    강화 학습
    강화학습
    파인튜닝
    pandas
    알고리즘
    Grad-CAM
    javascript
    귀칼
    transformer
    논문 리뷰
    파이썬
    ViT
    프로젝트
    html
    github
    자바스크립트
    Python
    Ai
    OCR
    자료구조
    EfficientNet
    paddleocr
    Attention Is All You Need
    구현
  • 최근 댓글

  • 최근 글

  • hELLO· Designed By정상우.v4.10.5
고니3000원
[논문 구현] 멀티모달 병목 트랜스포머(MBT) PyTorch 구현 및 RAVDESS 감정 인식 테스트
상단으로

티스토리툴바