"""
Test that performance is reprocible given config files
"""
from types import SimpleNamespace
import pytest
import numpy as np
import pandas as pd
import sys
import os
try:
    import weaver
except:
    sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from weaver.dataset import VerificationDataset, ClusteringDataset
from weaver.models import Model
from typing import Dict


use_continuous = True

if use_continuous:
    reward_threshold = None
else:
    reward_threshold = 0.5

def get_data_cfg() -> SimpleNamespace:
    data_cfg = {
        "train_split": 1.0,
        "train_queries": 1,
        "train_samples": 1,
        "random_seed": 0,
        "nan_replacement": 0,
        "reward_threshold":reward_threshold,
        "normalize_type": "all_problems",
        "normalize_method": "minmax",
        "closest_train_problem_method": "mean_verifier_distance",
        "closest_train_problem_metric_type": "euclidean",
        "verifier_cfg": {
            "verifier_type": "all",
            "verifier_size": "all",
            "verifier_subset": [],
        },
        "mv_as_verifier": False,
        #"fixed_test_split": None, 
        "same_train_test": False,
        "train_split_bins": 3,
        "normalize_params": {
            "tmp": None,
        }
    }
    return SimpleNamespace(**data_cfg)



def get_model_config(model_type: str) -> SimpleNamespace:
    model_cfg = {
        "model_type": model_type,
        "model_class": "cluster",
        "model_params": {
            "use_continuous": use_continuous,
            "k": 2,
            "seed": 0,
            "binarize_threshold": reward_threshold,
            "metric": "scores",
            "n_epochs": 5000,
            "mu_epochs": 10000,
            "log_train_every": 1000,
            "lr": 0.0001,
            "use_deps": None,
            "use_label_on_test": True,
            "deps_data_fraction": 1.0,
            "drop_imbalanced_verifiers": "small",
            "cb_args": {
                "class_balance": "labels",
            },
        },
        "cluster_cfg": {
            "n_clusters": 2,
            "cluster_type": "by_difficulty",
        }
    }
    return model_cfg

@pytest.mark.parametrize(
    "model_type",
    ["weak_supervision"]
)
def test_reproducibility(model_type: str) -> None:
    """
    Test a specific model type on synthetic data.
    
    Args:
        model_type: Type of model to test
    """
    # Generate synthetic test data
    dataset_name = "MATH-500-v2"
    model_size = "70B"
    data_cfg = get_data_cfg()
    data = VerificationDataset(dataset_name, model_size, **vars(data_cfg))

    verifier_names = data.verifier_names

    # Configure and instantiate model
    model_cfg = get_model_config(model_type)

    # Configure and instantiate clustering model
    model_cfg["model_params"]["cb_args"] = SimpleNamespace(
        **model_cfg["model_params"]["cb_args"]
    )
    clusters = ClusteringDataset(**model_cfg["cluster_cfg"])
    clusters.compute_clusters(data, mode="train")
    num_models = len(clusters.train_clusters)

    if data_cfg.reward_threshold is not None:
        data.binarize_verifiers(clusters, split="train")
        data.binarize_verifiers(clusters, split="test")


    model = Model(verifier_names,
                  clusters=clusters,
                  **model_cfg,
                  num_models=num_models)
    
    
    X_train, y_train = data.train_data
    # Fit the model
    for idx in range(len(model.clusters.train_clusters)):
        cluster_idxs = model.clusters.train_clusters[idx]
        X, y = X_train[cluster_idxs], y_train[cluster_idxs]
        model.fit(X, y, group_idx=idx)
        model.models[idx].is_test = True

    # Prediction with model
    num_train_problems = len(data.train_idx)
    all_results = []
    for idx in range(num_train_problems):
        problem_idx = idx if model.model_class in ["per_problem", "cluster"]  else None
        outputs = model.calculate_metrics(X_train[idx], y_train[idx], problem_idx=problem_idx)
        if model.model_class == "cluster":
            outputs["cluster_id"] = model.problem_idx_to_group_idx(problem_idx)
        all_results.append(outputs)
        
    all_results = pd.DataFrame(all_results)
    
    assert all_results["top1_positive"].mean() >= 0.92, f"Top 1 positive accuracy is {all_results['top1_positive'].mean()}"
    # Validate metrics
    return