import csv
import math
import os
import random
import copy
import numpy as np
import torch
import torch.nn.functional
import torchaudio
from PIL import Image
from scipy import signal
from torch.utils.data import Dataset
from torchvision import transforms
import librosa

from transformers import AutoTokenizer

import pandas as pd
import json
from typing import List, Any, Dict
from collections import Counter
import json
import re
import cv2
Image.MAX_IMAGE_PIXELS = 100000000

class AVDataset_CD(Dataset):
  def __init__(self, mode='train'):
    classes = []
    self.data = []
    data2class = {}

    self.mode=mode
    self.visual_path = '../datasets/cremad/'
    self.audio_path = '../datasets/cremad/pt_cremad/'
    self.stat_path = '../datasets/cremad/stat.csv'
    self.train_txt = '../datasets/cremad/train.csv'
    self.test_txt = '../datasets/cremad/test.csv'
    if mode == 'train':
        csv_file = self.train_txt
    else:
        csv_file = self.test_txt

    
    with open(self.stat_path, encoding='UTF-8-sig') as f:
            csv_reader = csv.reader(f)
            for row in csv_reader:
                classes.append(row[0])
    
    with open(csv_file) as f:
      csv_reader = csv.reader(f)
      for item in csv_reader:
        if item[1] in classes and os.path.exists(self.audio_path + item[0] + '.pt') and os.path.exists(
                        self.visual_path + '/' + item[0]):
            self.data.append(item[0])
            data2class[item[0]] = item[1]

    print('data load over')
    print(len(self.data))
    
    self.classes = sorted(classes)

    self.data2class = data2class
    self._init_atransform()
    print('# of files = %d ' % len(self.data))
    print('# of classes = %d' % len(self.classes))

    #Audio
    self.class_num = len(self.classes)

  def _init_atransform(self):
    self.aid_transform = transforms.Compose([transforms.ToTensor()])

  def __len__(self):
    return len(self.data)

  
  def __getitem__(self, idx):
    datum = self.data[idx]

    # Audio
    fbank = torch.load(self.audio_path + datum + '.pt').unsqueeze(0)

    # Visual
    if self.mode == 'train':
        transf = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    else:
        transf = transforms.Compose([
            transforms.Resize(size=(224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

    folder_path = self.visual_path + datum
    file_num = len(os.listdir(folder_path))
    pick_num = 2
    seg = int(file_num/pick_num)
    image_arr = []

    for i in range(pick_num):
      if self.mode == 'train':
        index = random.randint(i*seg, i*seg + seg - 1)
      else:
        index = i*seg + int(seg/2)
      path = folder_path + '/' + str(index).zfill(5) + '.jpg'
      image_arr.append(transf(Image.open(path).convert('RGB')).unsqueeze(0))

    images = torch.cat(image_arr)

    return fbank, images, self.classes.index(self.data2class[datum])


class AVDataset_KS(Dataset):
    def __init__(self, mode='train'):
        classes = []
        self.data = []
        data2class = {}

        self.mode = mode
        self.visual_path = '../datasets/KS/'
        self.audio_path = '../datasets/KS/pt_cremad/'
        self.stat_path = '../datasets/KS/stat.csv'
        self.train_txt = '../datasets/KS/train.csv'
        self.test_txt = '../datasets/KS/test.csv'
        if mode == 'train':
            csv_file = self.train_txt
        else:
            csv_file = self.test_txt

        with open(self.stat_path, encoding='UTF-8-sig') as f:
            csv_reader = csv.reader(f)
            for row in csv_reader:
                classes.append(row[0])

        with open(csv_file) as f:
            csv_reader = csv.reader(f)
            for item in csv_reader:
                if item[1] in classes and os.path.exists(self.audio_path + item[0] + '.pt') and os.path.exists(
                        self.visual_path + '/' + item[0]):
                    self.data.append(item[0])
                    data2class[item[0]] = item[1]

        print('data load over')
        print(len(self.data))

        self.classes = sorted(classes)

        self.data2class = data2class
        self._init_atransform()
        print('# of files = %d ' % len(self.data))
        print('# of classes = %d' % len(self.classes))

        # Audio
        self.class_num = len(self.classes)

    def _init_atransform(self):
        self.aid_transform = transforms.Compose([transforms.ToTensor()])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        datum = self.data[idx]

        # Audio
        fbank = torch.load(self.audio_path + datum + '.pt').unsqueeze(0)

        # Visual
        if self.mode == 'train':
            transf = transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])
        else:
            transf = transforms.Compose([
                transforms.Resize(size=(224, 224)),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])

        folder_path = self.visual_path + datum
        file_num = len(os.listdir(folder_path))
        pick_num = 3
        seg = int(file_num / pick_num)
        image_arr = []

        for i in range(pick_num):
            if self.mode == 'train':
                index = random.randint(i * seg, i * seg + seg - 1)
            else:
                index = i * seg + int(seg / 2)
            path = folder_path + '/' + str(index).zfill(5) + '.jpg'
            image_arr.append(transf(Image.open(path).convert('RGB')).unsqueeze(0))

        images = torch.cat(image_arr)

        return fbank, images, self.classes.index(self.data2class[datum])

class AVDataset_AVE(Dataset):
    def __init__(self, mode='train'):
        classes = []
        self.data = []
        data2class = {}

        self.mode = mode
        self.visual_path = '../datasets/AVE/AVE_Dataset/'
        self.audio_path = '../datasets/AVE/AVE_Dataset/pt_cremad/'
        self.stat_path = '../datasets/AVE/AVE_Dataset/stat.csv'
        self.train_txt = '../datasets/AVE/AVE_Dataset/train.csv'
        self.test_txt = '../datasets/AVE/AVE_Dataset/test.csv'
        if mode == 'train':
            csv_file = self.train_txt
        else:
            csv_file = self.test_txt

        with open(self.stat_path, encoding='UTF-8-sig') as f:
            csv_reader = csv.reader(f)
            for row in csv_reader:
                classes.append(row[0])

        with open(csv_file) as f:
            csv_reader = csv.reader(f)
            for item in csv_reader:
                if item[1] in classes and os.path.exists(self.audio_path + item[0] + '.pt') and os.path.exists(
                        self.visual_path + '/' + item[0]):
                    self.data.append(item[0])
                    data2class[item[0]] = item[1]

        print('data load over')
        print(len(self.data))

        self.classes = sorted(classes)

        self.data2class = data2class
        self._init_atransform()
        print('# of files = %d ' % len(self.data))
        print('# of classes = %d' % len(self.classes))

        # Audio
        self.class_num = len(self.classes)

    def _init_atransform(self):
        self.aid_transform = transforms.Compose([transforms.ToTensor()])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        datum = self.data[idx]

        # Audio
        fbank = torch.load(self.audio_path + datum + '.pt').unsqueeze(0)

        # Visual
        if self.mode == 'train':
            transf = transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])
        else:
            transf = transforms.Compose([
                transforms.Resize(size=(224, 224)),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])

        folder_path = self.visual_path + datum
        file_num = len(os.listdir(folder_path))
        pick_num = 3
        seg = int(file_num / pick_num)
        image_arr = []

        for i in range(pick_num):
            if self.mode == 'train':
                index = random.randint(i * seg, i * seg + seg - 1)
            else:
                index = i * seg + int(seg / 2)
            path = folder_path + '/' + str(index).zfill(5) + '.jpg'
            image_arr.append(transf(Image.open(path).convert('RGB')).unsqueeze(0))

        images = torch.cat(image_arr)

        return fbank, images, self.classes.index(self.data2class[datum])


class AVDataset_MOSI(Dataset):
    def __init__(self, mode='train'):
        if mode == 'train':
            csv_file = f"../datasets/CMU-MOSI/train.csv"
        else:
            csv_file = f"../datasets/CMU-MOSI/test.csv"

        self.data_df = pd.read_csv(csv_file, encoding='latin-1')
        self.tokenizer = AutoTokenizer.from_pretrained('./encoders/bert-base-uncased')
        self.split = mode

        self.emotion_to_id = {emo: i for i, emo in enumerate(sorted(self.data_df['annotation'].unique()))}

        print('data load over')
        print(len(self.data_df))

        self.classes = sorted(self.emotion_to_id)

        self._init_atransform()
        print('# of files = %d ' % len(self.data_df))
        print('# of classes = %d' % len(self.classes))

        # Audio
        self.class_num = len(self.classes)

    def _init_atransform(self):
        self.aid_transform = transforms.Compose([transforms.ToTensor()])

    def __len__(self):
        return len(self.data_df)

    def __getitem__(self, idx):
        if self.split == 'train':
            audio_root = '../datasets/CMU-MOSI/pt_cremad'
            video_root = '../datasets/CMU-MOSI/frames/Raw_peak_frames/Raw_peak_frames'
        else:
            audio_root = '../datasets/CMU-MOSI/pt_cremad'
            video_root = '../datasets/CMU-MOSI/frames/Raw_peak_frames/Raw_peak_frames'
        item = self.data_df.iloc[idx]

        utterance_text = item['text']
        emotion = item['annotation']
        clip_id = item['clip_id']; video_id = item['video_id']

        text_tokens = {
            'input_ids': torch.zeros(512, dtype=torch.long),
            'attention_mask': torch.zeros(512, dtype=torch.long)
        }
        text_input_str = str(utterance_text) if pd.notna(utterance_text) else ""
        encoded_text = self.tokenizer.encode_plus(
            text_input_str, add_special_tokens=True, max_length=512,
            padding='max_length', truncation=True, return_attention_mask=True, return_tensors='pt'
        )
        text_tokens['input_ids'] = encoded_text['input_ids'].squeeze(0)
        text_tokens['attention_mask'] = encoded_text['attention_mask'].squeeze(0)

        audio_path = os.path.join(audio_root, f"{video_id}/{clip_id}")
        fbank = torch.load(audio_path + '.pt').unsqueeze(0)

        # Visual
        if self.split == 'train':
            transf = transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])
        else:
            transf = transforms.Compose([
                transforms.Resize(size=(224, 224)),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])

        folder_path = os.path.join(video_root, f"{video_id}/{clip_id}")
        file_num = len(os.listdir(folder_path))
        pick_num = 3
        seg = int(file_num / pick_num)
        image_arr = []

        for i in range(pick_num):
            if self.split == 'train':
                index = random.randint(i * seg, i * seg + seg - 1)
            else:
                index = i * seg + int(seg / 2)
            path = folder_path + '/' + str(index).zfill(5) + '.jpg'
            image_arr.append(transf(Image.open(path).convert('RGB')).unsqueeze(0))

        images = torch.cat(image_arr)

        label_id = self.emotion_to_id.get(emotion, 0)
        label = torch.tensor(label_id, dtype=torch.long)

        return fbank, images,text_tokens['input_ids'],text_tokens['attention_mask'], label

class AVDataset_MELD(Dataset):
    def __init__(self, mode='train'):
        if mode == 'train':
            csv_file = f"../datasets/MELD/train_sent_emo.csv"
        else:
            csv_file = f"../datasets/MELD/test_sent_emo.csv"

        self.data_df = pd.read_csv(csv_file, encoding='latin-1')
        self.tokenizer = AutoTokenizer.from_pretrained('./encoders/bert-base-uncased')
        self.split = mode

        self.emotion_to_id = {emo: i for i, emo in enumerate(sorted(self.data_df['Sentiment'].unique()))}

        print('data load over')
        print(len(self.data_df))

        self.classes = sorted(self.emotion_to_id)

        self._init_atransform()
        print('# of files = %d ' % len(self.data_df))
        print('# of classes = %d' % len(self.classes))

        # Audio
        self.class_num = len(self.classes)

    def _init_atransform(self):
        self.aid_transform = transforms.Compose([transforms.ToTensor()])

    def __len__(self):
        return len(self.data_df)

    def __getitem__(self, idx):
        if self.split == 'train':
            audio_root = '../datasets/MELD/wav_pt/train'
            video_root = '../datasets/MELD/frames/train'
        else:
            audio_root = '../datasets/MELD/wav_pt/test'
            video_root = '../datasets/MELD/frames/test'
        item = self.data_df.iloc[idx]

        utterance_text = item['Utterance']
        emotion = item['Sentiment']
        dia_id = item['Dialogue_ID']; utt_id = item['Utterance_ID'] # For constructing audio/video paths

        text_tokens = {
            'input_ids': torch.zeros(512, dtype=torch.long),
            'attention_mask': torch.zeros(512, dtype=torch.long)
        }
        text_input_str = str(utterance_text) if pd.notna(utterance_text) else ""
        encoded_text = self.tokenizer.encode_plus(
            text_input_str, add_special_tokens=True, max_length=512,
            padding='max_length', truncation=True, return_attention_mask=True, return_tensors='pt'
        )
        text_tokens['input_ids'] = encoded_text['input_ids'].squeeze(0)
        text_tokens['attention_mask'] = encoded_text['attention_mask'].squeeze(0)

        audio_path = os.path.join(audio_root, f"dia{dia_id}_utt{utt_id}")
        fbank = torch.load(audio_path + '.pt').unsqueeze(0)

        # Visual
        if self.split == 'train':
            transf = transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])
        else:
            transf = transforms.Compose([
                transforms.Resize(size=(224, 224)),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])

        folder_path = os.path.join(video_root, f"dia{dia_id}_utt{utt_id}")
        file_num = len(os.listdir(folder_path))
        pick_num = 3
        seg = int(file_num / pick_num)
        image_arr = []

        for i in range(pick_num):
            if self.split == 'train':
                index = random.randint(i * seg, i * seg + seg - 1)
            else:
                index = i * seg + int(seg / 2)
            path = folder_path + '/' + str(index).zfill(5) + '.jpg'
            image_arr.append(transf(Image.open(path).convert('RGB')).unsqueeze(0))

        images = torch.cat(image_arr)

        label_id = self.emotion_to_id.get(emotion, 0)
        label = torch.tensor(label_id, dtype=torch.long)

        return fbank, images,text_tokens['input_ids'],text_tokens['attention_mask'], label

class AVDataset_FOOD(Dataset):
    def __init__(self, mode='train'):
        if mode=='train':
            csvdata = '../datasets/UPMC Food-101/texts/train_titles.csv'
            self.img_dir = '../datasets/UPMC Food-101/images/train'
        else:
            csvdata = '../datasets/UPMC Food-101/texts/test_titles.csv'
            self.img_dir = '../datasets/UPMC Food-101/images/test'
        csv_reader = csv.reader(open(csvdata, encoding='utf-8'))
        data_list = []
        for row in csv_reader:
            data_list.append(row)

        self.data = data_list

        self.label_to_id = {label: i for i, label in enumerate(sorted(os.listdir(self.img_dir)))}
        self.id_to_label = {i: label for label, i in self.label_to_id.items()}

        self.img_root = '../datasets/food101/images/'
        self.tokenizer = AutoTokenizer.from_pretrained('./encoders/bert-base-uncased')
        self.split = mode

        print('data load over')
        print(len(self.data))

        self.classes = sorted(self.label_to_id)

        self._init_atransform()
        print('# of files = %d ' % len(self.data))
        print('# of classes = %d' % len(self.classes))

        # Audio
        self.class_num = len(self.classes)

    def _init_atransform(self):
        self.aid_transform = transforms.Compose([transforms.ToTensor()])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item_data_row = self.data[idx]

        img_path_rel = item_data_row[0]
        text_input_str = item_data_row[1]
        label_str = item_data_row[2]
        label_id = self.label_to_id[label_str]

        text_tokens = {
            'input_ids': torch.zeros(512, dtype=torch.long),
            'attention_mask': torch.zeros(512, dtype=torch.long)
        }
        encoded_text = self.tokenizer.encode_plus(
            text_input_str, add_special_tokens=True, max_length=512,
            padding='max_length', truncation=True, return_attention_mask=True, return_tensors='pt'
        )
        text_tokens['input_ids'] = encoded_text['input_ids'].squeeze(0)
        text_tokens['attention_mask'] = encoded_text['attention_mask'].squeeze(0)

        # Visual
        if self.split == 'train':
            transf = transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])
        else:
            transf = transforms.Compose([
                transforms.Resize(size=(224, 224)),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])

        img_full_path = self.img_dir + '/' + label_str + '/' + img_path_rel
        image_arr = []
        image_arr.append(transf(Image.open(img_full_path).convert('RGB')).unsqueeze(0))

        images = torch.cat(image_arr)

        return images,text_tokens['input_ids'],text_tokens['attention_mask'], torch.tensor(label_id, dtype=torch.long)

class AVDataset_Hateful(Dataset):
    def __init__(self, mode='train'):
        self.img_root = '../datasets/hm_data'
        self.split = mode
        self.data = []
        if mode=='train':
            jsonl_file = f"../datasets/hm_data/train.jsonl"
        else:
            jsonl_file = f"../datasets/hm_data/test_unseen.jsonl"
        with open(jsonl_file, 'r') as f:
            for line in f:
                self.data.append(json.loads(line))
        self.tokenizer = AutoTokenizer.from_pretrained('./encoders/bert-base-uncased')
        self.split = mode

        print('data load over')
        print(len(self.data))

        self._init_atransform()

        # Audio
        self.class_num = 2

    def _init_atransform(self):
        self.aid_transform = transforms.Compose([transforms.ToTensor()])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        img_filename = item['img']
        meme_text = item['text']
        label = torch.tensor(item['label'], dtype=torch.long)

        text_tokens = {
            'input_ids': torch.zeros(512, dtype=torch.long),
            'attention_mask': torch.zeros(512, dtype=torch.long)
        }
        text_input_str = meme_text
        encoded_text = self.tokenizer.encode_plus(
            text_input_str, add_special_tokens=True, max_length=512,
            padding='max_length', truncation=True, return_attention_mask=True, return_tensors='pt'
        )
        text_tokens['input_ids'] = encoded_text['input_ids'].squeeze(0)
        text_tokens['attention_mask'] = encoded_text['attention_mask'].squeeze(0)

        if self.split == 'train':
            transf = transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])
        else:
            transf = transforms.Compose([
                transforms.Resize(size=(224, 224)),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])

        img_path = os.path.join(self.img_root, img_filename)
        image_arr = []
        image_arr.append(transf(Image.open(img_path).convert('RGB')).unsqueeze(0))

        images = torch.cat(image_arr)

        return images,text_tokens['input_ids'],text_tokens['attention_mask'], label

class AVDataset_IEMOCAP(Dataset):
    def __init__(self, mode='train'):
        if mode == 'train':
            csv_file = '../datasets/IEMOCAP_full_release/train.json'
        else:
            csv_file = '../datasets/IEMOCAP_full_release/test.json'
        with open(csv_file, 'r') as f:
            samples = json.load(f)
        self.samples = samples['meta_data']
        self.tokenizer = AutoTokenizer.from_pretrained('./encoders/bert-base-uncased')
        self.split = mode
        self.root_dir = '../datasets/IEMOCAP_full_release/IEMOCAP_full_release'
        self.json_path = csv_file

        self.metadata_index = self._build_metadata_index()

        self.emotion_to_id = samples['labels']
        self.idx_to_label = {i: label for label, i in self.emotion_to_id.items()}

        print('data load over')
        print(len(self.samples))

        self.classes = sorted(self.emotion_to_id)

        self._init_atransform()
        print('# of files = %d ' % len(self.samples))
        print('# of classes = %d' % len(self.classes))

        self.class_num = len(self.classes)

    def _build_metadata_index(self) -> Dict:
        metadata_index = {}

        emo_regex = re.compile(r'\[(.*)\]\s+(Ses\d{2}[MF]_\w+_\w{4})\s+(\w+)\s+\[.*\]')

        trans_regex = re.compile(r'(Ses\d{2}[MF]_\w+_\w{4})\s+\[(.*)\]:\s+(.*)')

        for session in range(1, 6):
            session_dir = os.path.join(self.root_dir, f'Session{session}')

            emo_eval_dir = os.path.join(session_dir, 'dialog', 'EmoEvaluation')
            for filename in os.listdir(emo_eval_dir):
                if filename.endswith('.txt'):
                    with open(os.path.join(emo_eval_dir, filename), 'r', encoding='latin-1') as f:
                        for line in f:
                            match = emo_regex.match(line.strip())
                            if match:
                                time_str, utt_id, label = match.groups()
                                start_time, end_time = map(float, time_str.split(' - '))
                                metadata_index[utt_id] = {
                                    'label': label,
                                    'start_time': start_time,
                                    'end_time': end_time
                                }

            trans_dir = os.path.join(session_dir, 'dialog', 'transcriptions')
            for filename in os.listdir(trans_dir):
                if filename.endswith('.txt'):
                    with open(os.path.join(trans_dir, filename), 'r', encoding='latin-1') as f:
                        for line in f:
                            match = trans_regex.match(line.strip())
                            if match:
                                utt_id, _, text = match.groups()
                                if utt_id in metadata_index:
                                    metadata_index[utt_id]['text'] = text

        return metadata_index

    def _init_atransform(self):
        self.aid_transform = transforms.Compose([transforms.ToTensor()])

    def _extract_video_frames(self, video_path: str, start_time: float, end_time: float) -> torch.Tensor:
        frames = []

        cap = cv2.VideoCapture(video_path)

        actual_fps = cap.get(cv2.CAP_PROP_FPS)

        start_frame = int(start_time * actual_fps)
        end_frame = int(end_time * actual_fps)

        cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)

        current_frame = start_frame
        while cap.isOpened() and current_frame <= end_frame:
            ret, frame = cap.read()
            if not ret:
                break

            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frame_tensor = torch.from_numpy(frame_rgb).permute(2, 0, 1).float() / 255.0
            frames.append(frame_tensor)
            current_frame += 1

        cap.release()
        return torch.stack(frames)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        item = self.samples[idx]
        emotion = item['label']
        sample_info = item
        audio_relative_path = sample_info['path']
        full_audio_path = os.path.join('../datasets/IEMOCAP_full_release/pt_cremad', audio_relative_path)
        utterance_id = os.path.splitext(os.path.basename(audio_relative_path))
        metadata = self.metadata_index[utterance_id[0]]
        text_input_str = metadata.get('text', '')

        text_tokens = {
            'input_ids': torch.zeros(512, dtype=torch.long),
            'attention_mask': torch.zeros(512, dtype=torch.long)
        }
        encoded_text = self.tokenizer.encode_plus(
            text_input_str, add_special_tokens=True, max_length=512,
            padding='max_length', truncation=True, return_attention_mask=True, return_tensors='pt'
        )
        text_tokens['input_ids'] = encoded_text['input_ids'].squeeze(0)
        text_tokens['attention_mask'] = encoded_text['attention_mask'].squeeze(0)

        audio_path = full_audio_path
        fbank = torch.load(audio_path.split('.wav')[0] + '.pt').unsqueeze(0)

        # Visual
        if self.split == 'train':
            transf = transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])
        else:
            transf = transforms.Compose([
                transforms.Resize(size=(224, 224)),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])

        parts = utterance_id[0].split('_')
        session_id = audio_relative_path.split('/')[0]
        dialogue_filename_base = '_'.join(parts[:-1])

        dialogue_avi_filename = f"{dialogue_filename_base}.avi"
        video_path = os.path.join(self.root_dir, session_id, 'dialog', 'avi', 'DivX', dialogue_avi_filename)

        start_time = metadata['start_time']
        end_time = metadata['end_time']
        video_frames = self._extract_video_frames(video_path, start_time, end_time)
        frames_list = list(video_frames)
        file_num = len(frames_list)
        pick_num = 2
        seg = int(file_num / pick_num)
        image_arr = []

        for i in range(pick_num):
            if self.split == 'train':
                index = random.randint(i * seg, i * seg + seg - 1)
            else:
                index = i * seg + int(seg / 2)
            image_arr.append(transf(frames_list[index]).unsqueeze(0))

        images = torch.cat(image_arr)

        label_id = self.emotion_to_id.get(emotion, 0)
        label = torch.tensor(label_id, dtype=torch.long)

        return fbank, images,text_tokens['input_ids'],text_tokens['attention_mask'], label

class AVDataset_mmimdb(Dataset):
    def __init__(self, mode='train'):

        self.img_root = '../datasets/mmimdb/dataset/'
        self.tokenizer = AutoTokenizer.from_pretrained('./encoders/bert-base-uncased')
        self.split = mode
        json_file = '../datasets/mmimdb/new_split_for_regress.json'

        with open(json_file, 'r') as f:
            all_data_json = json.load(f)

        self.data = all_data_json[self.split]

        all_genres = set()
        self.all_data = []
        for item in self.data:
            with open(self.img_root + '/' + item + '.json', 'r') as f:
                data = json.load(f)
                all_genres.update(data['genres'])
                tem = {}
                tem['img_filename'] = self.img_root + '/' + item + '.jpeg'
                tem['plot_summary'] = data['plot']
                if len(tem['plot_summary']) != 1:
                    tem['plot_summary'] = [' '.join([item for item in tem['plot_summary']])]
                tem['genres'] = data['genres']
                tem['rating'] = data['rating']
                self.all_data.append(tem)

        print(len(self.all_data))

        self._init_atransform()

    def _init_atransform(self):
        self.aid_transform = transforms.Compose([transforms.ToTensor()])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):

        item = self.all_data[idx]
        img_filename = item['img_filename']
        plot_summary = item['plot_summary'][0]
        label = torch.tensor(item['rating']).float()

        img_path = img_filename

        text_tokens = {
            'input_ids': torch.zeros(512, dtype=torch.long),
            'attention_mask': torch.zeros(512, dtype=torch.long)
        }
        text_input_str = str(plot_summary)
        encoded_text = self.tokenizer.encode_plus(
            text_input_str, add_special_tokens=True, max_length=512,
            padding='max_length', truncation=True, return_attention_mask=True, return_tensors='pt'
        )
        text_tokens['input_ids'] = encoded_text['input_ids'].squeeze(0)
        text_tokens['attention_mask'] = encoded_text['attention_mask'].squeeze(0)

        if self.split == 'train':
            transf = transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])
        else:
            transf = transforms.Compose([
                transforms.Resize(size=(224, 224)),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])

        image_arr = []
        image_arr.append(transf(Image.open(img_path).convert('RGB')).unsqueeze(0))
        images = torch.cat(image_arr)
        return images,text_tokens['input_ids'],text_tokens['attention_mask'], label

