import pandas as pd
import numpy as np
from sklearn.metrics import r2_score
import ast
import random

DATASETS = ['School', 'Chemical', 'Landmine']

def measure_Affinity(dataset, single_results, Tot_Run_list, Estimate_Version, Arch):
    """
    Compute estimated pairwise task affinities for given runs and version.
    Returns a DataFrame of shape (num_tasks, num_tasks).
    """
    # Initialize
    Estimated_Task_Affinity_GRADTAE = {task: {task: [] for task in TASKS} for task in TASKS}

    for run in Tot_Run_list:
        # Load the corresponding CSV for this version and run
        estimates = pd.read_csv(f'Results/{dataset}_Fast_Estimation_w_Time_run_{run}_SGD{Arch}_subsetNum_{Estimate_Version}.csv')

        # Parse Task_group and Individual_Task_Score columns
        Random_Subsets = [ast.literal_eval(x) for x in estimates['Task_group']]
        Test_Loss_estimates = [ast.literal_eval(x) for x in estimates['Individual_Task_Score']]
        estimates['Task_group'] = Random_Subsets

        # Compute pairwise losses
        task_subset_losses = {task: {task: [] for task in TASKS} for task in TASKS}
        for t1 in TASKS:
            for t2 in TASKS:
                if t1 == t2:
                    continue
                all_subsets = estimates[estimates.Task_group.apply(lambda x: t1 in x and t2 in x)]
                if len(all_subsets) == 0:
                    continue
                Test_Loss = [ast.literal_eval(x) for x in all_subsets['Individual_Task_Score']]
                tmp_loss_t1 = [x[t1] for x in Test_Loss]
                tmp_loss_t2 = [x[t2] for x in Test_Loss]

                task_subset_losses[t2][t1] = np.mean(tmp_loss_t1)
                task_subset_losses[t1][t2] = np.mean(tmp_loss_t2)

        # Convert losses to affinities
        for t1 in TASKS:
            for t2 in TASKS:
                if t1 == t2:
                    Estimated_Task_Affinity_GRADTAE[t1][t2].append(10000)
                    continue
                stl_loss = single_results[single_results['TASKS'] == t2]['Total_Loss'].values[0]
                affinity = (stl_loss - task_subset_losses[t1][t2]) / stl_loss
                Estimated_Task_Affinity_GRADTAE[t1][t2].append(affinity)

    # Compute mean over runs for each task pair
    mean_affinity = {t1: {t2: np.mean(vals) for t2, vals in affinity.items()}
                     for t1, affinity in Estimated_Task_Affinity_GRADTAE.items()}

    Estimated_Task_Affinity_df = pd.DataFrame(mean_affinity)
    Estimated_Task_Affinity_df.columns = TASKS

    # Save CSV with version in filename
    Estimated_Task_Affinity_df.to_csv(f'Results/Estimated_Task_Affinity_{dataset}_{Estimate_Version}_M_{len(Tot_Run_list)}.csv', index=False)
    return Estimated_Task_Affinity_df


for dataset in DATASETS:
    print(f"\nProcessing dataset: {dataset}")
    Avg_Corr = []
    Avg_Rsq = []

    single_results = pd.read_csv(f'../RESULTS/{dataset}_FIXED_STL_Avg.csv')
    TASKS = [int(t) for t in single_results['TASKS']]

    # Define versions and runs (ensure different runs per version if needed)
    versions_runs = {
        0: [1, 2, 3, 4, 5],
        1: [1, 2, 3, 4, 5],
        2: [1, 2, 3, 4, 5],
        3: [1, 2, 3, 4, 5],
        4: [1, 2, 3, 4, 5],
    }


    Arch_default = '_Arch_Arch_1'


    # # Compute affinities for all versions
    for Estimate_Version, run_list in versions_runs.items():
        if M==1:
            run_list = random.sample(run_list, M)
        if M==5:
            run_list = random.sample(run_list, 5)

        Arch = Arch_default
        print(f'Estimate_Version: {Estimate_Version}, runs: {run_list}')
        measure_Affinity(dataset, single_results, run_list, Estimate_Version, Arch)

    # Compute correlation with actual pairwise affinities
    tmp_corr = []
    tmp_rsq = []
    M = 5
    for pair_run in [1,2,3,4,5]:
        actual_pairwise_affinities = pd.read_csv(f'../RESULTS/{dataset}_Pairwise_Affinity_run_{pair_run}_SGD_FIXED.csv')
        actual_pairwise_affinities = np.array(actual_pairwise_affinities, dtype=float)

        for Estimate_Version in versions_runs.keys():
            estimated_affinities = pd.read_csv(f'Results/Estimated_Task_Affinity_{dataset}_{Estimate_Version}_M_{M}.csv')
            estimated_affinities = np.array(estimated_affinities, dtype=float)

            # Align tasks for non-Chemical datasets
            if dataset != 'Chemical':
                idx = np.array(TASKS) - 1
                estimated_affinities = estimated_affinities[idx, :][:, idx]
                actual_pairwise_affinities_aligned = actual_pairwise_affinities[idx, :][:, idx]
            else:
                actual_pairwise_affinities_aligned = actual_pairwise_affinities

            '''Evaluation'''
            mask = ~np.eye(len(TASKS), dtype=bool)
            est_flat = estimated_affinities[mask]
            true_flat = actual_pairwise_affinities_aligned[mask]

            valid = ~np.isnan(est_flat)
            est_flat = est_flat[valid]
            true_flat = true_flat[valid]

            corr = np.corrcoef(est_flat, true_flat)[0, 1]
            r2 = r2_score(true_flat, est_flat)

            tmp_corr.append(corr)
            tmp_rsq.append(r2)

        Avg_Corr.append(np.mean(tmp_corr))
        Avg_Rsq.append(np.mean(tmp_rsq))
        print(f"Dataset: {dataset}, Version: {Estimate_Version}, Corr={np.mean(tmp_corr):.4f} $\pm${np.std(tmp_corr):.2f}, R2={np.mean(tmp_rsq):.4f} $\pm${np.std(tmp_corr):.2f}")
    print(f'Avg Corr: {np.mean(Avg_Corr):0.4f} $\pm$ {np.std(Avg_Corr):0.2f}')
    print(f'Avg Rsq: {np.mean(Avg_Rsq):0.2f} $\pm$ {np.std(Avg_Rsq):0.2f}')
