import os, copy, librosa, numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision import transforms


LABELS_31 = [
    "blowing_nose", "blowing_out_candles", "bowling", "chopping_wood",
    "dribbling_basketball", "laughing", "mowing_lawn", "playing_accordion",
    "playing_bagpipes", "playing_bass_guitar", "playing_clarinet",
    "playing_drums", "playing_guitar", "playing_harmonica", "playing_keyboard",
    "playing_organ", "playing_piano", "playing_saxophone", "playing_trombone",
    "playing_trumpet", "playing_violin", "playing_xylophone", "ripping_paper",
    "shoveling_snow", "shuffling_cards", "singing", "stomping_grapes",
    "tap_dancing", "tapping_guitar", "tapping_pen", "tickling"
]
N_CLASS = len(LABELS_31)


class KSDataset(Dataset):
    def __init__(self, args,
                 mode: str = 'train',
                 sr:   int  = 22_050,
                 keep_fps: int = 3):

        super().__init__()
        assert mode.lower() in ['train', 'test']
        self.mode = mode.lower()
        self.sr   = sr

        ks_root = '/'
        self.v_dir = os.path.join(ks_root, f'frames-{keep_fps}FPS')
        self.a_dir = os.path.join(ks_root,  'audio_wav')

        self.train_txt = os.path.join(ks_root, 'my_train.txt')
        self.test_txt  = os.path.join(ks_root, 'my_test.txt')

        self.data2label = {}
        clips_train, clips_test = [], []

        def _parse_line(line, bucket):
            line = line.strip()
            if not line:
                return
            try:
                clip_id, _, cls_idx = [s.strip() for s in line.split(',')]
                cls_idx = int(cls_idx)            # 0–30
            except ValueError:
                print('[WARN] malformed line:', line)
                return

            wav_ok = os.path.isfile(os.path.join(self.a_dir,
                                                 f'{clip_id}.wav'))
            img_ok = os.path.isdir(os.path.join(self.v_dir, clip_id))
            if not (wav_ok and img_ok):
                return

            self.data2label[clip_id] = cls_idx
            bucket.append(clip_id)

        with open(self.train_txt) as f:
            for ln in f: _parse_line(ln, clips_train)
        with open(self.test_txt)  as f:
            for ln in f: _parse_line(ln, clips_test)

        self.av_files = clips_train if self.mode == 'train' else clips_test
        print(f'#files = {len(self.av_files)},  #classes = {N_CLASS}')

        self.tr_train = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485,0.456,0.406],
                                 [0.229,0.224,0.225])
        ])
        self.tr_eval = transforms.Compose([
            transforms.Resize((224,224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485,0.456,0.406],
                                 [0.229,0.224,0.225])
        ])

    def __len__(self):
        return len(self.av_files)

    def __getitem__(self, idx):
        clip_id = self.av_files[idx]
        wav_path = os.path.join(self.a_dir, f'{clip_id}.wav')
        wav, _   = librosa.load(wav_path, sr=self.sr)

        target_len = self.sr * 3                     # 3-s clip
        if len(wav) < target_len:
            wav = np.tile(wav, int(np.ceil(target_len/len(wav))))[:target_len]
        else:
            wav = wav[:target_len]

        spec = librosa.stft(wav, n_fft=512, hop_length=353)
        spec = np.log(np.abs(spec) + 1e-7).astype(np.float32)  # (F,T)
        spec = torch.from_numpy(spec)

        clip_dir   = os.path.join(self.v_dir, clip_id)
        frames     = sorted(os.listdir(clip_dir))
        pick       = 3
        stride     = max(1, len(frames)//pick)
        indices    = [min(i*stride, len(frames)-1) for i in range(pick)]

        imgs = []
        tfm  = self.tr_train if self.mode == 'train' else self.tr_eval
        for i in indices:
            img = Image.open(os.path.join(clip_dir, frames[i])
                             ).convert('RGB')
            imgs.append(tfm(img).unsqueeze(1))      # (C,1,H,W)
        visual = torch.cat(imgs, 1)                 # (C,T=3,H,W)

        label_idx = self.data2label[clip_id]        # 已是 int
        return spec, visual, label_idx
