import math
import numpy as np
import pickle
import pandas as pd
import os
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.gridspec import GridSpec
#plt.style.use('tableau-colorblind10')
from sklearn.decomposition import PCA
from sklearn.cross_decomposition import CCA
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.metrics import roc_auc_score, roc_curve, accuracy_score
from sklearn.utils import resample
import statsmodels.api as sm
import shap
from mpl_toolkits.axes_grid1 import make_axes_locatable
shap.initjs()

def compute_statistics(clean_barcodes_layers, poisoned_barcodes_layers, layers, feature_names):
    statistics_layers_full = {}

    for layer in layers:
        clean_barcodes = clean_barcodes_layers[layer]
        poisoned_barcodes = poisoned_barcodes_layers[layer]
        statistics = np.zeros((len(clean_barcodes) + len(poisoned_barcodes), len(feature_names)))
        for i, barcode in enumerate(clean_barcodes):
            bars_0 = barcode[0]
            bars_1 = barcode[1]

            deaths_0 = bars_0[:, 1]
            deaths_0 = [data for data in deaths_0 if not math.isinf(data)]
            births_1 = bars_1[:, 0]
            deaths_1 = bars_1[:, 1]
            persistences_1 = deaths_1 - births_1
            ratios_1 = deaths_1 / births_1

            # Calculate persistence entropy for each dimension
            persistence_entropy = {}
            for dim in range(2):
                bars = barcode[dim]
                if len(bars) > 0:
                    lifetimes = bars[:, 1] - bars[:, 0]
                    lifetimes = lifetimes[np.isfinite(lifetimes)]
                    if len(lifetimes) > 0:
                        normalized_lifetimes = lifetimes / np.sum(lifetimes)
                        entropy = -np.sum(normalized_lifetimes * np.log(normalized_lifetimes + 1e-10))
                        persistence_entropy[dim] = entropy
                    else:
                        persistence_entropy[dim] = 0
                else:
                    persistence_entropy[dim] = 0

            statistics[i, :] = np.asarray([np.mean(deaths_0), np.min(deaths_0), np.quantile(deaths_0, 0.25), np.quantile(deaths_0, 0.5), np.quantile(deaths_0, 0.75), np.max(deaths_0), np.std(deaths_0),
                                                np.mean(births_1), np.min(births_1), np.quantile(births_1, 0.25), np.quantile(births_1, 0.5), np.quantile(births_1, 0.75), np.max(births_1), np.std(births_1),
                                                np.mean(deaths_1), np.min(deaths_1), np.quantile(deaths_1, 0.25), np.quantile(deaths_1, 0.5), np.quantile(deaths_1, 0.75), np.max(deaths_1), np.std(deaths_1),
                                                np.mean(persistences_1), np.min(persistences_1), np.quantile(persistences_1, 0.25), np.quantile(persistences_1, 0.5), np.quantile(persistences_1, 0.75), np.max(persistences_1), np.std(persistences_1),
                                                np.mean(ratios_1), np.min(ratios_1), np.quantile(ratios_1, 0.25), np.quantile(ratios_1, 0.5), np.quantile(ratios_1, 0.75), np.max(ratios_1), np.std(ratios_1),
                                                np.sum(deaths_0), np.sum(persistences_1),
                                                len(bars_0), len(bars_1), 
                                                persistence_entropy[0], persistence_entropy[1]
                                                ])

        for i, barcode in enumerate(poisoned_barcodes):
            bars_0 = barcode[0]
            bars_1 = barcode[1]

            deaths_0 = bars_0[:, 1]
            deaths_0 = [data for data in deaths_0 if not math.isinf(data)]
            births_1 = bars_1[:, 0]
            deaths_1 = bars_1[:, 1]
            persistences_1 = deaths_1 - births_1
            ratios_1 = deaths_1 / births_1

            statistics[len(clean_barcodes)+i, :] = np.asarray([np.mean(deaths_0), np.min(deaths_0), np.quantile(deaths_0, 0.25), np.quantile(deaths_0, 0.5), np.quantile(deaths_0, 0.75), np.max(deaths_0), np.std(deaths_0),
                                                np.mean(births_1), np.min(births_1), np.quantile(births_1, 0.25), np.quantile(births_1, 0.5), np.quantile(births_1, 0.75), np.max(births_1), np.std(births_1),
                                                np.mean(deaths_1), np.min(deaths_1), np.quantile(deaths_1, 0.25), np.quantile(deaths_1, 0.5), np.quantile(deaths_1, 0.75), np.max(deaths_1), np.std(deaths_1),
                                                np.mean(persistences_1), np.min(persistences_1), np.quantile(persistences_1, 0.25), np.quantile(persistences_1, 0.5), np.quantile(persistences_1, 0.75), np.max(persistences_1), np.std(persistences_1),
                                                np.mean(ratios_1), np.min(ratios_1), np.quantile(ratios_1, 0.25), np.quantile(ratios_1, 0.5), np.quantile(ratios_1, 0.75), np.max(ratios_1), np.std(ratios_1),
                                                np.sum(deaths_0), np.sum(persistences_1),
                                                len(bars_0), len(bars_1), 
                                                persistence_entropy[0], persistence_entropy[1]
                                                ])

        statistics_layers_full[layer] = pd.DataFrame(statistics, columns=feature_names)
    return statistics_layers_full

def plot_crossed_correlations_layers(statistics_layers_full, layers, model, distance):
    layers_rows = [layers[i:i + 3] for i in range(0, len(layers), 3)]
    num_figures = len(layers_rows)

    figs = []
    for i, layers_row in enumerate(layers_rows):
        fig_width = 4.2 if len(layers_row) == 3 else 2.8
        fig = plt.figure(figsize=(fig_width, 1.6))
        figs.append(fig)

        axes = []
        last_heatmap = None  # Store last heatmap for colorbar
        for j, layer in enumerate(layers_row):
            data = statistics_layers_full[layer]
            ax = fig.add_subplot(1, len(layers_row), j + 1)
            
            heatmap = sns.heatmap(
                np.abs(data.corr()), 
                ax=ax, 
                alpha=0.9,
                cmap='viridis',
                cbar=False,  # Disable colorbars on all but the last one
                xticklabels=False, 
                yticklabels=False, 
            )
            ax.set_aspect('equal')
            ax.set_title(f"Layer {layer + 1}", fontsize=10, verticalalignment='bottom')
            axes.append(ax)

            # Store heatmap for colorbar if it's the last figure
            if i == num_figures - 1 and j == len(layers_row) - 1:
                last_heatmap = heatmap
        plt.tight_layout()
        # Add colorbar to the last heatmap
        if last_heatmap:
            fig.subplots_adjust(right=0.83)  # Make space for colorbar
            cbar_ax = fig.add_axes([0.85, 0.15, 0.03, 0.6])  # [left, bottom, width, height]
            cbar = fig.colorbar(last_heatmap.collections[0], cax=cbar_ax, orientation='vertical')
            cbar.ax.tick_params(labelsize=6)  # Reduce tick label size for readability

        #plt.tight_layout()
        plt.savefig(f'images/{model}/cross_correlation_distance_{distance}_model_{model}_{i}.png', transparent=True, dpi=300)
        plt.show()



def drop_features_high_correlation(statistics_layers_full, layers, model, distance, threshold_corr=0.95, verbose=False):
    statistics_layers_short={}

    # Remove any pre-existing files
    if os.path.exists(f"computed_data/used_features_{model}_{distance}.txt"):
        os.remove(f"computed_data/used_features_{model}_{distance}.txt")

    with open(f"computed_data/used_features_{model}_{distance}.txt", "w") as file:
        file.write(f"Features used in the analysis of {model} with {distance} distance:\n")
    for layer in layers:
        data = statistics_layers_full[layer]
        # Compute correlation matrix
        correlation_matrix = data.corr()

        # Select upper triangle of correlation matrix
        upper = correlation_matrix.where(np.triu(np.ones(correlation_matrix.shape), k=1).astype(bool))

        # Find features with correlation greater than 0.95
        to_drop = [column for column in upper.columns if any(upper[column] > threshold_corr)]
        #print(to_drop)

        # drop columns of redundant layers
        data.drop(to_drop, axis=1, inplace=True)
        statistics_layers_short[layer] = data
        if verbose==True:
            print(f'--- Dropping {len(to_drop)} features in layer {layer} with correlation over {threshold_corr}')
            print(f"{len(data.columns.tolist())} features remain: {data.columns.tolist()}")
            #print(f"Features to drop: {to_drop}")
            print(correlation_matrix['Mean deaths 0-bars']>0.9)

        # Save features used
        with open(f"computed_data/used_features_{model}_{distance}.txt", "a") as file:
            file.write(f'Dropping {len(to_drop)} features in layer {layer}: {to_drop}\n')
            file.write(f"Features used in layer {layer}: {data.columns.tolist()}\n")
    return statistics_layers_short

def PCA_analysis(statistics_layers, num_subsamples, size_subsamples, distance, model, layers, experiment, verbose=False):   
    layers_rows = [layers[i:i + 3] for i in range(0, len(layers), 3)]
    num_figures = len(layers_rows)

    figs = []   
    X_scaled_list = {}
    for i, layers_row in enumerate(layers_rows):
        fig_width = 4.8 if len(layers_row) == 3 else 3.2
        fig = plt.figure(figsize=(fig_width, 1.6))
        figs.append(fig)

        axes = []
        for j, layer in enumerate(layers_row):
            # get data
            data = statistics_layers[layer]

            # Add target
            y = [0 for _ in range(num_subsamples)] + [1 for _ in range(num_subsamples)]
            data['Class'] = y

            # Scale data
            X = data.iloc[:, :-1]
            scaler = StandardScaler()
            X_scaled = scaler.fit_transform(X)
            X_scaled_list[layer] = X_scaled

            # Apply PCA
            pca = PCA(n_components=2, whiten=False, random_state=42)
            X_pca = pca.fit_transform(X_scaled)

            ax = fig.add_subplot(1, len(layers_row), j + 1)

            # Scatter plot for the PCA
            ax.scatter(X_pca[:num_subsamples, 0], X_pca[:num_subsamples, 1], 
                       color='blue', label='clean', marker='o', s=0.7, alpha=0.8)
            ax.scatter(X_pca[num_subsamples:, 0], X_pca[num_subsamples:, 1], 
                       color='orange', label='poisoned', marker='P', s=0.7, alpha=0.8)
            
            # Remove ticks
            ax.set_xticks([])
            ax.set_yticks([]) # Reducing tick label size

            # make lines thinner
            ax.spines['top'].set_linewidth(0.5)
            ax.spines['right'].set_linewidth(0.5)
            ax.spines['bottom'].set_linewidth(0.5)
            ax.spines['left'].set_linewidth(0.5)

            ax.legend(fontsize=8)
            ax.set_title(f'Layer {layer+1}', fontsize=10)
            axes.append(ax)

            if verbose:
                components = pca.components_
                explained_variance_ratio = pca.explained_variance_ratio_

                print(f"Layer {layer}:")
                print("x-axis PCA Components (Directions):")
                print(np.abs(pd.DataFrame(components, columns=X.columns).iloc[0]).sort_values())
                print("Explained Variance Ratio (Contribution):")
                print(np.sum(explained_variance_ratio))
                print("\n")

        plt.tight_layout()
        plt.savefig(f'images/{model}/PCA_statistics_{experiment}_N_subsamples_{num_subsamples}_size_{size_subsamples}_distance_{distance}_model_{model}_{i}.png', 
                    transparent=True, dpi=300)
        plt.show()
    return X_scaled_list

def CCA_analysis(statistics_layers, X_scaled_list, num_subsamples, size_subsamples, distance, model, layers, experiment, verbose=False):
    layers_rows = [layers[i:i + 3] for i in range(0, len(layers), 3)]
    num_figures = len(layers_rows)

    figs = []
    for i, layers_row in enumerate(layers_rows):
        fig_width = 4.8 if len(layers_row) == 3 else 3.6
        fig = plt.figure(figsize=(fig_width, 3.6))
        figs.append(fig)

        axes = []
        for layer_idx, layer in enumerate(layers_row):
            data = statistics_layers[layer]

            # Add target as one hot-encoded variable
            #y = [0 for i in range(num_subsamples)] + [1 for i in range(num_subsamples)]
            #data['Class'] = y
            #y_dummy = pd.get_dummies(data['Class']).values
            y = X_scaled_list[layer]

            # Scale data
            X = data.iloc[:, :-1]
            scaler = StandardScaler()
            X_scaled = scaler.fit_transform(X)
            n_features_X = X_scaled.shape[1]

            # Initialize and fit 0CCA
            cca = CCA(n_components=2)
            cca.fit(X_scaled, y)

            # Transform the dataset
            X_c, y_c = cca.transform(X_scaled, y)

            # Determine feature importance based on canonical weights
            canonical_weights_X = np.abs(cca.x_weights_[:, 0])  # Canonical weights for the first component
            canonical_loadings_X = np.abs(cca.x_loadings_[:,0])  # Loadings for the first component

            # Bootstrap to compute variability of canonical weights and loadings
            n_bootstraps = 1000
            bootstrap_weights_X = np.zeros((n_bootstraps, n_features_X))
            bootstrap_loadings_X = np.zeros((n_bootstraps, n_features_X))

            for j in range(n_bootstraps):
                # Resample the data
                X_resampled, Y_resampled = resample(X_scaled, y)
                
                # Refit CCA on resampled data
                cca.fit(X_resampled, Y_resampled)
                
                # Collect bootstrap estimates
                bootstrap_weights_X[j, :] = cca.x_weights_[:, 0]  # First canonical weight
                bootstrap_loadings_X[j, :] = cca.x_loadings_[:,0]

            # Compute standard deviation of bootstrap estimates
            std_weights_X = bootstrap_weights_X.std(axis=0)
            std_loadings_X = bootstrap_loadings_X.std(axis=0)

            # Compute p-values for weights and loadings
            p_values_weights = [2 * min((bootstrap_weights_X[:, k] > 0).mean(), (bootstrap_weights_X[:, k] < 0).mean())
                                for k in range(n_features_X)]
            p_values_loadings = [2 * min((bootstrap_loadings_X[:, k] > 0).mean(), (bootstrap_loadings_X[:, k] < 0).mean())
                                for k in range(n_features_X)]
            # Add everything to DataFrame
            #print(feature_importance)
            feature_importance_df = pd.DataFrame({
                'Feature': X.columns,
                'Importance': canonical_loadings_X,
                'Deviations': std_loadings_X,
                'P-values': p_values_loadings,
            }).sort_values(by='Importance', ascending=False)

            if verbose==True:
                print("Canonical Weights (X):\n", canonical_weights_X)
                print("Canonical Loadings (X):\n", canonical_loadings_X)
                print("\nFeatures importance:", feature_importance_df)
                print("\nBootstrap Std Dev of Weights (X):", std_weights_X)
                print("Bootstrap Std Dev of Loadings (X):", std_loadings_X)
                print("\nP-Values for Weights (X):", p_values_weights)
                print("P-Values for Loadings (X):", p_values_loadings)

            num_features = 5
            if len(feature_importance_df['Feature'])<num_features:
                num_features = len(feature_importance_df['Feature'])
            
            if X.shape[1] < 5:
                num_features = X.shape[1] 
            ax = fig.add_subplot(1, len(layers_row), layer_idx + 1)
            ax.bar(feature_importance_df['Feature'].iloc[:num_features], feature_importance_df['Importance'].iloc[:num_features], color='#fa2a55')
            ax.errorbar(feature_importance_df['Feature'].iloc[:num_features], feature_importance_df['Importance'].iloc[:num_features], feature_importance_df['Deviations'].iloc[:num_features], fmt='none', color='Black', capsize=1)
            ax.set_xticks(range(num_features))  # Set ticks for the top 10 features  
            ax.set_xticklabels(feature_importance_df['Feature'].iloc[:num_features], rotation=90, fontsize=8)
            ax.set_ylim(0,1.1)
            if layer_idx != 0:
                ax.set_yticks([])
            ax.tick_params(axis='y', labelsize=6)

            # make lines thinner
            ax.spines['top'].set_linewidth(0.5)
            ax.spines['right'].set_linewidth(0.5)
            ax.spines['bottom'].set_linewidth(0.5)
            ax.spines['left'].set_linewidth(0.5)
            ax.set_title(f'Layer {layer+1}', fontsize=10)
            if layer_idx == 0:
                ax.set_ylabel('Loadings from CCA', fontsize=8)

            axes.append(ax)

        #fig.suptitle(f"CCA for feature importance of statistical vectorizations computed with {distance} distance for {model.replace('_', ' ')}", fontsize=20)
        plt.tight_layout()
        plt.savefig(f'images/{model}/CCA_statistics_{experiment}_N_subsamples_{num_subsamples}_size_{size_subsamples}_distance_{distance}_model_{model}_{i}.png', transparent=True, dpi=300)
        plt.show()

def regression_analysis(statistics_layers, num_subsamples, size_subsamples, distance, model, layers, experiment, verbose=False):
    # Group layers in chunks of 3
    layers_rows = [layers[i:i + 3] for i in range(0, len(layers), 3)]
    num_figures = len(layers_rows)

    figs = []
    last_heatmap = None
    for i, layers_row in enumerate(layers_rows):
        fig = plt.figure(figsize=(5.2, 2.4))  # Increased height for more space below
        figs.append(fig)

        for j, layer in enumerate(layers_row):
            data = statistics_layers[layer]
            X = data.iloc[:, :-1]
            y = data['Class']

            # Split, scale, and fit model in a single step
            X_scaled = StandardScaler().fit_transform(X)
            X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.3, random_state=42)

            # Train Logistic Regression model
            regression = LogisticRegression()
            regression.fit(X_train, y_train)

            if verbose:
                # Print coefficients and classification report
                print(f"Coefficients for layer {layer}")
                print(pd.DataFrame(np.exp(regression.coef_), columns=X.columns).iloc[0])
                print(f"Classification report for layer {layer}")
                print(classification_report(y_test, regression.predict(X_test)))

            # Predict labels and probabilities
            y_pred = regression.predict(X_test)
            y_pred_prob = regression.predict_proba(X_test)[:, 1]

            # Compute performance metrics
            accuracy = accuracy_score(y_test, y_pred)
            auc = roc_auc_score(y_test, y_pred_prob)
            cv_score = np.mean(cross_val_score(regression, X_train, y_train, cv=5))

            # Create PCA plot
            pca = PCA(n_components=2, random_state=42)
            X_pca = pca.fit_transform(X_scaled)

            ax = fig.add_subplot(1, len(layers_row), j + 1)
            scatter = ax.scatter(X_pca[:, 0], X_pca[:, 1], c=regression.predict_proba(X_scaled)[:, 1], s=0.7, cmap=LinearSegmentedColormap.from_list("custom_blue_orange", ['blue', 'orange']), alpha=0.8)

            # Store heatmap for colorbar if it's the last figure
            if i == num_figures - 1 and j == len(layers_row) - 1:
                last_heatmap = scatter

            ax.set_title(f'Layer {layer+1}', fontsize=10)
            # Position text box below the plot
            ax.text(0.5, -0.1, f"Accuracy: {accuracy:.2f}\nAUC-ROC: {auc:.2f}\nCV: {cv_score:.2f}", transform=ax.transAxes,
                    fontsize=8, verticalalignment='top', horizontalalignment='center', bbox=dict(boxstyle="round", facecolor="white", alpha=0.8, linewidth=0.5))

            # Remove ticks, adjust borders and legends
            ax.set_xticks([]); ax.set_yticks([]); 
            for spine in ax.spines.values(): spine.set_linewidth(0.5)
        plt.tight_layout()
        fig.subplots_adjust(bottom=0.3)
        # Adjust layout to make space for text below and color bar
        if last_heatmap:
            fig.subplots_adjust(right=0.9)  # Make space for colorbar
            cbar_ax = fig.add_axes([0.92, 0.3, 0.03, 0.6])  # [left, bottom, width, height]
            cbar = fig.colorbar(last_heatmap, cax=cbar_ax, orientation='vertical')
            cbar.ax.tick_params(labelsize=6)  # Reduce tick label size for readability

        # Save the figure
        plt.savefig(f'images/{model}/PCA_regression_{experiment}_N_subsamples_{num_subsamples}_size_{size_subsamples}_distance_{distance}_model_{model}_{i}.png', transparent=True, dpi=300)
        plt.show()
def SHAP_analysis(statistics_layers, num_subsamples, size_subsamples, distance, model, layers, experiment):
    # Group layers in chunks of 2
    layers_rows = [layers[i:i + 2] for i in range(0, len(layers), 2)]
    
    figs = []
    shap_layers = {}
    
    for i, layers_row in enumerate(layers_rows):
        if len(layers_row)==2:
            print('hi')
            fig = plt.figure(figsize=(5.6, 1.2))
        else:
            fig = plt.figure(figsize=(24, 1.2))  # Increased height for more space below
        figs.append(fig)

        axes = []
        for j, layer in enumerate(layers_row):
            print(f"- SHAP analysis of layer {layer}:")
            data = statistics_layers[layer]
            X = data.iloc[:, :-1]
            y = data['Class']

            # Split in train and test
            X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

            # Scale data
            scaler = StandardScaler()
            X_train_scaled = scaler.fit_transform(X_train)
            X_test_scaled = scaler.transform(X_test)

            # Train a logistic regression
            regression = LogisticRegression()
            regression.fit(X_train_scaled, y_train)

            # Initialize SHAP Explainer
            explainer = shap.Explainer(regression, X_train_scaled, link=shap.links.logit)
            shap_values = explainer(X_test_scaled)
            shap_layers[layer] = shap_values

            # Restore feature names
            for idx, feature in enumerate(X_test.columns):
                shap_values.feature_names[idx] = feature
            shap_values.data = X_test.values  # Convert back to original data

            fig.add_subplot(1, len(layers_row), j + 1)
            axes.append(plt.gca())
            ax = shap.plots.beeswarm(shap_layers[layer], show=False, color_bar=False, color=plt.get_cmap("viridis"), max_display=5)
            ax.set_xlabel('SHAP values')
            ax.set_title(f"Layer {layer+1}", fontsize=14)
            ax.set_xticks([])
            fig.subplots_adjust(left=0.35, bottom=0.1)
            
        #fig.tight_layout()
        # Add colorbar only to the last subplot of the row
        if j == len(layers_row) - 1:
            fig.subplots_adjust(left=0.35, bottom=0.1, right=0.8, wspace=6)  # Make space for colorbar
            cbar_ax = fig.add_axes([0.85, 0.2, 0.03, 0.6])  # [left, bottom, width, height]
            cbar = fig.colorbar(ax.collections[0], cax=cbar_ax, orientation='vertical')
            cbar.ax.tick_params(labelsize=6)
            cbar.set_label("Feature value", size=10)
            

        plt.savefig(f'images/{model}/SHAP_regression_statistics_{experiment}_model_{model}_distance_{distance}_{i}.png', transparent=True, dpi=300)
        plt.show()