# =============================================
# File: vggsound_dataloaders.py
# ---------------------------------------------
# Retrieval-consistent samplers for VGGSound
# - AUDIO: 10 s → log-Mel fbank (1000, 128) with 25ms Hann / 10ms hop, no aug
# - VIDEO: 10 s → 3 FPS → 30 RGB frames resized to 224×224, no aug
# - CSV layout: stat.csv, {train|test}.csv ; data layout matches user repo
# - Returns label index from sorted(class_names)
# =============================================
from __future__ import annotations
import os, csv
from dataclasses import dataclass
from typing import List, Tuple

import cv2
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import librosa


# ----------------------------
# Common helpers
# ----------------------------

def _load_classes(csv_dir: str) -> List[str]:
    classes: List[str] = []
    with open(os.path.join(csv_dir, 'stat.csv'), newline='') as f:
        for row in csv.reader(f):
            if row:
                classes.append(row[0])
    return sorted(classes)


def _load_split(csv_dir: str, data_root: str, split: str, classes: List[str]) -> List[Tuple[str, int]]:
    items: List[Tuple[str, int]] = []
    with open(os.path.join(csv_dir, f'{split}.csv'), newline='') as f:
        for vid, cls in csv.reader(f):
            cls_dir = cls.replace(' ', '_')
            label = classes.index(cls)
            base = os.path.join(data_root, split, cls_dir, vid[:-3])  # remove 'mp4'
            items.append((base, label))
    return items


# ----------------------------
# AUDIO dataset (10 s → log-mel 1000×128)
# ----------------------------
@dataclass
class AudioCfg:
    csv_path: str = './csv'
    data_path: str = './data'
    split: str = 'train'
    sr: int = 16000
    seconds: int = 10
    n_mels: int = 128
    win_ms: float = 25.0
    hop_ms: float = 10.0


class VGGSoundAudio(Dataset):
    """Produces (1, 1000, 128) float32 tensor + label (int)."""
    def __init__(self, cfg: AudioCfg):
        self.cfg = cfg
        self.classes = _load_classes(cfg.csv_path)
        self.items = _load_split(cfg.csv_path, cfg.data_path, cfg.split, self.classes)

    def __len__(self):
        return len(self.items)

    def __getitem__(self, idx: int):
        base, y = self.items[idx]
        wav = base + 'wav'
        x = self._logmel_1000x128(wav)  # (1000,128)
        x = torch.from_numpy(x).unsqueeze(0)  # (1,1000,128)
        return x, y

    # --- feature extractor ---
    def _logmel_1000x128(self, wav_path: str) -> np.ndarray:
        sr = self.cfg.sr
        target_len = self.cfg.seconds * sr
        samples, _ = librosa.load(wav_path, sr=sr, mono=True)
        # center-crop/pad to 10s
        if len(samples) < target_len:
            pad = target_len - len(samples)
            left, right = pad // 2, pad - pad // 2
            samples = np.pad(samples, (left, right), mode='constant')
        elif len(samples) > target_len:
            st = (len(samples) - target_len) // 2
            samples = samples[st:st+target_len]
        n_fft = int(sr * (self.cfg.win_ms / 1000.0))
        hop = int(sr * (self.cfg.hop_ms / 1000.0))
        mel = librosa.feature.melspectrogram(y=samples, sr=sr, n_fft=n_fft, hop_length=hop,
                                             win_length=n_fft, window='hann', n_mels=self.cfg.n_mels,
                                             center=True, power=2.0, fmin=0.0, fmax=sr/2)
        logmel = librosa.power_to_db(mel, ref=np.max).T.astype(np.float32)  # (T,128)
        # force T=1000
        T = logmel.shape[0]
        if T < 1000:
            pad = 1000 - T
            logmel = np.pad(logmel, ((pad//2, pad-pad//2), (0,0)), mode='constant')
        elif T > 1000:
            st = (T - 1000) // 2
            logmel = logmel[st:st+1000]
        return logmel


# ----------------------------
# VIDEO dataset (10 s → 30 RGB frames @3FPS)
# ----------------------------
@dataclass
class VideoCfg:
    csv_path: str = './csv'
    data_path: str = './data'
    split: str = 'train'
    seconds: int = 10
    size: tuple[int, int] = (224, 224)  # (160, 160) # (W,H)


class VGGSoundVideo(Dataset):
    """Produces (T=30, 3, H, W) float32 in [0,1] + label (int)."""
    def __init__(self, cfg: VideoCfg):
        self.cfg = cfg
        self.classes = _load_classes(cfg.csv_path)
        self.items = _load_split(cfg.csv_path, cfg.data_path, cfg.split, self.classes)

    def __len__(self):
        return len(self.items)

    def __getitem__(self, idx: int):
        base, y = self.items[idx]
        mp4 = base + 'mp4'
        x = self._read_30_frames(mp4)  # (30,3,H,W)
        x = torch.from_numpy(x)
        return x, y

    def _read_30_frames(self, mp4_path: str) -> np.ndarray:
        T = 30
        H, W = self.cfg.size[1], self.cfg.size[0]
        cap = cv2.VideoCapture(mp4_path)
        if not cap.isOpened():
            return np.zeros((T, 3, H, W), dtype=np.float32)
        total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        src_fps = cap.get(cv2.CAP_PROP_FPS) or 0.0
        times = np.linspace(0, self.cfg.seconds, num=T, endpoint=False)
        frames = []
        for t in times:
            idx = int(round(t * src_fps)) if src_fps > 0 else int(round((t / self.cfg.seconds) * max(total - 1, 0)))
            idx = np.clip(idx, 0, max(total - 1, 0))
            cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
            ok, frame = cap.read()
            if not ok or frame is None:
                if len(frames) == 0:
                    frames.append(np.zeros((H, W, 3), dtype=np.uint8))
                else:
                    frames.append(frames[-1])
            else:
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frame = cv2.resize(frame, (W, H))
                frames.append(frame)
        cap.release()
        arr = np.stack(frames, axis=0).astype(np.float32) / 255.0  # (30,H,W,3)
        arr = np.transpose(arr, (0,3,1,2))  # (30,3,H,W)
        return arr


# ----------------------------
# Dataloaders (no augmentation)
# ----------------------------

def make_audio_loader(cfg: AudioCfg, batch_size: int, shuffle: bool, num_workers: int = 8) -> DataLoader:
    ds = VGGSoundAudio(cfg)
    return DataLoader(ds, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True)


def _collate_video(batch):
    xs, ys = zip(*batch)
    x = torch.stack(xs, dim=0)  # (N,T,3,H,W)
    y = torch.tensor(ys, dtype=torch.long)
    return x, y


def make_video_loader(cfg: VideoCfg, batch_size: int, shuffle: bool, num_workers: int = 8) -> DataLoader:
    ds = VGGSoundVideo(cfg)
    return DataLoader(ds, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers,
                      pin_memory=True, collate_fn=_collate_video)

