from sklearn.linear_model import Ridge, LinearRegression
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.metrics import mean_squared_error, r2_score
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import PolynomialFeatures
import random
pd.set_option('display.max_columns', None)
import tqdm
import math

selected_seed = 2025
np.random.seed(selected_seed)
random.seed(selected_seed)


datasets = ['School','Chemical','Landmine'  ]
All_DF = {}
SAVE = False
# ModelName = 'LR'
ModelName = 'Quadratic'
MAX_rsq = {}
MAX_Corr = {}
BEST_RUN_DICT = {'School': 1, 'Chemical': 3, 'Landmine': 2}
RUNS = [1,2,3,'AVG',4,5,6]
School_Corr = []
School_Rsq = []
for RUN in RUNS:
    print(f'\n****************RUN: {RUN}')
    for dataset in datasets:
        if dataset not in MAX_rsq:
            MAX_rsq[dataset] = (-math.inf,None)
            MAX_Corr[dataset] = (-math.inf,None)
        print(f'\n******* Dataset: {dataset} *******')

        DataPath = f"../Dataset/{dataset.upper()}/"

        # Load files
        if RUN == 'AVG':
            print(f'reading AVG file')
            pairwise_affinity = pd.read_csv(f'../RESULTS/{dataset}_Pairwise_Affinity_Avg_SGD_FIXED.csv')
        else:
            pairwise_affinity = pd.read_csv(f'../RESULTS/{dataset}_Pairwise_Affinity_run_{RUN}_SGD_FIXED.csv')

        task_relation_features = pd.read_csv(f'{DataPath}Pairwise_Task_Features_{dataset}_FIXED.csv')
        task_relation_features = task_relation_features[['Task1', 'Task2',
                                                         'DatasetSize_Ratio_t1', 'DatasetSize_Ratio_t2',
                                                         'Total_Dataset_Size',
                                                         'DatasetSize_Diff',
                                                         'Distance_Diff',
                                                         'Unified_Distance', 'Distance_Diff_over_Sum',
                                                         'Distance_Diff_over_Prod',
                                                         'Unified_Dist_over_Sum', 'Unified_Dist_over_Prod',
                                                         'Energy_Distance',
                                                         'Rank_based_Similarity', 'Graph_based_Similarity',
                                                         'Cosine_Similarity',
                                                         'Mean_Diff_L2',
                                                         'Skewness_Diff_L2','PCA_Top_CosSim_Mean'
                                                         ]]

        # print(f'Total columns: {task_relation_features.shape[1]}')
        ALL_Columns = list(task_relation_features.columns)
        ALL_Columns.remove('Task1')
        ALL_Columns.remove('Task2')
        ALL_Columns.remove('DatasetSize_Ratio_t1')
        ALL_Columns.remove('DatasetSize_Ratio_t2')
        ALL_Columns = ['Dataset_Ratio'] + ALL_Columns
        stl_res = pd.read_csv(f'../RESULTS/{dataset}_FIXED_STL_Avg.csv')
        tasks = list(stl_res['TASKS'])  # ensure integer task IDs
        tasks = [int(t) for t in tasks]
        tasks = sorted(tasks)
        # print(f'tasks: {tasks}')
        task_num = len(tasks)
        pairwise_affinity_flat = np.array(pairwise_affinity).flatten()
        diagonal_indices = np.arange(task_num) * task_num + np.arange(task_num)

        # Remove diagonal elements
        filtered_pairwise_affinity = np.delete(pairwise_affinity_flat, diagonal_indices)

        # print(f'shape of filtered_pairwise_affinity = {filtered_pairwise_affinity.shape}')

        # Create lookup dictionary for ground-truth affinities
        affinity_lookup = {}

        pairwise_affinity = pairwise_affinity.values
        for i in range(len(tasks)):
            for j in range(i + 1, len(tasks)):  # Only upper triangle (avoid duplicates)
                t1 = tasks[i]
                t2 = tasks[j]
                affinity = pairwise_affinity[i, j]
                affinity_lookup[(t1, t2)] = pairwise_affinity[i, j]
                affinity_lookup[(t2, t1)] = pairwise_affinity[j, i]  # Symmetric
        print(f'shape of affinity_lookup = {len(affinity_lookup)}')


        affinities = []
        feature_matrix = []
        PAIRS = []
        for idx, row in task_relation_features.iterrows():
            t1, t2 = int(row['Task1']), int(row['Task2'])
            PAIRS.append((t1, t2))
            affinities.append(affinity_lookup[(t1, t2)])
            feature_matrix.append(row.drop(['Task1', 'Task2','DatasetSize_Ratio_t2']).values)
            affinities.append(affinity_lookup[(t2, t1)])
            PAIRS.append((t2,t1))
            feature_matrix.append(row.drop(['Task1', 'Task2', 'DatasetSize_Ratio_t1']).values)

        X = np.array(feature_matrix)
        y = np.array(affinities)
        PAIRS = np.array(PAIRS)
        print(f'shape of x: {X.shape}, y = {len(y)}, Pairs = {len(PAIRS)}')

        # continue

        feature_types = {
            "Dataset": ['Dataset_Ratio', 'Total_Dataset_Size', 'DatasetSize_Diff'],
            "Distance": ['Distance_Diff', 'Unified_Distance', 'Energy_Distance'],
            "NormDistance": ['Distance_Diff_over_Sum', 'Distance_Diff_over_Prod',
                             'Unified_Dist_over_Sum', 'Unified_Dist_over_Prod'],
            "Similarity": ['Cosine_Similarity', 'Rank_based_Similarity', 'Graph_based_Similarity'],
            "FeatureDiff": ['Mean_Diff_L2', 'Skewness_Diff_L2']
        }


        # Manually selected promising feature subsets
        final_feature_sets = {'COMMON_FEAT': ['Total_Dataset_Size','Dataset_Ratio',
                                               'Distance_Diff_over_Prod','Unified_Dist_over_Prod',
                                               'Energy_Distance',
                                               'Rank_based_Similarity',
                                               'Mean_Diff_L2']}

        results = []

        for set_name, selected_feats in final_feature_sets.items():

            # print(f"\n🔍 Evaluating feature set: {set_name} => {selected_feats}")
            indices = [ALL_Columns.index(f) for f in selected_feats]
            X_subset = X[:, indices]

            if ModelName == 'Quadratic':
                poly = PolynomialFeatures(degree=2, include_bias=False)
                X_poly = poly.fit_transform(X_subset)  # shape (n, 2): x, x^2
                n_samples = len(X_poly)
                all_indices = np.arange(n_samples)

                # Split features, labels, and indices
                X_train, X_test, y_train, y_test, idx_train, idx_test = train_test_split(X_poly, y, all_indices, test_size=0.75)
            else:
                n_samples = len(X_subset)
                all_indices = np.arange(n_samples)
                X_train, X_test, y_train, y_test, idx_train, idx_test = train_test_split(X_subset, y, all_indices,
                                                                                         test_size=0.75)
            train_pairs = PAIRS[idx_train]
            test_pairs = PAIRS[idx_test]

            x_scaler = StandardScaler()
            y_scaler = StandardScaler()

            X_train = x_scaler.fit_transform(X_train)
            X_test = x_scaler.transform(X_test)

            y_train = y_scaler.fit_transform(y_train.reshape(-1, 1)).ravel()
            y_test = y_scaler.transform(y_test.reshape(-1, 1)).ravel()

            if ModelName == 'Quadratic':
                # alpha = np.logspace(-6, 4, 200)
                param_grid = {
                    'alpha': np.logspace(-10, 4, 400),  # [ 1e-6, 1e-5,1e-4, 1e-3, 1e-2, 0.1, 1.0, 10],
                    'fit_intercept': [True, False],
                    'solver': ['auto', 'svd', 'cholesky']  # , 'lsqr', 'sparse_cg', 'sag', 'saga']
                }
                # Set up the grid search
                grid = GridSearchCV(estimator=Ridge(), param_grid=param_grid, cv=5, n_jobs=-1, scoring='neg_mean_squared_error')
                grid.fit(X_train, y_train)
                best_alpha = grid.best_params_['alpha']
                model = grid.best_estimator_
            else:
                model = LinearRegression()
                model.fit(X_train, y_train)

            if ModelName == 'LR':
                print(f'\nsetname: {set_name}')
                # print(f'the coefficient for feature {selected_feats[0]} is \n {model.coef_}')
                for idx,each_feature in enumerate(selected_feats):
                    print(f'{each_feature} : Coeff = {model.coef_[idx]}')

            # Evaluate

            y_pred = model.predict(X_test)
            all_r2 = r2_score(y_test, y_pred)
            all_correlation = np.corrcoef(y_test, y_pred)[0,1]
            print(f'{dataset.upper()}, R2 Score: {all_r2:0.4f}, '
                  f'correlation:  {all_correlation:0.4f},'
                  f'')

            train_pred = model.predict(X_train)
            test_pred = model.predict(X_test)
            train_pred = y_scaler.inverse_transform(train_pred.reshape(-1, 1)).ravel()
            test_pred = y_scaler.inverse_transform(test_pred.reshape(-1, 1)).ravel()


            if dataset =='School' and set_name == 'COMMON_FEAT':
                School_Corr.append(all_correlation)
                School_Rsq.append(all_r2)

            results.append(
                {'Set_Name':set_name, 'Feature': selected_feats, 'Correlation': all_correlation, 'R2_Score': all_r2})

            if ModelName == 'Quadratic' and set_name == 'COMMON_FEAT':
                SAVE = True
                if MAX_rsq[dataset][0]<=all_r2:
                    rq_val = round(all_r2, 5)
                    MAX_rsq[dataset]=(rq_val,RUN)
                if MAX_Corr[dataset][0]<=all_correlation:
                    corr_val = round(all_correlation, 5)
                    MAX_Corr[dataset]=(corr_val,RUN)

            elif ModelName == 'LR' and set_name == 'COMMON_FEAT':
                SAVE = True
            else:
                SAVE = False

            if SAVE:
                if RUN == 'AVG':
                    np.savez(
                        f'../RESULTS/{ModelName}_{dataset}_BEST_predictions_and_pairs_{selected_seed}.npz',
                        train_pred=train_pred,
                        test_pred=test_pred,
                        train_pairs=train_pairs,
                        test_pairs=test_pairs
                    )
                else:
                    np.savez(
                        f'../RESULTS/{ModelName}_{dataset}_BEST_predictions_and_pairs_run_{RUN}_{selected_seed}.npz',
                        train_pred=train_pred,
                        test_pred=test_pred,
                        train_pairs=train_pairs,
                        test_pairs=test_pairs
                    )

        results_df = pd.DataFrame(results)
        results_df['Dataset'] = dataset

        if dataset == datasets[0]:
            All_DF['Set_Name'] = list(results_df['Set_Name'])
            All_DF['Features'] = list(results_df['Feature'])
        rsq = list(results_df['R2_Score'])
        corr = list(results_df['Correlation'])

        All_DF[f'{dataset}_R2_Score'] = rsq
        All_DF[f'{dataset}_Correlation'] = corr

        if 'all_results' not in globals():
            all_results = results_df
        else:
            all_results = pd.concat([all_results, results_df], axis=0)


    All_DF = pd.DataFrame(All_DF)
    All_DF.to_csv(path_or_buf=f'../RESULTS/FeatureSet_Performance_{ModelName}.csv', index=False)

print(f'Max: {MAX_rsq}')
print(f'Max Corr: {MAX_Corr}')
print(f'School Avg corr: {np.mean(School_Corr):.4f} $\pm$ {np.std(School_Corr):.4f}')
print(f'School Corr: {School_Corr}')

print(f'School R2 Score: {np.mean(School_Rsq):.4f} $\pm$ {np.std(School_Rsq):.4f}')
print(f'School R2: {School_Rsq}')
