"""
Data loading utilities for cost-sensitive learning experiments.

Provides a unified interface to load datasets with train/val/test splits.
Each dataset returns a dict with standardized keys.
"""

from dataclasses import dataclass
from typing import Dict, Optional, Any, Tuple, Union, List
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split

from data.jigsaw import load_jigsaw_frame


@dataclass
class DatasetSplit:
    """
    A single split of a dataset.

    Attributes:
        X: Features (texts, embeddings, or DataFrame)
        y: Binary labels (y_star)
        delta: Signed delta values
        abs_delta: Absolute delta values (|delta|)
        ids: Optional example IDs
        indices: Optional indices into the full dataset (for pre-computed embedding lookup)
    """
    X: Union[np.ndarray, List[str], pd.DataFrame]
    y: np.ndarray
    delta: np.ndarray
    abs_delta: np.ndarray
    ids: Optional[np.ndarray] = None
    indices: Optional[np.ndarray] = None

    def __len__(self) -> int:
        return len(self.y)


@dataclass
class Dataset:
    """
    A complete dataset with train/val/test splits.

    Attributes:
        name: Dataset name (e.g., 'jigsaw', 'turkey')
        train: Training split
        val: Validation split (optional)
        test: Test split
        metadata: Additional dataset info
    """
    name: str
    train: DatasetSplit
    val: Optional[DatasetSplit]
    test: DatasetSplit
    metadata: Dict[str, Any]

    @property
    def feature_type(self) -> str:
        """Return the type of features: 'text', 'image', or 'tabular'."""
        return self.metadata.get('feature_type', 'unknown')


def load_jigsaw(
    path: str = "data/jigsaw/train.csv",
    sample_size: Optional[int] = None,
    val_size: float = 0.1,
    test_size: float = 0.1,
    seed: int = 42,
) -> Dataset:
    """
    Load Jigsaw toxicity dataset.

    Args:
        path: Path to the Jigsaw CSV file
        sample_size: Optional subsample size (None = use all)
        val_size: Fraction for validation (0 to skip validation)
        test_size: Fraction for test
        seed: Random seed for splitting

    Returns:
        Dataset with text features and toxicity labels
    """
    # Load full frame
    df = load_jigsaw_frame(path=path, sample=sample_size, random_state=seed)

    # Extract arrays
    texts = df['comment_text'].tolist()
    y = df['y_star'].to_numpy()
    delta = df['delta_signed'].to_numpy()
    abs_delta = df['abs_delta'].to_numpy()

    # Create indices for splitting
    indices = np.arange(len(df))

    # Split: train / (val + test)
    val_test_size = val_size + test_size
    train_idx, val_test_idx = train_test_split(
        indices, test_size=val_test_size, random_state=seed, stratify=y
    )

    # Split val_test into val and test
    if val_size > 0:
        # Relative size of val within val_test
        val_relative = val_size / val_test_size
        val_idx, test_idx = train_test_split(
            val_test_idx,
            test_size=(1 - val_relative),
            random_state=seed,
            stratify=y[val_test_idx]
        )
    else:
        val_idx = np.array([], dtype=int)
        test_idx = val_test_idx

    # Create splits (include indices for pre-computed embedding lookup)
    train_split = DatasetSplit(
        X=[texts[i] for i in train_idx],
        y=y[train_idx],
        delta=delta[train_idx],
        abs_delta=abs_delta[train_idx],
        indices=train_idx,
    )

    val_split = None
    if len(val_idx) > 0:
        val_split = DatasetSplit(
            X=[texts[i] for i in val_idx],
            y=y[val_idx],
            delta=delta[val_idx],
            abs_delta=abs_delta[val_idx],
            indices=val_idx,
        )

    test_split = DatasetSplit(
        X=[texts[i] for i in test_idx],
        y=y[test_idx],
        delta=delta[test_idx],
        abs_delta=abs_delta[test_idx],
        indices=test_idx,
    )

    return Dataset(
        name='jigsaw',
        train=train_split,
        val=val_split,
        test=test_split,
        metadata={
            'feature_type': 'text',
            'n_total': len(df),
            'n_train': len(train_idx),
            'n_val': len(val_idx),
            'n_test': len(test_idx),
            'path': path,
            'sample_size': sample_size,
        }
    )


def load_turkey(
    path: str = "data/turkey_cost_table.csv",
    sample_size: Optional[int] = None,
    val_size: float = 0.1,
    test_size: float = 0.1,
    seed: int = 42,
) -> Dataset:
    """
    Load Turkey injury classification dataset.

    Args:
        path: Path to the turkey cost table CSV
        sample_size: Optional subsample size
        val_size: Fraction for validation
        test_size: Fraction for test
        seed: Random seed

    Returns:
        Dataset with image paths and injury labels
    """
    df = pd.read_csv(path)

    # Subsample if requested
    if sample_size is not None and sample_size < len(df):
        df = df.sample(n=sample_size, random_state=seed).reset_index(drop=True)

    # Extract arrays - image_path needs to be converted to full path
    image_paths = ["data/turkey/" + p.replace("Turkey/", "") for p in df['image_path'].tolist()]
    y = df['y_star'].to_numpy()
    delta = df['delta_signed'].to_numpy()
    abs_delta = df['abs_delta'].to_numpy()

    # Create indices for splitting
    indices = np.arange(len(df))

    # Split: train / (val + test)
    val_test_size = val_size + test_size
    train_idx, val_test_idx = train_test_split(
        indices, test_size=val_test_size, random_state=seed, stratify=y
    )

    # Split val_test into val and test
    if val_size > 0:
        val_relative = val_size / val_test_size
        val_idx, test_idx = train_test_split(
            val_test_idx,
            test_size=(1 - val_relative),
            random_state=seed,
            stratify=y[val_test_idx]
        )
    else:
        val_idx = np.array([], dtype=int)
        test_idx = val_test_idx

    # Create splits
    train_split = DatasetSplit(
        X=[image_paths[i] for i in train_idx],
        y=y[train_idx],
        delta=delta[train_idx],
        abs_delta=abs_delta[train_idx],
    )

    val_split = None
    if len(val_idx) > 0:
        val_split = DatasetSplit(
            X=[image_paths[i] for i in val_idx],
            y=y[val_idx],
            delta=delta[val_idx],
            abs_delta=abs_delta[val_idx],
        )

    test_split = DatasetSplit(
        X=[image_paths[i] for i in test_idx],
        y=y[test_idx],
        delta=delta[test_idx],
        abs_delta=abs_delta[test_idx],
    )

    return Dataset(
        name='turkey',
        train=train_split,
        val=val_split,
        test=test_split,
        metadata={
            'feature_type': 'image',
            'n_total': len(df),
            'n_train': len(train_idx),
            'n_val': len(val_idx),
            'n_test': len(test_idx),
            'path': path,
            'sample_size': sample_size,
        }
    )


def load_nhanes(
    path: str = "data/nhanes_cost_table.csv",
    sample_size: Optional[int] = None,
    val_size: float = 0.1,
    test_size: float = 0.1,
    seed: int = 42,
) -> Dataset:
    """
    Load NHANES hypertension classification dataset.

    Args:
        path: Path to the NHANES cost table CSV
        sample_size: Optional subsample size
        val_size: Fraction for validation
        test_size: Fraction for test
        seed: Random seed

    Returns:
        Dataset with tabular features and hypertension labels
    """
    df = pd.read_csv(path)

    # Subsample if requested
    if sample_size is not None and sample_size < len(df):
        df = df.sample(n=sample_size, random_state=seed).reset_index(drop=True)

    # Feature columns (exclude ID, target, and delta columns)
    feature_cols = ['RIDAGEYR', 'RIAGENDR', 'RIDRETH3', 'BMXBMI']
    X = df[feature_cols].to_numpy()
    y = df['y_star'].to_numpy()
    delta = df['delta_signed'].to_numpy()
    abs_delta = df['abs_delta'].to_numpy()
    ids = df['SEQN'].to_numpy()

    # Create indices for splitting
    indices = np.arange(len(df))

    # Split: train / (val + test)
    val_test_size = val_size + test_size
    train_idx, val_test_idx = train_test_split(
        indices, test_size=val_test_size, random_state=seed, stratify=y
    )

    # Split val_test into val and test
    if val_size > 0:
        val_relative = val_size / val_test_size
        val_idx, test_idx = train_test_split(
            val_test_idx,
            test_size=(1 - val_relative),
            random_state=seed,
            stratify=y[val_test_idx]
        )
    else:
        val_idx = np.array([], dtype=int)
        test_idx = val_test_idx

    # Create splits
    train_split = DatasetSplit(
        X=X[train_idx],
        y=y[train_idx],
        delta=delta[train_idx],
        abs_delta=abs_delta[train_idx],
        ids=ids[train_idx],
    )

    val_split = None
    if len(val_idx) > 0:
        val_split = DatasetSplit(
            X=X[val_idx],
            y=y[val_idx],
            delta=delta[val_idx],
            abs_delta=abs_delta[val_idx],
            ids=ids[val_idx],
        )

    test_split = DatasetSplit(
        X=X[test_idx],
        y=y[test_idx],
        delta=delta[test_idx],
        abs_delta=abs_delta[test_idx],
        ids=ids[test_idx],
    )

    return Dataset(
        name='nhanes',
        train=train_split,
        val=val_split,
        test=test_split,
        metadata={
            'feature_type': 'tabular',
            'feature_names': feature_cols,
            'n_total': len(df),
            'n_train': len(train_idx),
            'n_val': len(val_idx),
            'n_test': len(test_idx),
            'path': path,
            'sample_size': sample_size,
        }
    )


def load_inaturalist(
    path: str = "data/inaturalist_cost_table.csv",
    sample_size: Optional[int] = None,
    val_size: float = 0.1,
    test_size: float = 0.1,
    seed: int = 42,
) -> Dataset:
    """
    Load iNaturalist natural vs controlled environment dataset.

    Args:
        path: Path to the iNaturalist cost table CSV
        sample_size: Optional subsample size
        val_size: Fraction for validation
        test_size: Fraction for test
        seed: Random seed

    Returns:
        Dataset with image paths and natural/controlled labels
    """
    df = pd.read_csv(path)

    # Subsample if requested
    if sample_size is not None and sample_size < len(df):
        df = df.sample(n=sample_size, random_state=seed).reset_index(drop=True)

    # Extract arrays
    image_paths = df['image_path'].tolist()
    y = df['y_star'].to_numpy()
    delta = df['delta_signed'].to_numpy()
    abs_delta = df['abs_delta'].to_numpy()

    # Create indices for splitting
    indices = np.arange(len(df))

    # Split: train / (val + test)
    val_test_size = val_size + test_size
    train_idx, val_test_idx = train_test_split(
        indices, test_size=val_test_size, random_state=seed, stratify=y
    )

    # Split val_test into val and test
    if val_size > 0:
        val_relative = val_size / val_test_size
        val_idx, test_idx = train_test_split(
            val_test_idx,
            test_size=(1 - val_relative),
            random_state=seed,
            stratify=y[val_test_idx]
        )
    else:
        val_idx = np.array([], dtype=int)
        test_idx = val_test_idx

    # Create splits
    train_split = DatasetSplit(
        X=[image_paths[i] for i in train_idx],
        y=y[train_idx],
        delta=delta[train_idx],
        abs_delta=abs_delta[train_idx],
    )

    val_split = None
    if len(val_idx) > 0:
        val_split = DatasetSplit(
            X=[image_paths[i] for i in val_idx],
            y=y[val_idx],
            delta=delta[val_idx],
            abs_delta=abs_delta[val_idx],
        )

    test_split = DatasetSplit(
        X=[image_paths[i] for i in test_idx],
        y=y[test_idx],
        delta=delta[test_idx],
        abs_delta=abs_delta[test_idx],
    )

    return Dataset(
        name='inaturalist',
        train=train_split,
        val=val_split,
        test=test_split,
        metadata={
            'feature_type': 'image',
            'n_total': len(df),
            'n_train': len(train_idx),
            'n_val': len(val_idx),
            'n_test': len(test_idx),
            'path': path,
            'sample_size': sample_size,
        }
    )


def load_synthetic(
    n_train: int = 1000,
    n_test: int = 10000,
    d: int = 50,
    noise_scale: float = 0.1,
    sample_size: Optional[int] = None,
    val_size: float = 0.1,
    test_size: float = 0.1,
    seed: int = 42,
    **kwargs,
) -> Dataset:
    """
    Generate synthetic dataset with known linear boundary.

    Data is generated as:
      - w ~ N(0, I), normalized to unit norm
      - X ~ N(0, I)
      - delta = X @ w + noise
      - y = 1[delta >= 0]

    Args:
        n_train: Number of training samples
        n_test: Number of test samples
        d: Feature dimension
        noise_scale: Noise std for delta
        sample_size: Ignored (use n_train instead)
        val_size: Fraction for validation (from train)
        test_size: Ignored (use n_test instead)
        seed: Random seed

    Returns:
        Dataset with tabular features and synthetic labels
    """
    rng = np.random.default_rng(seed)

    # Random unit weight vector
    w = rng.normal(size=d)
    w = w / np.linalg.norm(w)

    # Generate features
    n_total = n_train + n_test
    X_all = rng.normal(size=(n_total, d))

    # Generate deltas (signed margins)
    delta_all = X_all @ w + rng.normal(scale=noise_scale, size=n_total)

    # Labels from sign of delta
    y_all = (delta_all >= 0).astype(int)
    abs_delta_all = np.abs(delta_all)

    # Split into train+val and test
    X_trainval, X_test = X_all[:n_train], X_all[n_train:]
    y_trainval, y_test = y_all[:n_train], y_all[n_train:]
    delta_trainval, delta_test = delta_all[:n_train], delta_all[n_train:]
    abs_delta_trainval, abs_delta_test = abs_delta_all[:n_train], abs_delta_all[n_train:]

    # Split train into train/val
    if val_size > 0:
        n_val = int(n_train * val_size)
        indices = rng.permutation(n_train)
        val_idx, train_idx = indices[:n_val], indices[n_val:]

        X_train, X_val = X_trainval[train_idx], X_trainval[val_idx]
        y_train, y_val = y_trainval[train_idx], y_trainval[val_idx]
        delta_train, delta_val = delta_trainval[train_idx], delta_trainval[val_idx]
        abs_delta_train, abs_delta_val = abs_delta_trainval[train_idx], abs_delta_trainval[val_idx]

        val_split = DatasetSplit(
            X=X_val,
            y=y_val,
            delta=delta_val,
            abs_delta=abs_delta_val,
        )
    else:
        X_train = X_trainval
        y_train = y_trainval
        delta_train = delta_trainval
        abs_delta_train = abs_delta_trainval
        val_split = None

    train_split = DatasetSplit(
        X=X_train,
        y=y_train,
        delta=delta_train,
        abs_delta=abs_delta_train,
    )

    test_split = DatasetSplit(
        X=X_test,
        y=y_test,
        delta=delta_test,
        abs_delta=abs_delta_test,
    )

    return Dataset(
        name='synthetic',
        train=train_split,
        val=val_split,
        test=test_split,
        metadata={
            'feature_type': 'tabular',
            'n_train': len(train_split),
            'n_val': len(val_split) if val_split else 0,
            'n_test': len(test_split),
            'd': d,
            'noise_scale': noise_scale,
            'w': w,  # Store the true weight vector
        }
    )


# Dataset registry
_DATASET_LOADERS = {
    'jigsaw': load_jigsaw,
    'turkey': load_turkey,
    'nhanes': load_nhanes,
    'inaturalist': load_inaturalist,
    'synthetic': load_synthetic,
}


def load_dataset(
    name: str,
    sample_size: Optional[int] = None,
    val_size: float = 0.1,
    test_size: float = 0.1,
    seed: int = 42,
    **kwargs,
) -> Dataset:
    """
    Load a dataset by name.

    Args:
        name: Dataset name ('jigsaw', 'turkey', 'nhanes')
        sample_size: Optional subsample size
        val_size: Fraction for validation split
        test_size: Fraction for test split
        seed: Random seed
        **kwargs: Additional dataset-specific arguments

    Returns:
        Dataset with train/val/test splits

    Raises:
        ValueError: If dataset name is not recognized
    """
    if name not in _DATASET_LOADERS:
        available = ', '.join(_DATASET_LOADERS.keys())
        raise ValueError(f"Unknown dataset: {name}. Available: {available}")

    loader = _DATASET_LOADERS[name]
    return loader(
        sample_size=sample_size,
        val_size=val_size,
        test_size=test_size,
        seed=seed,
        **kwargs,
    )


def list_datasets() -> List[str]:
    """Return list of available dataset names."""
    return list(_DATASET_LOADERS.keys())
