import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from datasets import load_dataset
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from pathlib import Path
import sys
sys.path.insert(0, str(Path(__file__).resolve().parents[2]))

from weaver.utils import get_best_ensemble_answer, classify_difficulty, get_accuracy, calculate_majority_at_k, compute_pass_at_k_gt, get_best_ensemble_answer_p_problem, preprocess_score
from weaver.constants import DATASET_TO_REWARD_MODELS, DATASET_TO_LM_JUDGES, REWARD_MODELS_NAME_MAP, JUDGE_NAME_MAP

FIGURES_DIR = Path("figures3")
FIGURES_DIR.mkdir(exist_ok=True)


from scipy.spatial.distance import cdist
import numpy as np


def find_closest_train_problem(X_train, X_test):
    """
    Finds the closest problem in the train set for each problem in the test set.

    Args:
        X_train: (num_train_problems, num_samples, num_verifiers) - Train set
        X_test: (num_test_problems, num_samples, num_verifiers) - Test set

    Returns:
        closest_train_idx: (num_test_problems,) - Closest train problem indices for each test problem
    """
    # Aggregate each problem's samples into a single representation (e.g., mean across samples)
    X_train_repr = X_train.mean(axis=1)  # Shape: (num_train_problems, num_verifiers)
    X_test_repr = X_test.mean(axis=1)    # Shape: (num_test_problems, num_verifiers)

    # Compute pairwise distances between test problems and train problems
    distances = cdist(X_test_repr, X_train_repr, metric='euclidean')  # Shape: (num_test_problems, num_train_problems)

    # Find the closest train problem for each test problem
    closest_train_idx = np.argmin(distances, axis=1)  # Shape: (num_test_problems,)
    distances = np.min(distances, axis=1)  # Shape: (num_test_problems,)
    return closest_train_idx, distances



def load_task_data(dataset_name, model_size):
    """Loads dataset for the specified task and model size from huggingface."""

    # Some special cases
    if dataset_name == "MATH-500":
        if model_size == "70B":
            dataset = "anonymous_research/MATH500_with_Llama_3.1_70B_Instruct"
        else:
            dataset = "anonymous_research/MATH-500_with_Llama_3.1_8B_Instruct"
    elif dataset_name == "CodeContests_gonly":
        dataset = f"anonymous_research/CodeContests_with_Llama_3.3_{model_size}_Instruct_GENERATIONS_ONLY"
    else:
        dataset = f"anonymous_research/{dataset_name}_with_Llama_3.1_{model_size}_Instruct"

    # Determine correct key for correctness field
    if dataset_name == "CodeContests_gonly":
        df = pd.DataFrame(load_dataset(dataset)["train"])
        correct_key = "is_corrects"
    else:
        df = pd.DataFrame(load_dataset(dataset)["data"])
        correct_key = "answer_correct"

    return df



# Function to classify difficulty based on correctness
def split_p_difficulty(X_data: np.array, y_data: np.array, test_split=0.2, random_state=42):
    """
    Splits the data into train and test sets based on the mean correctness of the problems.
    so it uses y_data to classify the problems into three levels of difficulty: Easy, Medium, Hard
    The difficulty level is the mean accuracy of the problem.
    """
    num_problems, num_responses, num_verifiers = X_data.shape

    levels = ["Hard", "Medium", "Easy"]
    thresholds = [0.1, 0.5]
    mean_correct = y_data.mean(1) # num_problems x num_verifiers
    assignments = np.digitize(mean_correct, thresholds, right=True)
    assignment_names= [levels[a] for a in assignments]
    # We will have three datasets
    train_idx, test_idx = train_test_split(np.arange(num_problems),
                                           test_size=test_split,
                                           random_state=random_state,
                                           stratify=assignments)
    
    assert len(np.unique(assignments[train_idx])) == 3
    assert len(np.unique(assignments[test_idx])) == 3

    return (train_idx, test_idx, assignments)


def train_and_evaluate_lr(X_data, y_data, test_split=0.2,random_state=42):
    """
    Trains per-problem logistic regression models on train set,

    """
    
    (train_idx_easy, test_idx_easy, assignments) = split_p_difficulty(X_data, y_data, test_split=test_split, random_state=random_state)

    # Replace NaN values with the column mean
    X_data = np.nan_to_num(X_data, nan=np.nanmean(X_data, axis=0))
    all_results = []

    # Train a logistic regression model for each problem in the train set
    for prob_idx in train_idx_easy:
        assignment = assignments[prob_idx]
        X, y = X_data[prob_idx], y_data[prob_idx]
        if len(np.unique(y)) == 1:
            continue  # Skip if all labels are the same

        clf = LogisticRegression(random_state=0, max_iter=1000)
        clf.fit(X, y)

        acc = clf.score(X, y)

        results_entry = {"problem": prob_idx, "set": "train", "accuracy": acc, "clf": clf, "difficulty": assignment}
        all_results.append(results_entry)

    all_results = pd.DataFrame(all_results)

    # Get the closest train problem for each test problem
    train_prob_used =  all_results['problem'].values
    closest_train_idx, distance_to_idx = find_closest_train_problem(X_data[train_prob_used], X_data[test_idx_easy])

    all_test_results = []

    # Then apply the classifier to each problem in the test set
    for idx, test_idx in enumerate(test_idx_easy):
        assignment = assignments[test_idx]
        dist_ = distance_to_idx[idx]
        c_train_idx = closest_train_idx[idx]
        try:
            clf = all_results.loc[c_train_idx].clf
            X, y = X_data[test_idx], y_data[test_idx]
            if len(np.unique(y)) == 1:
                continue  # Skip if all labels are the same
            acc = clf.score(X, y)
        except:
            breakpoint()
        results_entry = {"problem": test_idx, "set": "test", "accuracy": acc, "close_train_idx": c_train_idx, "distance": dist_, "difficulty": assignment}
        
        all_test_results.append(results_entry)

    all_test_results = pd.DataFrame(all_test_results)
    return all_results, all_test_results


def get_all_verifiers(dataset_name):
        # Some special cases
    if dataset_name == "MATH-500":
        if model_size == "70B":
            dataset = "anonymous_research/MATH500_with_Llama_3.1_70B_Instruct"
        else:
            dataset = "anonymous_research/MATH-500_with_Llama_3.1_8B_Instruct"
    elif dataset_name == "CodeContests_gonly":
        dataset = f"anonymous_research/CodeContests_with_Llama_3.3_{model_size}_Instruct_GENERATIONS_ONLY"
    else:
        dataset = f"anonymous_research/{dataset_name}_with_Llama_3.1_{model_size}_Instruct"


    all_verifiers = DATASET_TO_REWARD_MODELS.get(dataset, [])
    all_judges = DATASET_TO_LM_JUDGES.get(dataset, [])
    verifier_names = all_verifiers + all_judges
    return verifier_names


if __name__ == "__main__":
    datasets = ["MATH-500", "AIMO", "MMLU-Pro", "GPQA", "MMLU-College", "AlpacaEval", "BBH", "ArenaHard"]
    datasets = ["MATH-500"]
    model_sizes = ["8B", "70B"]
    model_sizes = ["8B"]
    results = []

    for dataset_name in datasets:
        for model_size in model_sizes:
            try:
                df = load_task_data(dataset_name, model_size)
                # Get the correct answers
                y_data = np.stack(df["answer_correct"].values).astype(int) # num_problems x num_responses
                # Get the verifiers
                all_verifiers = get_all_verifiers(dataset_name)

                # Build X_data as (num_problems, num_responses, num_verifiers)
                # Which has all the verifier scores for each response
                verifier_matrices = []
                verifier_names = []
                for verifier in all_verifiers:
                    if verifier == "mv_verifier":
                        continue
                    else:
                        raw_scores = df[verifier].values  # list of arrays
                        norm_scores = [preprocess_score(s, verifier) for s in raw_scores]
                        verifier_names.append(verifier)
                        verifier_matrices.append(np.asarray(norm_scores).squeeze())
                X_data = np.stack(verifier_matrices, axis=-1)  # shape: (num_problems, num_responses, num_verifiers)

                # Remap verifier names so they are consistent across datasets
                verifier_names2 = []
                for v in verifier_names:
                    if v in REWARD_MODELS_NAME_MAP:
                        verifier_names2.append(REWARD_MODELS_NAME_MAP[v])
                    elif v in JUDGE_NAME_MAP:
                        verifier_names2.append(JUDGE_NAME_MAP[v])
                    else:
                        raise NotImplementedError(f"Unknown verifier name: {v}")
                verifier_names = verifier_names2

                
                # Train and evaluate a logistic regression model
                df_train, df_test = train_and_evaluate_lr(X_data, y_data)
                results.append({"dataset": dataset_name, "model_size": model_size, "train_results": df_train, "test_results": df_test})
                print("Results: Train Accuracy {:.2f} Test Accuracy {:.2f}".format(df_train["accuracy"].mean(), df_test["accuracy"].mean()))
                # Stratify by difficulty
                difficulty_levels = ["Hard", "Medium", "Easy"]
                for difficulty in [0, 1, 2]:
                    df_train_diff = df_train[df_train["difficulty"] == difficulty]
                    df_test_diff = df_test[df_test["difficulty"] == difficulty]
                    print(f"Results with difficulty {difficulty_levels[difficulty]}:\n \t  \
                          Train samples:{len(df_train_diff['accuracy'])} Train Accuracy {df_train_diff['accuracy'].mean():.2f} Test samples:{len(df_test_diff['accuracy'])}, Test Accuracy {df_test_diff['accuracy'].mean():.2f}")
                print(f"Completed: {dataset_name} - {model_size}")

            except Exception as e:
                print(f"Failed: {dataset_name} - {model_size} | Error: {e}")

    # Save overall results
    pd.DataFrame(results).to_csv(FIGURES_DIR / "full_train_test_results.csv", index=False)
    print("All done!")