"""
Test dataset loading and verifier configuration

Example to run specific test:
    pytest test_dataset.py -k "test_load_datasets"
    pytest test_dataset.py -k "test_load_datasets and MATH-500 and 8B"
    pytest test_dataset.py -k "test_verifier_cfg and reward_models and small and all"
    pytest test_dataset.py -k "test_distance and MATH-500 and 8B and mean_verifier_distance"
    pytest test_dataset.py -k "test_different_train_split_values"

"""
import sys
import os
try:
    import weaver
except:
    sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from weaver.dataset import VerificationDataset
from weaver.constants import DATASET_TO_HF
from dataclasses import dataclass
import numpy as np

import pytest


DATA_IGNORE_LIST = ["CodeContests_gonly"]

ALL_DATASETS = list(set(DATASET_TO_HF.keys()) - set(DATA_IGNORE_LIST))
ALL_MODEL_SIZES = list(set(DATASET_TO_HF[ALL_DATASETS[0]].keys()) - set(DATA_IGNORE_LIST))


DATASET_FILTER = ["MATH-500-v2"]
MODEL_SIZE_FILTER = ["8B"]
if len(DATASET_FILTER) > 0:
    ALL_DATASETS = [d for d in ALL_DATASETS if d in DATASET_FILTER]
if len(MODEL_SIZE_FILTER) > 0:
    ALL_MODEL_SIZES = [m for m in ALL_MODEL_SIZES if m in MODEL_SIZE_FILTER]


def get_verifier_cfg(verifier_type, verifier_size, verifier_subset):
    class VerifierConfig:
        def __init__(self):
            self.verifier_type: str = verifier_type
            self.verifier_size: str = verifier_size
            self.verifier_subset: str = verifier_subset
        
        def get(self, key, default=None):
            """ Mimic dictionary .get() behavior """
            return getattr(self, key, default)

    return VerifierConfig()


def get_data_cfg():
    data_cfg = {
        "train_split": 1.0,
        "train_queries": 1,
        "train_samples": 1,
        "random_seed": 0,
        "nan_replacement": 0,
        "reward_threshold": None,
        "normalize_type": "per_problem",
        "normalize_method": "minmax",
        "closest_train_problem_method": "mean_verifier_distance",
        "closest_train_problem_metric_type": "euclidean",
        "verifier_cfg": get_verifier_cfg("all", "all", []),
        "mv_as_verifier": False,
        "fixed_test_split": None,
        "same_train_test": False,
        "train_split_bins": 1,
        "normalize_params": {
            "output_distribution": "normal",
            "n_quantiles": 100,
        }
    }
    return data_cfg



@pytest.mark.parametrize("dataset_name", ALL_DATASETS)
@pytest.mark.parametrize("model_size", ALL_MODEL_SIZES)
def test_load_datasets(dataset_name, model_size):
    """Test loading a dataset and model size."""
    print(f"Testing dataset: {dataset_name}, model size: {model_size}")  # Print start

    data_cfg = get_data_cfg()
    try:            
        dataset = VerificationDataset(dataset_name, model_size, **data_cfg)
        assert dataset is not None, f"Failed to load {dataset_name} for {model_size}"
    except Exception as e:
        raise Exception(f"Error loading {dataset_name} with model size {model_size}: {e}")
    
    return


#@pytest.mark.parametrize("verifier_type", ["reward_models", "judges"])
#@pytest.mark.parametrize("verifier_size", ["small", "medium", "large"])
def no_test_verifier_cfg(verifier_type, verifier_size):
    """
    Test all combinations of verifier_type, verifier_size, verifier_subset except for all
    """
    dataset_name = "MATH-500"
    model_size = "8B"
    data_cfg = get_data_cfg()
    data_cfg["verifier_cfg"] = get_verifier_cfg(verifier_type, verifier_size, [])
    print(f"Testing dataset: {dataset_name}, model size: {model_size}")  # Print start
    try:            
        dataset = VerificationDataset(dataset_name, model_size, **data_cfg)
        df, correct_key = dataset.load_task_data()
        assert df is not None, f"Failed to load {dataset_name} for {model_size}"
        print(f"Passed: {dataset_name} with model size {model_size} {verifier_type} {verifier_size}")  # Print pass
    except Exception as e:
        raise Exception(f"Error loading {dataset_name} with model size {model_size} {verifier_type} {verifier_size}: {e}")


@pytest.mark.parametrize("dataset_name", ALL_DATASETS)
@pytest.mark.parametrize("model_size", ALL_MODEL_SIZES)
@pytest.mark.parametrize("train_queries", [2])
@pytest.mark.parametrize("closest_train_problem_method", ["mean_verifier_distance", "SBERT"])
def test_distance(dataset_name, model_size, closest_train_problem_method, train_queries):
    """
    When train set = test set, the closest train problem to each test problem is itself
    """
    data_cfg = get_data_cfg()
    data_cfg["verifier_cfg"] = get_verifier_cfg("all", "all", [])
    data_cfg["train_queries"] = train_queries
    data_cfg["same_train_test"] = True
    data_cfg["closest_train_problem_method"] = closest_train_problem_method
    dataset = VerificationDataset(dataset_name, model_size, **data_cfg)

    # Check that the closest train problem to each test problem is itself
    closest_train_idxs = dataset.closest_train_idxs[:,0]
    test_idx = dataset.test_idx
    train_idx = dataset.train_idx[closest_train_idxs]

    assert all(train_idx == test_idx), \
        (f"Closest train problem failed for {dataset_name} model size {model_size} methods {closest_train_problem_method}")


@pytest.mark.parametrize("dataset_name", ["AIMO-v2"])
@pytest.mark.parametrize("model_size", ["8B"])
def test_different_seeds_produce_different_splits(dataset_name, model_size):
    """Test that different random seeds produce different train-test splits."""
    # Create two datasets with different seeds
    data_cfg = get_data_cfg()
    data_cfg['train_split']   = 0.8
    data_cfg["random_seed"]   = 0
    
    dataset1 = VerificationDataset(dataset_name, model_size, **data_cfg)

    data_cfg["random_seed"]   = 1
    dataset2 = VerificationDataset(dataset_name, model_size, **data_cfg)
            
    # Get the train and test indices from both datasets
    train_idx1, test_idx1 = dataset1.train_idx, dataset1.test_idx
    train_idx2, test_idx2 = dataset2.train_idx, dataset2.test_idx
    
    print(f"train_idx1: {train_idx1}")
    print(f"train_idx2: {train_idx2}")
    print(f"test_idx1: {test_idx1}")
    print(f"test_idx2: {test_idx2}")
    
    # Verify that the splits are different
    # We check both train and test indices since they are complementary
    assert not np.array_equal(train_idx1, train_idx2), "Train splits should be different with different seeds"
    assert not np.array_equal(test_idx1, test_idx2), "Test splits should be different with different seeds"
    

    # Verify that the splits are valid
    # Check that train and test indices are disjoint
    assert len(set(train_idx1) & set(test_idx1)) == 0, "Train and test indices should be disjoint"
    assert len(set(train_idx2) & set(test_idx2)) == 0, "Train and test indices should be disjoint"
    
    # Check that all indices are covered
    all_indices1 = set(train_idx1) | set(test_idx1)
    all_indices2 = set(train_idx2) | set(test_idx2)
    assert len(all_indices1) == len(train_idx1) + len(test_idx1), "All indices should be covered"
    assert len(all_indices2) == len(train_idx2) + len(test_idx2), "All indices should be covered"


@pytest.mark.parametrize("dataset_name", ["AIMO-v2"])
@pytest.mark.parametrize("model_size", ["8B"])
def test_same_seed_produces_same_splits(dataset_name, model_size):
    """Test that the same random seed produces identical train-test splits."""
    # Create two datasets with the same seed
    data_cfg = get_data_cfg()
    data_cfg['train_split']   = 0.8
    data_cfg["random_seed"]   = 0
    
    dataset1 = VerificationDataset(dataset_name, model_size, **data_cfg)
    dataset2 = VerificationDataset(dataset_name, model_size, **data_cfg)
        
    # Get the train and test indices from both datasets
    train_idx1, test_idx1 = dataset1.train_idx, dataset1.test_idx
    train_idx2, test_idx2 = dataset2.train_idx, dataset2.test_idx
    
    # Verify that the splits are identical
    assert np.array_equal(train_idx1, train_idx2), "Train splits should be identical with same seed"
    assert np.array_equal(test_idx1, test_idx2), "Test splits should be identical with same seed"


@pytest.mark.parametrize("train_samples", [1.0])  # Now includes a >1 case
@pytest.mark.parametrize("train_queries", [1.0])
@pytest.mark.parametrize("dataset_name", ["MATH-500-v2"])
@pytest.mark.parametrize("model_size", ["8B"])
@pytest.mark.parametrize("train_split", [0.5, 1.0])
def test_different_train_split_values(dataset_name, model_size, train_split, train_samples, train_queries):
    """
    Test that different train_split, train_samples, train_queries produce
    consistent train-test shapes.
    """
    data_cfg = get_data_cfg()
    # 1) Load baseline dataset with 1.0 for everything (no subsampling).
    data_cfg["train_split"]   = 1.0
    data_cfg["train_samples"] = 1
    data_cfg["train_queries"] = 1

    dataset_baseline = VerificationDataset(dataset_name, model_size, **data_cfg)

    # Baseline shapes
    base_train_problems = dataset_baseline.train_data[0].shape[0]  # (rows)
    base_train_samples  = dataset_baseline.train_data[0].shape[1]  # (columns)
    #base_test_problems  = dataset_baseline.test_data[0].shape[0]
    #base_test_samples   = dataset_baseline.test_data[0].shape[1]
    base_verifiers      = len(dataset_baseline.verifier_names)

    # 2) Load new dataset with parametrized values
    data_cfg["train_split"]   = train_split
    data_cfg["train_samples"] = train_samples
    data_cfg["train_queries"] = train_queries

    dataset_new = VerificationDataset(dataset_name, model_size, **data_cfg)

    # New shapes
    new_train_problems = dataset_new.train_data[0].shape[0]
    new_train_samples  = dataset_new.train_data[0].shape[1]
    new_test_problems  = dataset_new.test_data[0].shape[0]
    new_verifiers      = len(dataset_new.verifier_names)

    # Compute expected shapes based on your splitting and subsampling logic:
    #
    # 1) train_split -> how many rows go to train vs. test (unless train_split=1.0 => test = train).
    #
    if float(train_split) < 1.0:
        expected_train_problems = int(round(base_train_problems * train_split))
        expected_test_problems  = int(round(base_train_problems * (1.0 - train_split)))
    else:
        # train_split = 1 => train == test == full
        expected_train_problems = base_train_problems
        expected_test_problems  = base_train_problems

    # 2) train_queries -> fraction or absolute number of the *training problems*.
    #
    if train_queries < 1:
        expected_train_problems = int(round(expected_train_problems * train_queries))
    elif train_queries > 1:
        # If user requests more problems than exist, clamp to the max
        expected_train_problems = min(expected_train_problems, int(train_queries))

    # 3) train_samples -> fraction or absolute number of the *sample dimension* for BOTH train and test.
    #
    #    If < 1.0 => fraction of baseline #samples
    #    If > 1.0 => absolute # of samples, clamped to maximum available
    #
    if train_samples < 1:
        expected_train_samples = int(round(base_train_samples * train_samples))
    elif train_samples > 1:
        # clamp to the maximum possible
        expected_train_samples = min(base_train_samples, int(train_samples))
    else:
        # train_samples = 1.0 => full
        expected_train_samples = base_train_samples

    # 4) Check shapes:
    #
    # Verifiers
    try:
        assert new_verifiers == base_verifiers, (
            f"Number of verifiers should remain the same. Expected {base_verifiers}, got {new_verifiers}."
        )

        # Problems (rows)
        assert abs(new_train_problems - expected_train_problems) <= 1, (
            f"Train problems mismatch. Expected {expected_train_problems}, got {new_train_problems} "
            f"(train_split={train_split}, train_queries={train_queries})."
        )
        assert abs(new_test_problems - expected_test_problems) <= 1, (
            f"Test problems mismatch. Expected {expected_test_problems}, got {new_test_problems} "
            f"(train_split={train_split})."
        )

        # Samples (columns)
        assert abs(new_train_samples - expected_train_samples) <= 1, (
            f"Train samples mismatch. Expected {expected_train_samples}, got {new_train_samples} "
            f"(train_samples={train_samples})."
        )
    except:
        breakpoint()