import numpy as np
import matplotlib.pyplot as plt
from sklearn.feature_selection import f_regression, mutual_info_regression
from sklearn.model_selection import train_test_split
from sklearn import datasets, linear_model
from sklearn.metrics import r2_score
from helper import *
from helper_basis_symbolic import *
import os
from joblib import Parallel, delayed
import time
import numpy as np
import ot
from sklearn.feature_selection import RFE, RFECV
np.random.seed(42)  # Fix the seed for reproducibility
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import KFold

def fit_kde(trajs):
    kde_states = fit_states(trajs)
    kde_succ_states = fit_succ_states(trajs)
    return kde_states, kde_succ_states

def compute_log_probabilities(trajs, expert_ts, num_chunks, kde_states, kde_succ_states, use_marginal):
    n_batches = 50
    logP_tau = compute_prob_marginal(n_batches, trajs, expert_ts, num_chunks, kde_states)
    if not use_marginal:
        logP_tau = compute_prob_seq(logP_tau, n_batches, trajs, num_chunks, kde_succ_states)
    return logP_tau

def filter_outliers(mu_tau, logP_tau, remove_outliers):
    if remove_outliers:
        sorted_indices_x = np.argsort(logP_tau)[:-10]
        sorted_logPtau = logP_tau[sorted_indices_x]
        sorted_mu_tau = mu_tau[sorted_indices_x, :]
        
        # shuffle sorted arrays
        indices = np.arange(len(sorted_logPtau))
        np.random.shuffle(indices)
        sorted_mu_tau = sorted_mu_tau[indices]
        sorted_logPtau = sorted_logPtau[indices]
        return np.array(sorted_mu_tau), np.array(sorted_logPtau).squeeze()
    return mu_tau, logP_tau


def select_top_features(X, y, Xe, ye, all_variables, feats_names, score_T=0.7, n_selected=20, drop_features=False, env_name=None, notex=None, save=None, method=1, verbose=False, folder_path=None):
    original_feats_names = feats_names.copy()
        
    if drop_features:
        df = pd.DataFrame(X, columns=feats_names)
        
        # Handle infinite values
        df.replace([np.inf, -np.inf], np.nan, inplace=True)
        df.fillna(0, inplace=True)
        y = pd.Series(y).replace([np.inf, -np.inf], np.nan).fillna(0).values

        # Remove zero variance features
        df = df.loc[:, df.var() != 0]
        feats_names = df.columns.tolist()
        
        # Calculate correlation matrix
        correlation_matrix = df.corr().abs()
        correlation_matrix.fillna(0, inplace=True)

        upper_triangle = correlation_matrix.where(np.triu(np.ones(correlation_matrix.shape), k=1).astype(bool))
        correlation_threshold = 0.95
        high_correlation_pairs = [(column, index) for column in upper_triangle.columns for index in upper_triangle.index if upper_triangle.loc[index, column] > correlation_threshold]
        
        target_correlation = df.apply(lambda x: x.corr(pd.Series(y)), axis=0).abs()
        to_drop = set()
        for col1, col2 in high_correlation_pairs:
            if target_correlation[col1] >= target_correlation[col2]:
                to_drop.add(col2)
            else:
                to_drop.add(col1)

        df = df.drop(columns=to_drop)
        feats_names = [feat for feat in feats_names if feat not in to_drop]
        X = df.values
    
    num_feats = len(feats_names)
    # Change this to 2 if you want to use the RFE method

    if method == 1:
        kf = KFold(n_splits=10, shuffle=True, random_state=42)
        selected_features = []
        for train_index, test_index in kf.split(X):
            X_train, X_test = X[train_index], X[test_index]
            y_train, y_test = y[train_index], y[test_index]
            
            f_test_cv, _ = f_regression(X_train, y_train)
            f_test_cv_orig = f_test_cv.copy()
            f_test_cv = np.abs(f_test_cv)
            f_test_cv = (f_test_cv / np.max(f_test_cv)).astype(float)
            
            ind_top_cv = [i for i in range(num_feats) if f_test_cv[i] > score_T]
            selected_features.extend(ind_top_cv)

        selected_features_counts = pd.Series(selected_features).value_counts()
        
        top_features_indices = selected_features_counts.head(n_selected).index.to_list()
        important_features = [feats_names[i] for i in top_features_indices]
        sign_importance = [f_test_cv_orig[i] for i in top_features_indices]
        print("    Top important features based on F-test and cross-validation:")
        print(selected_features_counts.head(n_selected))
        for feature, i, s in zip(important_features, top_features_indices, sign_importance):
            print(i, feature, s/np.max(f_test_cv_orig))
        
        ind_top = top_features_indices

    if method == 2:
        f_test, _ = f_regression(X, y)
        f_test = np.abs(f_test)
        f_test = (f_test / np.max(f_test)).astype(float)
        
        filtered_features_indices = np.argsort(f_test)[::-1][:50]
    
        model = linear_model.Ridge()
        selector = RFE(estimator=model, n_features_to_select=n_selected, step=1)
        selector = selector.fit(X[:, filtered_features_indices], y)
        selected_indices = np.where(selector.support_)[0]
        
        ind_top = filtered_features_indices[selected_indices]
        important_features = [feats_names[i] for i in ind_top]
        # print("    Top important features based on RFE:")
        # for feature in important_features:
        #     print(feature)
    

    if method == 3:
        f_test, _ = f_regression(X, y)
        f_test = np.abs(f_test)
        f_test = (f_test / np.max(f_test)).astype(float)
        
        filtered_features_indices = np.argsort(f_test)[::-1][:n_selected]
        important_features = [feats_names[i] for i in filtered_features_indices]

        # for feature in important_features:
        #     print(feature)
        
        ind_top = filtered_features_indices
        
    
    ind_top_original = [original_feats_names.index(feat) for feat in important_features]
    feat_original = [original_feats_names[i] for i in ind_top_original]
    
    selected_basis = [original_feats_names[i] for i in ind_top_original]
    if save:
        print(f"    Saving basis functions: {selected_basis}")
        save_selected_basis_functions(selected_basis, all_variables, env_name, notex)        

    if verbose:
        r2_train, r2_test = plot_regression_results(Xe, ye, ind_top_original, verbose, folder_path)
        print(f"\n  [{notex}] - ({len(ind_top)}) / ({len(feats_names)})- R2 train/test: {r2_train:.3f}/{r2_test:.3f}")
        
    return ind_top


def plot_regression_results(X, y, ind_top, verbose, folder_path):
    # CREATE LABELS
    X = np.array(X)
    y = np.array(y).squeeze()
        
    regr = linear_model.LassoCV()
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=1/5, random_state=85, shuffle=False)
    selected_train = X_train[:, ind_top].astype(np.float32)
    selected_test = X_test[:, ind_top].astype(np.float32)

    regr.fit(selected_train, y_train.astype(np.float32))
    y_pred_train = regr.predict(selected_train)
    y_pred_test = regr.predict(selected_test)

    if verbose:
        fig, ax = plt.subplots(1, 2, figsize=(16, 8))
        ax[0].set_title('Train data - R score: %.2f' % r2_score(y_train, y_pred_train))
        ax[1].set_title('Test data - R score: %.2f' % r2_score(y_test, y_pred_test))

        sorted_indices_train = np.argsort(y_train)
        sorted_indices_test = np.argsort(y_test)

        ax[0].plot(y_train[sorted_indices_train], 'o', color="black", label='y_train')
        ax[0].plot(y_pred_train[sorted_indices_train], color="green", label='y_pred_train')
        ax[0].legend()

        ax[1].plot(y_test[sorted_indices_test], 'o', color="black", label='y_test')
        ax[1].plot(y_pred_test[sorted_indices_test], color="green", label='y_pred_test')
        ax[1].legend()
        
        # plt.savefig(f'{folder_path}/regression_results.png')
        plt.show()
        
    correlations = np.corrcoef(selected_train, y_train, rowvar=False)[-1, :-1]
    signs = np.sign(correlations)
    print("    Correlations of features with the target:", signs)

    return r2_score(y_train, y_pred_train), r2_score(y_test, y_pred_test)

def save_selected_basis_functions(selected_basis, all_variables, env_name, notex):
    filepath = f'tmp/{env_name}/{notex}/{env_name}_basis_{notex}.joblib'
    save_symbolic_basis_functions_joblib(selected_basis, all_variables, filepath=filepath)

def save_array_plot(array, filename):
    plt.figure(figsize=(16, 8))
    plt.plot(array, '.')
    plt.savefig(filename)


def cross_validation_kde(env_name, normalize, num_chunks, gamma):
    # load data
    print(os.getcwd())
    portions = [0.5]

    trajs, rewards, expert_ts, non_expert_trajs, non_expert_ts = load_data(env_name, normalize, num_chunks, gamma)
    
    for data_portion in portions:
        print(f'------- Using {data_portion} of the data')  
        kde_states, _ = fit_kde(trajs)
        kde_states_small, _ = fit_kde(trajs[:int(data_portion*len(trajs))])

        for n in [10000]:
            # sample n points from two kdes
            sample_states = kde_states.sample(n)
            sample_states_small = kde_states_small.sample(n)
            print('shape: ', sample_states.shape)
            # check if the two samples are similar
            start_time = time.time()
            try:
                M = ot.dist(sample_states, sample_states_small)
                dist = ot.emd2([], [], M, numItermax=1000000, numThreads="max")
            except Exception as e:
                pass
            print("Time for OT: {:.2f} seconds".format(time.time() - start_time))
            print(f"Distance between two samples: {dist} for n={n}")
        
    # # # repeat for different n for kde_succ_states, kde_succ_states_small
    # for data_portion in portions:
    #     print(f'------- Using {data_portion} of the data')  
    #     _, kde_succ_states = fit_kde(trajs)
    #     _, kde_succ_states_small = fit_kde(trajs[:int(data_portion*len(trajs))])

    #     for n in [5000, 10000]:
    #         # sample n points from two kdes
    #         sample_succ_states = kde_succ_states.sample(n)
    #         sample_succ_states_small = kde_succ_states_small.sample(n)
    #         print('shape: ', sample_succ_states.shape)
    #         # check if the two samples are similar
    #         start_time = time.time()
    #         try:
    #             M = ot.dist(sample_succ_states, sample_succ_states_small)
    #             dist = ot.emd2([], [], M, numItermax=500000, numThreads=10)
    #         except Exception as e:
    #             pass
    #         print("Time for OT: {:.2f} seconds".format(time.time() - start_time))
    #         print(f"Distance between two samples: {dist} for n={n}")
    

