from dataclasses import dataclass

import numpy as np
import torch
import torchaudio
from pytorch_lightning import LightningDataModule
from torch.utils.data import Dataset, DataLoader

import soundfile
from soundfile import LibsndfileError

torch.set_num_threads(1)


@dataclass
class DataConfig:
    filelist_path: str
    sampling_rate: int
    num_samples: int
    batch_size: int
    num_workers: int


class VocosDataModule(LightningDataModule):
    def __init__(self, train_params: DataConfig, val_params: DataConfig):
        super().__init__()
        self.train_config = train_params
        self.val_config = val_params

    def _get_dataloder(self, cfg: DataConfig, train: bool):
        dataset = VocosDataset(cfg, train=train)
        # DEBUG: single-process loading so the crashing file path shows up clearly
        dataloader = DataLoader(
            dataset,
            batch_size=cfg.batch_size,
            num_workers=0,            # TEMP: force main-process loading
            persistent_workers=False, # TEMP
            shuffle=train,
            pin_memory=True,
        )
        return dataloader

    def train_dataloader(self) -> DataLoader:
        return self._get_dataloder(self.train_config, train=True)

    def val_dataloader(self) -> DataLoader:
        return self._get_dataloder(self.val_config, train=False)


class VocosDataset(Dataset):
    def __init__(self, cfg: DataConfig, train: bool):
        with open(cfg.filelist_path) as f:
            self.filelist = f.read().splitlines()
        self.sampling_rate = cfg.sampling_rate
        self.num_samples = cfg.num_samples
        self.train = train

    def __len__(self) -> int:
        return len(self.filelist)

    def __getitem__(self, index: int) -> torch.Tensor:
        audio_path = self.filelist[index]

        # Wrap the read to surface the exact bad file path
        try:
            y1, sr = soundfile.read(audio_path, dtype='float32', always_2d=False)
        except LibsndfileError as e:
            raise RuntimeError(f"LibsndfileError while reading: {audio_path}") from e
        except Exception as e:
            raise RuntimeError(f"Read failed for {audio_path}: {type(e).__name__}: {e}") from e

        # to tensor, shape (1, T) or (1, T, C)
        y = torch.tensor(y1).float().unsqueeze(0)

        # If multi-channel became (1, T, C), mix to mono on channel dim
        if y.ndim > 2:
            y = y.mean(dim=-1, keepdim=False)

        gain = np.random.uniform(-1, -6) if self.train else -3
        y, _ = torchaudio.sox_effects.apply_effects_tensor(y, sr, [["norm", f"{gain:.2f}"]])

        if sr != self.sampling_rate:
            y = torchaudio.functional.resample(y, orig_freq=sr, new_freq=self.sampling_rate)

        # Length handling
        if y.size(-1) < self.num_samples:
            pad_length = self.num_samples - y.size(-1)
            padding_tensor = y.repeat(1, 1 + pad_length // y.size(-1))
            y = torch.cat((y, padding_tensor[:, :pad_length]), dim=1)
        elif self.train:
            start = np.random.randint(low=0, high=y.size(-1) - self.num_samples + 1)
            y = y[:, start:start + self.num_samples]
        else:
            y = y[:, :self.num_samples]

        return y[0]
