from sklearn.linear_model import Ridge, LinearRegression
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.metrics import mean_squared_error, r2_score
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import PolynomialFeatures
import random
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import seaborn as sns

seed_val = 2025
np.random.seed(seed_val)
random.seed(seed_val)

def modify_feature_name(feat):
    feature = feat.replace('_', ' ')
    feature = feature.replace(' L2 ', '')
    return feature

datasets = ['School','Chemical','Landmine'  ]
All_DF = {}
AVG = 0
# ModelName = 'LR'
ModelName = 'Quadratic'
for RUN in [4,5,6]:
    for dataset in datasets:
        print(f'\n******* Dataset: {dataset} *******')

        DataPath = f"../Dataset/{dataset.upper()}/"
        # Load files
        if AVG:
            pairwise_affinity = pd.read_csv(f'../RESULTS/{dataset}_Pairwise_Affinity_Avg_SGD_FIXED.csv')
        else:
            pairwise_affinity = pd.read_csv(f'../RESULTS/{dataset}_Pairwise_Affinity_run_{RUN}_SGD_FIXED.csv')
        task_relation_features = pd.read_csv(f'{DataPath}Pairwise_Task_Features_{dataset}_FIXED.csv')
        print(f'shape of task_relation_features: {task_relation_features.shape}')
        print(f'shape of pairwise_Affinities: {pairwise_affinity.shape}')
        print(task_relation_features.columns)

        task_relation_features = task_relation_features[['Task1', 'Task2',
                                                         'Total_Dataset_Size',
                                                         'DatasetSize_Ratio_t1', 'DatasetSize_Ratio_t2',
                                                         'DatasetSize_Diff',
                                                         'Unified_Distance', 'Distance_Diff_over_Sum',
                                                         'Distance_Diff_over_Prod',
                                                         'Unified_Dist_over_Sum', 'Unified_Dist_over_Prod',
                                                         'Energy_Distance',
                                                         'Mean_Diff_L2',
                                                         'Cosine_Similarity',
                                                         'PCA_Top_CosSim_Mean',
                                                         'Rank_based_Similarity',
                                                         'Graph_based_Similarity',
                                                         ]]

        print(f'Total columns: {task_relation_features.shape[1]}')
        ALL_Columns = list(task_relation_features.columns)
        ALL_Columns.remove('Task1')
        ALL_Columns.remove('Task2')
        ALL_Columns.remove('DatasetSize_Ratio_t1')
        ALL_Columns.remove('DatasetSize_Ratio_t2')
        ALL_Columns = [ALL_Columns[0], 'Dataset_Ratio'] + ALL_Columns[1:]
        stl_res = pd.read_csv(f'../RESULTS/{dataset}_FIXED_STL_Avg.csv')
        tasks = list(stl_res['TASKS'])  # ensure integer task IDs
        tasks = [int(t) for t in tasks]
        tasks = sorted(tasks)
        print(f'tasks: {tasks}')
        task_num = len(tasks)
        pairwise_affinity_flat = np.array(pairwise_affinity).flatten()
        diagonal_indices = np.arange(task_num) * task_num + np.arange(task_num)

        # Remove diagonal elements
        filtered_pairwise_affinity = np.delete(pairwise_affinity_flat, diagonal_indices)

        print(f'shape of filtered_pairwise_affinity = {filtered_pairwise_affinity.shape}')

        # Create lookup dictionary for ground-truth affinities
        affinity_lookup = {}

        pairwise_affinity = pairwise_affinity.values
        for i in range(len(tasks)):
            for j in range(i + 1, len(tasks)):  # Only upper triangle (avoid duplicates)
                t1 = tasks[i]
                t2 = tasks[j]
                affinity = pairwise_affinity[i, j]
                affinity_lookup[(t1, t2)] = pairwise_affinity[i, j]
                affinity_lookup[(t2, t1)] = pairwise_affinity[j, i]
        affinities = []
        feature_matrix = []

        for idx, row in task_relation_features.iterrows():
            t1, t2 = int(row['Task1']), int(row['Task2'])
            affinities.append(affinity_lookup[(t1, t2)])
            feature_matrix.append(row.drop(['Task1', 'Task2','DatasetSize_Ratio_t2']).values)
            affinities.append(affinity_lookup[(t2, t1)])
            feature_matrix.append(row.drop(['Task1', 'Task2', 'DatasetSize_Ratio_t1']).values)

        X = np.array(feature_matrix)
        y = np.array(affinities)

        print(f'shape of x: {X.shape}, y = {len(y)}')
        print(f'train size: {len(X)*0.25}, test size: {len(X)*0.75}')

        print(f'len(ALL_Columns): {len(ALL_Columns)}')
        results = []


        def remove_highly_correlated_features(X_df, threshold=0.95):
            X_df = X_df.drop('Task1', axis=1)
            X_df = X_df.drop('Task2', axis=1)
            corr_matrix = X_df.corr().abs()
            upper = corr_matrix.where(np.triu(np.ones(corr_matrix.shape), k=1).astype(bool))
            # print(upper)
            '''save matrix'''
            # upper.sort_values(ascending=False, inplace=True)
            upper_df = pd.DataFrame(upper, columns=X_df.columns)

            upper_df.to_csv(f'../RESULTS/{dataset}_COLINEARITY_MATRIX.csv')


            to_drop = [column for column in upper.columns if any(upper[column] > threshold)]
            print(f"Removing {len(to_drop)} features due to high collinearity: {to_drop}")
            return X_df.drop(columns=to_drop), to_drop


        def remove_redundant_features_by_correlation(X_df, threshold=0.95):
            corr_matrix = X_df.corr().abs()
            to_drop = set()
            features = X_df.columns

            for i in range(len(features)):
                for j in range(i + 1, len(features)):
                    corr_val = corr_matrix.iloc[i, j]
                    if corr_val > threshold:
                        f1, f2 = features[i], features[j]

                        # Compute mean absolute correlation of each feature with others
                        f1_corr_mean = corr_matrix[f1].drop(f2).mean()
                        f2_corr_mean = corr_matrix[f2].drop(f1).mean()

                        # Drop the more redundant one (higher mean correlation)
                        if f1_corr_mean > f2_corr_mean:
                            to_drop.add(f1)
                        else:
                            to_drop.add(f2)

            print(f"Removing {len(to_drop)} features due to high pairwise collinearity.")
            print(f'features: {to_drop}')
            return to_drop
            # return X_df.drop(columns=list(to_drop)), list(to_drop)



        for idx,feature in enumerate(ALL_Columns):
            X_feature = X[:,idx].reshape(-1, 1) #using one feature only

            # print(f'X_train shape: {X_train.shape}, y_train: {len(y_train)}, y_test: {len(y_test)}')
            if ModelName == 'Quadratic':
                poly = PolynomialFeatures(degree=2, include_bias=False)
                X_poly = poly.fit_transform(X_feature)  # shape (n, 2): x, x^2
                X_poly = StandardScaler().fit_transform(X_poly) # normalizing poly features
                y_centered = y - np.mean(y)  # <- Center y to help with bias
                X_train, X_test, y_train, y_test = train_test_split(X_poly, y_centered, test_size=0.75)
                # X_train, X_test, y_train, y_test = train_test_split(X_poly, y, test_size=0.75)
            else:
                X_train, X_test, y_train, y_test = train_test_split(X_feature, y, test_size=0.75)

            if ModelName == 'Quadratic':
                ridge = Ridge()
                # Define the grid of hyperparameters to search
                param_grid = {
                    'alpha': np.logspace(-10, 1, 300),#[ 1e-6, 1e-5,1e-4, 1e-3, 1e-2, 0.1, 1.0, 10],
                    'fit_intercept': [True, False],
                    'solver': ['auto', 'svd', 'cholesky']#, 'lsqr', 'sparse_cg', 'sag', 'saga']
                }

                # Set up the grid search
                grid_search = GridSearchCV(estimator=ridge,
                                           param_grid=param_grid,
                                           cv=5,  # 5-fold cross-validation
                                           scoring='r2',
                                           n_jobs=-1,  # Use all CPUs
                                           verbose=0)

                # Fit on your training data
                grid_search.fit(X_train, y_train)

                # Best model and hyperparameters
                print("Best parameters found:", grid_search.best_params_)
                print("Best R² score:", grid_search.best_score_)
                model = grid_search.best_estimator_


            elif ModelName == 'LR':
                model = LinearRegression(fit_intercept=True)


            elif ModelName == 'Ridge':
                alpha = np.logspace(-6, 4, 200)
                grid = GridSearchCV(estimator=Ridge(), param_grid={'alpha': alpha}, cv=5)
                grid.fit(X_train, y_train)
                best_alpha = grid.best_params_['alpha']
                model = Ridge(alpha=best_alpha)


            scaler = StandardScaler()
            X_train = scaler.fit_transform(X_train)
            X_test = scaler.transform(X_test)
            y_scaler = StandardScaler()
            y_train = y_scaler.fit_transform(y_train.reshape(-1, 1)).ravel()
            y_test = y_scaler.transform(y_test.reshape(-1, 1)).ravel()

            model.fit(X_train, y_train)
            y_pred = model.predict(X_test)


            r2 = r2_score(y_test, y_pred)
            correlation = np.corrcoef(y_test, y_pred)[0][1]
            feature = modify_feature_name(feature)

            results.append({'Feature': feature, 'Correlation': correlation, 'R2_Score': r2,  'Correlation^2': correlation**2})

        # Save results to CSV
        results_df = pd.DataFrame(results)#.sort_values(by='R2_Score', ascending=False)
        if AVG:
            results_df.to_csv(f'../RESULTS/{dataset}_Individual_Feature_R2_{ModelName}_FIXED.csv', index=False)
        else:
            results_df.to_csv(f'../RESULTS/{dataset}_Individual_Feature_R2_{ModelName}_FIXED_RUN_{RUN}.csv', index=False)
        print(results_df)
        # Collect result for plotting
        results_df['Dataset'] = dataset


        if dataset == datasets[0]:
            All_DF['Features'] = list(results_df['Feature'])
        rsq = list(results_df['R2_Score'])
        # rsq = [val for val in rsq]
        corr = list(results_df['Correlation'])

        All_DF[f'{dataset}_R2_Score'] = rsq
        All_DF[f'{dataset}_Correlation'] = corr

        if 'all_results' not in globals():
            all_results = results_df
        else:
            all_results = pd.concat([all_results, results_df], axis=0)

        Feature_name_dict = {
            'Dataset Ratio': r'$\texttt{Data-Ratio}$',
            'Total Dataset Size': r'$\texttt{Data-Size }$',
            'DatasetSize Diff': r'$\texttt{Size-Gap}$',

            'Distance Diff': '$|d_{t_i}-d_{t_j}|$',
            'Unified Distance': '$d_{(t_i+t_j)}$',
            # 'Distance Diff over Sum': r'$\frac{|d_{t_i}-d_{t_j}|}{d_{t_i}+d_{t_j}}$',
            # 'Distance Diff over Prod': r'$\frac{|d_{t_i}-d_{t_j}|}{\sqrt{d_{t_i}\cdot d_{t_j}}}$',
            # 'Unified Dist over Sum': r'$\frac{d_{(t_i+t_j)}}{d_{t_i}+d_{t_j}}$',
            # 'Unified Dist over Prod': r'$\frac{d_{(t_i+t_j)}}{\sqrt{d_{t_i}\cdot d_{t_j}}}$',

            'Distance Diff over Sum': r'$|d_{t_i}-d_{t_j}|\div d_{t_i}+d_{t_j}$',
            'Distance Diff over Prod': r'$|d_{t_i}-d_{t_j}| \div \sqrt{d_{t_i}\cdot d_{t_j}}$',
            'Unified Dist over Sum': r'$d_{(t_i+t_j)} \div d_{t_i}+d_{t_j}$',
            'Unified Dist over Prod': r'$d_{(t_i+t_j)} \div \sqrt{d_{t_i}\cdot d_{t_j}}$',

            'Energy Distance': r'$\texttt{Energy Distance}$',
            'Rank based Similarity': r'$\texttt{Rank-Div}$',
            'Graph based Similarity': r'$\texttt{Cross-Link}$',
            'Cosine Similarity': r'$\texttt{Cosine-Sim}$',
            'Mean Diff L2': r'$\texttt{Feature-Mean Gap}$',
            'Skewness Diff L2': r'$\texttt{Skewness Gap}$',
            'PCA Top CosSim Mean': r'$\texttt{PCA-Align}$',

            # 'Energy Distance':r'$\texttt{Energy Distance}(t_i \leftrightarrow t_j)$',
            # 'Rank based Similarity': r'$\texttt{Rank-Div}(t_i \leftrightarrow t_j)$',
            # 'Graph based Similarity': r'$\texttt{Cross-Link}(t_i \leftrightarrow t_j)$',
            # 'Cosine Similarity': r'$\texttt{Cosine-Sim}(t_i \leftrightarrow t_j)$',
            # 'Mean Diff L2': r'$\texttt{Feature-Mean Gap}$',
            # 'Skewness Diff L2': r'$\texttt{Skewness Gap}$',
            # 'PCA Top CosSim Mean': r'$\texttt{PCA-Align}(t_i \leftrightarrow t_j)$'
        }
        NEW_FEATURE_NAME = []
        FEATURES = list(results_df['Feature'])
        for each_feature in FEATURES:
            print(each_feature, end='-->')
            print(Feature_name_dict[each_feature])
            NEW_FEATURE_NAME.append(Feature_name_dict[each_feature])
        All_DF['Features'] = NEW_FEATURE_NAME





    # Optional: reset index
    all_results.reset_index(drop=True, inplace=True)

    # # Plot R2 score for each feature per dataset
    # fig, axes = plt.subplots(1, len(datasets), figsize=(18, 6), sharey=True)
    # fig.suptitle(f'Individual Feature R2 Scores ({ModelName})', fontsize=16)
    #
    # for i, dataset in enumerate(datasets):
    #     df = all_results[all_results['Dataset'] == dataset]
    #     sns.barplot(data=df, x='R2_Score', y='Feature', ax=axes[i])#, palette='viridis')
    #     axes[i].set_title(dataset)
    #     axes[i].set_xlabel('R² Score')
    #     axes[i].set_ylabel('' if i > 0 else 'Feature')
    #     axes[i].grid(axis='x')
    # fig.show()

    all_results.reset_index(drop=True, inplace=True)

    # Plot R2 score for each feature per dataset (rotated)
    fig, axes = plt.subplots(1, len(datasets), figsize=(18, 6), sharey=False)
    fig.suptitle(f'RUN: {RUN}--Feature R2 Scores (Labels vs. Predictions using {ModelName})', fontsize=16)

    for i, dataset in enumerate(datasets):
        df = all_results[all_results['Dataset'] == dataset]
        sns.barplot(data=df, x='Feature', y='R2_Score', ax=axes[i])
        axes[i].set_title(f'{dataset}:max: {max(df["R2_Score"]):0.5f}')
        axes[i].set_xlabel('Feature')
        axes[i].set_ylabel('R² Score' if i == 0 else '')
        axes[i].tick_params(axis='x', rotation=90)
        axes[i].grid(axis='y')

    fig.tight_layout()

    # Plot Correlation score for each feature per dataset
    # Plot Correlation score for each feature per dataset (rotated)
    fig, axes = plt.subplots(1, len(datasets), figsize=(18, 6), sharey=False)
    fig.suptitle(f'RUN: {RUN}--Feature Correlation (Labels vs. Predictions using {ModelName})', fontsize=16)

    for i, dataset in enumerate(datasets):
        df = all_results[all_results['Dataset'] == dataset]
        sns.barplot(data=df, x='Feature', y='Correlation', ax=axes[i], color='orange')
        axes[i].set_title(f'{dataset}:max: {max(df["Correlation"]):0.5f}')
        axes[i].set_xlabel('Feature')
        axes[i].set_ylabel('Correlation' if i == 0 else '')
        axes[i].tick_params(axis='x', rotation=90)
        axes[i].grid(axis='y')

    fig.tight_layout(rect=[0, 0, 1, 0.95])

    print(f'All_DF : {len(All_DF)}')

    All_DF = pd.DataFrame(All_DF)
    if AVG:
        All_DF.to_csv(f'../RESULTS/Feature_PredictionPerformance_Correlation_{ModelName}_FIXED.csv', index=False)
    else:
        All_DF.to_csv(f'../RESULTS/Feature_PredictionPerformance_Correlation_{ModelName}_FIXED_RUN_{RUN}.csv', index=False)

plt.show()