import os, librosa, numpy as np, copy
from PIL  import Image
import torch
from torch.utils.data import Dataset
from torchvision import transforms


class AVEDataset(Dataset):
    def __init__(self, args, mode='train', sr=22_050):
        super().__init__()
        assert mode.lower() in ['train', 'test']
        self.mode = mode.lower()
        self.sr   = sr

        ave_root = '/'
        self.v_dir = os.path.join(ave_root, 'frames-3FPS')
        self.a_dir = os.path.join(ave_root, 'audio_wav')

        self.train_txt = os.path.join(ave_root, 'trainSet.txt')
        self.test_txt  = os.path.join(ave_root,  'testSet.txt')

        self.classes       = set()
        self.data2class    = {}
        train_clips, test_clips = [], []

        def _process_line(line, target_list):
            line = line.strip()
            if not line:
                return
            parts = line.split('&')
            if len(parts) < 2:
                print('[WARN] malformed line:', line)
                return
            label, clip_id = parts[0], parts[1]

            a_path = os.path.join(self.a_dir, f'{clip_id}.wav')
            v_path = os.path.join(self.v_dir,  clip_id)
            if not (os.path.isfile(a_path) and os.path.isdir(v_path)):
                return

            self.classes.add(label)
            self.data2class[clip_id] = label
            target_list.append(clip_id)

        with open(self.train_txt, 'r', encoding='utf-8') as f:
            for line in f: _process_line(line, train_clips)

        with open(self.test_txt,  'r', encoding='utf-8') as f:
            for line in f: _process_line(line, test_clips)

        self.av_files = train_clips if self.mode == 'train' else test_clips
        self.classes  = sorted(self.classes)

        print(f'#files = {len(self.av_files)},  #classes = {len(self.classes)}')

        self.tr_train = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406],
                                 [0.229, 0.224, 0.225])
        ])
        self.tr_eval = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406],
                                 [0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self.av_files)

    def __getitem__(self, idx):
        clip_id = self.av_files[idx]

        # === Audio ===
        wav_path = os.path.join(self.a_dir, f'{clip_id}.wav')
        wav, _   = librosa.load(wav_path, sr=self.sr)

        target_len = self.sr * 3                       # 3 秒
        if len(wav) < target_len:                     # wrap-pad
            repeat = int(np.ceil(target_len / len(wav)))
            wav = np.tile(wav, repeat)[:target_len]
        else:
            wav = wav[:target_len]

        spec = librosa.stft(wav, n_fft=512, hop_length=353)
        spec = np.log(np.abs(spec) + 1e-7).astype(np.float32)  # (F,T)
        spec = torch.from_numpy(spec)                          # (F,T)

        clip_dir   = os.path.join(self.v_dir, clip_id)
        frame_list = sorted(os.listdir(clip_dir))
        pick = 3
        stride = max(1, len(frame_list) // pick)
        sel = [min(i*stride, len(frame_list)-1) for i in range(pick)]

        imgs = []
        tfm  = self.tr_train if self.mode == 'train' else self.tr_eval
        for i in sel:
            img_p = os.path.join(clip_dir, frame_list[i])
            img = Image.open(img_p).convert('RGB')
            img = tfm(img).unsqueeze(1)      # (C,1,H,W)
            imgs.append(img)
        visual = torch.cat(imgs, dim=1)      # (C, T=3, H, W)

        label_idx = self.classes.index(self.data2class[clip_id])
        return spec, visual, label_idx