import numpy as np
import pandas as pd
from sklearn.metrics import r2_score
import math
import matplotlib.pyplot as plt
import seaborn as sns

datasets = ['School','Chemical','Landmine']
TASKS_DICT = {
    'School': list(range(1, 140)),  # 1-based
    'Landmine': list(range(0, 29)), # 0-based
    'Chemical': [2, 5, 6, 9, 10, 12, 18, 20, 22, 24, 25, 27, 28, 30, 46, 52, 55,
                 57, 59, 61, 67, 70, 76, 78, 80, 81, 83, 84, 85, 86, 87, 89, 90, 91, 92],
    'Parkinsons': list(range(1, 43)), # 1-based
}

ModelName = 'Quadratic'
print(f'MODEL: {ModelName}')
for dataset in datasets:
    print(f'\n******* Dataset: {dataset} *******')
    selected_seed = 2025
    suffix = f'_{selected_seed}'

    TASKS = TASKS_DICT[dataset]
    task_num = len(TASKS)
    task_to_index = {task: i for i, task in enumerate(TASKS)}

    best_run = None
    max_rsq = -math.inf
    AVG_CORR, AVG_R2 = [], []


    RUNS = [1,2,3]
    for run in RUNS:
        # === Load prediction & pair files ===
        if run == 'AVG':
            data = np.load(f'../RESULTS/{ModelName}_{dataset}_BEST_predictions_and_pairs{suffix}.npz')
            pairwise_affinity = pd.read_csv(f'../RESULTS/{dataset}_Pairwise_Affinity_Avg_SGD_FIXED.csv')
        else:
            data = np.load(f'../RESULTS/{ModelName}_{dataset}_BEST_predictions_and_pairs_run_{run}{suffix}.npz')
            pairwise_affinity = pd.read_csv(f'../RESULTS/{dataset}_Pairwise_Affinity_run_{run}_SGD_FIXED.csv')


        pairwise_affinity = np.array(pairwise_affinity)

        train_pred, test_pred = data['train_pred'], data['test_pred']
        train_pairs, test_pairs = data['train_pairs'], data['test_pairs']
        all_preds = np.concatenate([train_pred, test_pred])
        all_pairs = np.concatenate([train_pairs, test_pairs], axis=0)


        pred_matrix = np.zeros((task_num, task_num))
        for idx,pair in enumerate(train_pairs):
            # print(f'idx: {idx}, pair: {pair}, prediction: {train_pred[idx]}')
            i = task_to_index[pair[0]]
            j = task_to_index[pair[1]]
            pred = train_pred[idx]
            pred_matrix[i, j] = pred
            # print(f'pred_matrix[i,j]: {pred_matrix[i,j]}')
        for idx,pair in enumerate(test_pairs):
            # print(f'idx: {idx}, pair: {pair}, prediction: {test_pred[idx]}')
            i = task_to_index[pair[0]]
            j = task_to_index[pair[1]]
            pred = test_pred[idx]
            pred_matrix[i, j] = pred
            # print(f'pred_matrix[i,j]: {pred_matrix[i,j]}')
        # Pick a random pair from test set
        t1, t2 = test_pairs[0]
        # print(f"Pair: {(t1, t2)} | Pred: {test_pred[0]} | True: {pred_matrix[task_to_index[t1], task_to_index[t2]]}--{pred_matrix[task_to_index[t2], task_to_index[t1]]}")
        assert test_pred[0]==pred_matrix[task_to_index[t1], task_to_index[t2]]
        import random
        random_idx = random.sample(range(len(test_pairs)), 10)
        # print(random_idx)
        for each_idx in random_idx:
            t1, t2 = test_pairs[each_idx]
            # print(f"Pair: {(t1, t2)} | Pred: {test_pred[0]} | True: {pred_matrix[task_to_index[t1], task_to_index[t2]]}--{pred_matrix[task_to_index[t2], task_to_index[t1]]}")
            assert test_pred[each_idx] == pred_matrix[task_to_index[t1], task_to_index[t2]]

        # === Save outputs ===
        # print(f'shape of pred_matrix: {pred_matrix.shape}')
        df_pred_matrix = pd.DataFrame(pred_matrix, index=TASKS, columns=TASKS)
        df_pred_matrix.to_csv(f'../RESULTS/New_{ModelName}_{dataset}_Predicted_Pairwise_Affinity_{run}{suffix}.csv')
        np.save(f'../RESULTS/New_{ModelName}_{dataset}_Predicted_Pairwise_Affinity_{run}{suffix}.npy', pred_matrix)

        # === Flatten ground-truth vs predicted (excluding diagonal) ===
        diagonal_indices = np.arange(task_num) * task_num + np.arange(task_num)
        matched_affinity = np.delete(pairwise_affinity.flatten(), diagonal_indices)
        matched_preds = np.delete(pred_matrix.flatten(), diagonal_indices)

        # === Compute correlation on all pairs ===
        corr = np.corrcoef(matched_affinity, matched_preds)[0, 1]
        rsq = r2_score(matched_affinity, matched_preds)
        # AVG_CORR.append(corr)
        # AVG_R2.append(rsq)
        # print(f'ALL Pairs RUN = {run}\t Correlation: {corr:4f}, R2: {rsq:4f}---ALL')

        # === Compute correlation on test set only ===
        test_pair_indices = [(task_to_index[t1], task_to_index[t2]) for t1, t2 in test_pairs]
        flat_indices = [i * task_num + j for i, j in test_pair_indices]
        matched_affinity_test = pairwise_affinity.flatten()[flat_indices]
        test_pred_sorted = test_pred  # already aligned

        corr_test = np.corrcoef(matched_affinity_test, test_pred_sorted)[0, 1]
        rsq_test = r2_score(matched_affinity_test, test_pred_sorted)

        AVG_CORR.append(corr_test)
        AVG_R2.append(rsq_test)

        print(f'TEST PAIRS RUN = {run}\t Corr(Test): {corr_test:.3f}\t R2(Test): {rsq_test:.3f}')

        if rsq_test > max_rsq:
            max_rsq = rsq_test
            best_run = run

    print(f'\n*** {ModelName} | Dataset {dataset} | Avg Corr: {np.mean(AVG_CORR):.2f}$\pm$ {np.std(AVG_CORR):0.2f}, '
          f'Avg R2: {np.mean(AVG_R2):.2f}$\pm$ {np.std(AVG_R2):0.2f}')
    print(f'Best run = {best_run} with R2 = {max_rsq:.3f}')

