import time
from sklearn.linear_model import Ridge, LinearRegression
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.metrics import mean_squared_error, r2_score
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import PolynomialFeatures
import random
pd.set_option('display.max_columns', None)
import tqdm
import time
seed = 2021
np.random.seed(seed)
random.seed(seed)

datasets = ['School','Chemical','Landmine'  ]
RUNS = [1,2,3,4,5,6]
All_DF = {}
SAVE = False
# ModelName = 'LR'
ModelName = 'Quadratic'
for RUN in RUNS:
    if RUN == 'AVG':
        AVG = True
    else:
        AVG = False
    print(f'\n****************RUN: {RUN}')
    for dataset in datasets:
        print(f'\n******* Dataset: {dataset} *******')

        DataPath = f"../Dataset/{dataset.upper()}/"

        # Load files
        if AVG:
            pairwise_affinity = pd.read_csv(f'../RESULTS/{dataset}_Pairwise_Affinity_Avg_SGD_FIXED.csv')
        else:
            pairwise_affinity = pd.read_csv(f'../RESULTS/{dataset}_Pairwise_Affinity_run_{RUN}_SGD_FIXED.csv')

        task_relation_features = pd.read_csv(f'{DataPath}Pairwise_Task_Features_{dataset}_FIXED.csv')
        task_relation_features = task_relation_features[['Task1', 'Task2',
                                                         'DatasetSize_Ratio_t1', 'DatasetSize_Ratio_t2',
                                                         'Total_Dataset_Size',
                                                         'DatasetSize_Diff',
                                                         'Distance_Diff',
                                                         'Unified_Distance', 'Distance_Diff_over_Sum',
                                                         'Distance_Diff_over_Prod',
                                                         'Unified_Dist_over_Sum', 'Unified_Dist_over_Prod',
                                                         'Energy_Distance',
                                                         'Rank_based_Similarity', 'Graph_based_Similarity',
                                                         'Cosine_Similarity',
                                                         'Mean_Diff_L2',
                                                         'Skewness_Diff_L2','PCA_Top_CosSim_Mean'
                                                         ]]

        # print(f'Total columns: {task_relation_features.shape[1]}')
        ALL_Columns = list(task_relation_features.columns)
        ALL_Columns.remove('Task1')
        ALL_Columns.remove('Task2')
        ALL_Columns.remove('DatasetSize_Ratio_t1')
        ALL_Columns.remove('DatasetSize_Ratio_t2')
        ALL_Columns = ['Dataset_Ratio'] + ALL_Columns
        stl_res = pd.read_csv(f'../RESULTS/{dataset}_FIXED_STL_Avg.csv')
        tasks = list(stl_res['TASKS'])  # ensure integer task IDs
        tasks = [int(t) for t in tasks]
        tasks = sorted(tasks)
        # print(f'tasks: {tasks}')
        task_num = len(tasks)
        pairwise_affinity_flat = np.array(pairwise_affinity).flatten()
        diagonal_indices = np.arange(task_num) * task_num + np.arange(task_num)

        # Remove diagonal elements
        filtered_pairwise_affinity = np.delete(pairwise_affinity_flat, diagonal_indices)

        # print(f'shape of filtered_pairwise_affinity = {filtered_pairwise_affinity.shape}')

        # Create lookup dictionary for ground-truth affinities
        affinity_lookup = {}

        pairwise_affinity = pairwise_affinity.values
        for i in range(len(tasks)):
            for j in range(i + 1, len(tasks)):  # Only upper triangle (avoid duplicates)
                t1 = tasks[i]
                t2 = tasks[j]
                affinity = pairwise_affinity[i, j]
                affinity_lookup[(t1, t2)] = pairwise_affinity[i, j]
                affinity_lookup[(t2, t1)] = pairwise_affinity[j, i]  # Symmetric
        print(f'shape of affinity_lookup = {len(affinity_lookup)}')


        affinities = []
        feature_matrix = []
        PAIRS = []
        for idx, row in task_relation_features.iterrows():
            t1, t2 = int(row['Task1']), int(row['Task2'])
            PAIRS.append((t1, t2))
            affinities.append(affinity_lookup[(t1, t2)])
            feature_matrix.append(row.drop(['Task1', 'Task2','DatasetSize_Ratio_t2']).values)
            affinities.append(affinity_lookup[(t2, t1)])
            PAIRS.append((t2,t1))
            feature_matrix.append(row.drop(['Task1', 'Task2', 'DatasetSize_Ratio_t1']).values)

        X = np.array(feature_matrix)
        y = np.array(affinities)
        scaler = StandardScaler()
        X_scaled = scaler.fit_transform(X)
        X = X_scaled

        PAIRS = np.array(PAIRS)
        print(f'shape of x: {X.shape}, y = {len(y)}, Pairs = {len(PAIRS)}')


        feature_types = {
            "Dataset": ['Dataset_Ratio', 'Total_Dataset_Size', 'DatasetSize_Diff'],
            "Distance": ['Distance_Diff', 'Unified_Distance', 'Energy_Distance'],
            "NormDistance": ['Distance_Diff_over_Sum', 'Distance_Diff_over_Prod',
                             'Unified_Dist_over_Sum', 'Unified_Dist_over_Prod'],
            "Similarity": ['Cosine_Similarity', 'Rank_based_Similarity', 'Graph_based_Similarity'],
            "FeatureDiff": ['Mean_Diff_L2', 'Skewness_Diff_L2']
        }

        promising_feature_sets = {'COMMON_3': ['Dataset_Ratio', 'Total_Dataset_Size',
                                               'Distance_Diff_over_Prod','Unified_Dist_over_Prod',
                                               'Energy_Distance',
                                               'Rank_based_Similarity',
                                               'Mean_Diff_L2']}


        results = []

        TRAIN_SIZES = [0.05, 0.075, 0.10, 0.125, 0.15,0.175, 0.20, 0.225, 0.25, 0.275, 0.3, 0.325, 0.35, 0.375, 0.4, 0.425, 0.45, 0.475, 0.5]
        for train_size in TRAIN_SIZES:
            test_size = 0.5
            selected_feats = promising_feature_sets['COMMON_3']
            # print(f"\n🔍 Evaluating feature set: {set_name} => {selected_feats}")
            indices = [ALL_Columns.index(f) for f in selected_feats]
            X_subset = X[:, indices]

            if ModelName == 'Quadratic':
                poly = PolynomialFeatures(degree=2, include_bias=False)
                X_poly = poly.fit_transform(X_subset)

                n_samples = len(X_poly)
                all_indices = np.arange(n_samples)

                # Split features, labels, and indices
                X_temp, X_test, y_temp, y_test, idx_temp, idx_test = train_test_split(X_poly, y, all_indices, test_size=test_size)

                # Step 2: take 20% of the remaining 50% → 10% of the total
                if train_size != TRAIN_SIZES[-1]:
                    X_train, _, y_train, _, idx_train, _ = train_test_split(
                        X_temp, y_temp, idx_temp, test_size=(1-train_size*2))
                else:
                    X_train, y_train, idx_train = X_temp, y_temp, idx_temp

                print(f'shapes: {len(X_train)}, {len(y_train)}, {len(idx_train)}, test: {len(X_test)}, {len(y_test)}')

            else:
                n_samples = len(X_subset)
                all_indices = np.arange(n_samples)
                X_temp, X_test, y_temp, y_test, idx_temp, idx_test = train_test_split(X_subset, y, all_indices,
                                                                                         test_size=test_size)
                # Step 2: take 20% of the remaining 50% → 10% of the total
                if train_size != TRAIN_SIZES[-1]:
                    X_train, _, y_train, _, idx_train, _ = train_test_split(
                        X_temp, y_temp, idx_temp, test_size=(1 - train_size * 2))
                else:
                    X_train, y_train, idx_train = X_temp, y_temp, idx_temp

            train_pairs = PAIRS[idx_train]
            test_pairs = PAIRS[idx_test]
            timeStart = time.time()
            x_scaler = StandardScaler()
            y_scaler = StandardScaler()

            X_train = x_scaler.fit_transform(X_train)
            X_test = x_scaler.transform(X_test)

            y_train = y_scaler.fit_transform(y_train.reshape(-1, 1)).ravel()
            y_test = y_scaler.transform(y_test.reshape(-1, 1)).ravel()

            if ModelName == 'Quadratic':
                alpha = np.logspace(-6, 4, 200)
                grid = GridSearchCV(estimator=Ridge(), param_grid={'alpha': alpha}, cv=5)
                grid.fit(X_train, y_train)
                best_alpha = grid.best_params_['alpha']
                model = Ridge(alpha=best_alpha)
            else:
                model = LinearRegression()

            model.fit(X_train, y_train)
            if ModelName == 'LR':
                for idx,each_feature in enumerate(selected_feats):
                    print(f'{each_feature} : Coeff = {model.coef_[idx]}')

            # Evaluate

            y_pred = model.predict(X_test)
            all_r2 = r2_score(y_test, y_pred)
            time_required = time.time() - timeStart
            all_correlation = np.corrcoef(y_test, y_pred)[0,1]
            print(f'{dataset.upper()}, train_size: {train_size:0.3f},  R2 Score: {r2_score(y_test, y_pred):0.4f}, '
                  f'correlation:  {np.corrcoef(y_test, y_pred)[0, 1]:0.4f},'
                  f'correlation^2:  {(np.corrcoef(y_test, y_pred)[0, 1]**2):0.4f}')

            train_pred = model.predict(X_train)
            test_pred = model.predict(X_test)
            # predictions = np.array([train_pred, test_pred])
            # pairs = np.array([train_pairs, test_pairs])
            train_pred = y_scaler.inverse_transform(train_pred.reshape(-1, 1)).ravel()
            test_pred = y_scaler.inverse_transform(test_pred.reshape(-1, 1)).ravel()

            results.append(
                {
                    # 'Set_Name':set_name, 'Feature': selected_feats,
                 'Train_Size': train_size, 'Test_Size': test_size,
                 'Correlation': all_correlation, 'R2_Score': all_r2,
                'Time_required': time_required})

        results_df = pd.DataFrame(results)
        results_df['Dataset'] = dataset

        if dataset == datasets[0]:
            All_DF['Train_Size'] = list(results_df['Train_Size'])
            All_DF['Test_Size'] = list(results_df['Test_Size'])
        rsq = list(results_df['R2_Score'])
        # rsq = [val for val in rsq]
        corr = list(results_df['Correlation'])
        time_required = list(results_df['Time_required'])

        All_DF[f'{dataset}_R2_Score'] = rsq
        All_DF[f'{dataset}_Correlation'] = corr
        All_DF[f'{dataset}_Time_required'] = time_required

        if 'all_results' not in globals():
            all_results = results_df
        else:
            all_results = pd.concat([all_results, results_df], axis=0)


    All_DF = pd.DataFrame(All_DF)
    All_DF.to_csv(path_or_buf=f'../RESULTS/FeatureSet_Performance_{ModelName}_for_RT_run_{RUN}_SEED_{seed}.csv', index=False)

