import numpy as np
import pandas as pd
import random
from scipy.stats import pearsonr
import matplotlib.pyplot as plt

np.random.seed(2025)
random.seed(2025)

datasets = ['School','Chemical','Landmine'  ]
All_DF = {}
AVG = 0
for RUN in [4,5,6]:

    datasets = ['School','Chemical','Landmine',]
    fig_bar,ax_bar = plt.subplots(1,3)
    All_DF = {'Features': [],
              'School_Correlation': [], #'School_Pvalue': [],
                'Chemical_Correlation': [], #'Chemical_Pvalue': [],
              'Landmine_Correlation': [], #'Landmine_Pvalue': []
    }

    for dataset in datasets:
        print(f'\n******* Dataset: {dataset} *******')

        DataPath = f"../Dataset/{dataset.upper()}/"

        # Load files
        if AVG:
            pairwise_affinity = pd.read_csv(f'../RESULTS/{dataset}_Pairwise_Affinity_Avg_SGD_FIXED.csv')
        else:
            pairwise_affinity = pd.read_csv(f'../RESULTS/{dataset}_Pairwise_Affinity_run_{RUN}_SGD_FIXED.csv')

        task_relation_features = pd.read_csv(f'{DataPath}Pairwise_Task_Features_{dataset}_FIXED.csv')
        print(f'shape of task_relation_features: {task_relation_features.shape}')
        print(f'shape of pairwise_Affinities: {pairwise_affinity.shape}')
        print(task_relation_features.columns)

        task_relation_features = task_relation_features[['Task1', 'Task2',
                                                         'Total_Dataset_Size',
                                                         'DatasetSize_Ratio_t1', 'DatasetSize_Ratio_t2',
                                                         'DatasetSize_Diff',
                                                         'Unified_Distance', 'Distance_Diff_over_Sum',
                                                         'Distance_Diff_over_Prod',
                                                         'Unified_Dist_over_Sum', 'Unified_Dist_over_Prod',
                                                         'Energy_Distance',
                                                         'Mean_Diff_L2',
                                                         'Cosine_Similarity',
                                                         'PCA_Top_CosSim_Mean',
                                                         'Rank_based_Similarity',
                                                         'Graph_based_Similarity',
                                                         ]]

        print(f'Total columns: {task_relation_features.shape[1]}')
        ALL_Columns = list(task_relation_features.columns)
        ALL_Columns.remove('Task1')
        ALL_Columns.remove('Task2')
        ALL_Columns.remove('DatasetSize_Ratio_t1')
        ALL_Columns.remove('DatasetSize_Ratio_t2')
        # print(ALL_Columns)
        # print(task_relation_features.columns)
        ALL_Columns = [ALL_Columns[0], 'Dataset_Ratio'] + ALL_Columns[1:]
        stl_res = pd.read_csv(f'../RESULTS/{dataset}_FIXED_STL_Avg.csv')
        tasks = list(stl_res['TASKS'])  # ensure integer task IDs
        tasks = [int(t) for t in tasks]
        tasks = sorted(tasks)
        print(f'tasks: {tasks}')
        task_num = len(tasks)
        pairwise_affinity_flat = np.array(pairwise_affinity).flatten()
        diagonal_indices = np.arange(task_num) * task_num + np.arange(task_num)

        # Remove diagonal elements
        filtered_pairwise_affinity = np.delete(pairwise_affinity_flat, diagonal_indices)

        print(f'shape of filtered_pairwise_affinity = {filtered_pairwise_affinity.shape}')

        # Create lookup dictionary for ground-truth affinities
        affinity_lookup = {}

        pairwise_affinity = pairwise_affinity.values
        for i in range(len(tasks)):
            for j in range(i + 1, len(tasks)):  # Only upper triangle (avoid duplicates)
                t1 = tasks[i]
                t2 = tasks[j]
                affinity = pairwise_affinity[i, j]
                affinity_lookup[(t1, t2)] = pairwise_affinity[i, j]
                affinity_lookup[(t2, t1)] = pairwise_affinity[j, i]  # Symmetric
        print(f'shape of affinity_lookup = {len(affinity_lookup)}')


        affinities = []
        feature_matrix = []

        for idx, row in task_relation_features.iterrows():
            t1, t2 = int(row['Task1']), int(row['Task2'])
            affinities.append(affinity_lookup[(t1, t2)])
            feature_matrix.append(row.drop(['Task1', 'Task2','DatasetSize_Ratio_t2']).values)
            affinities.append(affinity_lookup[(t2, t1)])
            feature_matrix.append(row.drop(['Task1', 'Task2', 'DatasetSize_Ratio_t1']).values)

        X = np.array(feature_matrix)
        y = np.array(affinities)

        print(f'shape of x: {X.shape}, y = {len(y)}')
        print(f'len(ALL_Columns) = {len(ALL_Columns)}')
        # Compute Pearson correlation
        correlations = {}
        pvalues = {}
        for idx, feature in enumerate(ALL_Columns):

            X_feature = X[:, idx]  # using one feature only
            # print(f'\n******* Feature: {feature} *******')
            corr, pval = pearsonr(X_feature, y)
            # corr, pval = spearmanr(feature_matrix[:, i], affinities)
            correlations[feature] = corr
            pvalues[feature] = pval

        corr_val = list(correlations.values())
        corr_val = [round(corr,5) for corr in corr_val]
        p_val = list(pvalues.values())
        p_val = [round(pval,5) for pval in p_val]
        correlation_df = pd.DataFrame({
            'Feature': list(correlations.keys()),
            'Pearson_Correlation': corr_val,
            'P_Value': p_val
        })#.sort_values('Pearson_Correlation', key=lambda x: abs(x), ascending=False)

        print(correlation_df)

        if dataset=='Chemical':
            All_DF['Features'] = list(correlations.keys())
            All_DF['Chemical_Correlation'] = corr_val
            # All_DF['Chemical_Pvalue'] = p_val

        if dataset == 'Landmine':
            All_DF['Landmine_Correlation'] = corr_val
            # All_DF['Landmine_Pvalue'] = p_val

        if dataset=='School':
            All_DF['School_Correlation'] = corr_val
            # All_DF['School_Pvalue'] = p_val

        if dataset=='Parkinsons':
            All_DF['Parkinsons_Correlation'] = corr_val
            # All_DF['Parkinsons_Pvalue'] = p_val

        '''get a bar chart for the features and correlations'''

        FEATURES = list(correlation_df['Feature'])
        Feature_name_dict = {'Total_Dataset_Size': '$(D_{t_i} + D_{t_j})$',}
        for idx in range(len(FEATURES)):
            feature = FEATURES[idx]
            feature = feature.replace('_', ' ')
            if feature == 'Mean Diff L2':
                FEATURES[idx] = 'Euclidean Distance ($\mu$)'
            if 'L2' in feature:
                feature = feature.replace(' L2', '')
            if 'PCA Top CosSim Mean' in feature:
                feature = 'PCA Cosine Similarity ($\mu$)'
            FEATURES[idx] = feature

        All_DF['Features'] = FEATURES
        Feature_name_dict = {
            'Dataset Ratio': r'$\texttt{Data-Ratio}$',
            'Total Dataset Size': r'$\texttt{Data-Size }$',
            'DatasetSize Diff': r'$\texttt{Size-Gap}$',

            'Distance Diff': '$|d_{t_i}-d_{t_j}|$',
            'Unified Distance': '$d_{(t_i+t_j)}$',
            # 'Distance Diff over Sum': r'$\frac{|d_{t_i}-d_{t_j}|}{d_{t_i}+d_{t_j}}$',
            # 'Distance Diff over Prod': r'$\frac{|d_{t_i}-d_{t_j}|}{\sqrt{d_{t_i}\cdot d_{t_j}}}$',
            # 'Unified Dist over Sum': r'$\frac{d_{(t_i+t_j)}}{d_{t_i}+d_{t_j}}$',
            # 'Unified Dist over Prod': r'$\frac{d_{(t_i+t_j)}}{\sqrt{d_{t_i}\cdot d_{t_j}}}$',

            'Distance Diff over Sum': r'$|d_{t_i}-d_{t_j}| ÷ (d_{t_i}+d_{t_j})$',
            'Distance Diff over Prod': r'$|d_{t_i}-d_{t_j}| \div (\sqrt{d_{t_i}\cdot d_{t_j}})$',
            'Unified Dist over Sum': r'$ d_{(t_i+t_j)} \div (d_{t_i}+d_{t_j})$',
            'Unified Dist over Prod': r'$d_{(t_i+t_j)} \div (\sqrt{d_{t_i}\cdot d_{t_j}})$',

            'Energy Distance':r'$\texttt{Energy Distance}$',
            'Rank based Similarity': r'$\texttt{Rank-Div}$',
            'Graph based Similarity': r'$\texttt{Cross-Link}$',
            'Cosine Similarity': r'$\texttt{Cosine-Sim}$',
            'Mean Diff':  r'$\texttt{Feature-Mean Gap}$',
            'Skewness Diff': r'$\texttt{Skewness Gap}$',
            'PCA Cosine Similarity ($\mu$)': r'$\texttt{PCA-Align}$'
        }
        col_idx = datasets.index(dataset)

        ax_bar[col_idx].bar(correlation_df['Feature'], correlation_df['Pearson_Correlation'], color='orange')
        ax_bar[col_idx].set_xlabel('Individual Feature')
        ax_bar[col_idx].set_ylabel('Pearson_Correlation')
        ax_bar[col_idx].set_xticklabels(FEATURES, rotation=90, )
        ax_bar[col_idx].set_title(f'{dataset.upper()}')
        ax_bar[col_idx].grid(True)

        print(f'Features: {FEATURES}')
        NEW_FEATURE_NAME = []
        for each_feature in FEATURES:
            print(each_feature, end = '-->')
            print(Feature_name_dict[each_feature])
            NEW_FEATURE_NAME.append(Feature_name_dict[each_feature])
        All_DF['Features'] = NEW_FEATURE_NAME

        # Optional: Save correlations
        correlation_df.to_csv(f'../RESULTS/Feature_vs_Affinity_Correlation_{dataset}.csv', index=False)

     # All_DF['Features'] = NEW_FEATURE_NAME
    fig_bar.suptitle(f'RUN: {RUN} --- Correlation with Pairwise MTL Gains', fontsize=22, color='blue')
    fig_bar.tight_layout()
    fig_bar.show()

    All_DF = pd.DataFrame(All_DF)
    All_DF.to_csv(f'../RESULTS/Feature_vs_MTLGain_Correlation_Run_{RUN}.csv', index=False)


plt.show()