import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from typing import Dict, List, Tuple
import numpy as np
import random
import math

def compute_quantile(scores, alpha):
    n = len(scores)
    scores = [-num for num in scores]
    scores.sort()

    index = math.ceil((n + 1) * (1 - alpha)) - 1

    # in case sample size is too small, must simply omit every claim
    if index > (n - 1):
        return -1000

    return -scores[index]

def set_seed(seed: int):
    """Set all random seeds."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    # MPS doesn't have a separate seed function - torch.manual_seed handles it

def split_indices(indices: List[int], cal_ratio: float = 0.5) -> Tuple[List[int], List[int]]:
    """Split indices into two groups."""
    split_point = int(len(indices) * cal_ratio)
    return indices[:split_point], indices[split_point:]


def generate_kfold_splits(n_samples: int, n_folds: int = 20, seed: int = 42):
    """
    Generate random K-fold splits with fixed seed for reproducibility.

    Args:
        n_samples: Total number of samples
        n_folds: Number of folds (default 20)
        seed: Random seed for reproducibility (default 42)

    Returns:
        List of test indices for each fold

    Example:
        fold_assignments = generate_kfold_splits(202, n_folds=20, seed=42)
        test_idx = fold_assignments[fold_idx]
    """
    from sklearn.model_selection import KFold

    kfold = KFold(n_splits=n_folds, shuffle=True, random_state=seed)
    indices = list(range(n_samples))

    fold_assignments = []
    for train_idx, test_idx in kfold.split(indices):
        fold_assignments.append(test_idx.tolist())

    return fold_assignments


def split_dataset(dataset, train_ratio: float = 0.7, val_ratio: float = 0.15, seed: int = 42,
                  fold_idx: int = None, n_folds: int = None):
    """
    Split dataset into train/val/test indices.

    Args:
        dataset: Dataset object
        train_ratio: Ratio of training data (default 0.7)
        val_ratio: Ratio of validation data (default 0.15)
        seed: Random seed for shuffling (default 42)
        fold_idx: Current fold index for K-fold CV (optional)
        n_folds: Total number of folds for K-fold CV (optional)

    Returns:
        Tuple of (train_idx, val_idx, test_idx)

    If fold_idx and n_folds are provided, uses random K-fold partitioning with
    fixed seed for reproducibility. Otherwise, uses simple random shuffling.
    """
    n = len(dataset)

    if fold_idx is not None and n_folds is not None:
        # K-fold mode: random partitioning with fixed seed
        fold_assignments = generate_kfold_splits(n, n_folds=n_folds, seed=seed)
        test_idx = fold_assignments[fold_idx]

        # Remaining data split into train/val
        remaining = [i for i in range(n) if i not in test_idx]
        train_end = int(len(remaining) * train_ratio / (train_ratio + val_ratio))
        train_idx = remaining[:train_end]
        val_idx = remaining[train_end:]

        return train_idx, val_idx, test_idx
    else:
        # Original random shuffle mode
        indices = list(range(n))
        random.seed(seed)
        random.shuffle(indices)

        train_end = int(n * train_ratio)
        val_end = int(n * (train_ratio + val_ratio))

        return indices[:train_end], indices[train_end:val_end], indices[val_end:]


def get_data_batch(dataset, indices: List[int], noise_dict: Dict[int, float]):
    """Extract data batch for given indices."""
    return (
        [dataset.x[i] for i in indices],
        [dataset.y[i] for i in indices],
        [noise_dict[i] for i in indices]
    )



