import math
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import torch
import numpy as np 
import tqdm 
import copy 
import pandas as pd 

def Heatmaps_And_Ratios(HS, num_groups, layers, model_choice, dataset_choice):
    ops = ['q', 'v', 'k', 'attn_output']
    op_to_name = {"q" : "Query", "v" : "Value", "attn_output": "Attention Head Output", "k" : "Key"}

    # fig, axes = plt.subplots(3,5, figsize = (8 * 4, 16))
    fig, axes = plt.subplots(1,5, figsize = (8 * 5, 8))
    im_list   = [] 

    ROW = 0 
    for cls_type in HS.keys():
        accs = copy.deepcopy(HS[cls_type]['accs'])
        for layer in layers:
            for op in ["v", 'k']:
                J = 0 
                temp = copy.deepcopy(accs[op][layer])
                for IDX in accs[op][layer].keys():
                    for _ in range(num_groups):
                        temp[J] = accs[op][layer][IDX]
                        J += 1
                accs[op][layer] = temp 

    dfs_test   = {  op : pd.DataFrame(accs[op]).applymap(lambda X: X['test']) for op in ops }

    COL       = 1

    for op in ['q', 'v', 'k' , 'attn_output']:
        df_test = dfs_test[op]
        im1  = axes[COL].imshow(df_test.T, origin='lower', aspect='auto', cmap='coolwarm',  vmin=.5, vmax=1)
        if ROW == 0:
            axes[COL].set_title(op_to_name[op], fontweight = "bold", fontsize = 30)
        im_list.append(im1)
        COL += 1

    xaxis = np.linspace(.5,1,num = 100)

    # colors = ['blue', 'orange', 'red', 'green']
    colors = ['#1f77b4', '#ff7f0e', '#d62728', '#2ca02c']

    cidx  =0 
    for op in ops:
        df_test = dfs_test[op]
        prop_above_T_te = [] 
        for T in np.linspace(.5,1,num = 100):
            flat_te          = df_test.values.flatten()
            prop_above_T_te.append(len(flat_te[flat_te >= T]) / len(flat_te))
        axes[0].plot(xaxis,  prop_above_T_te, label  = op_to_name[op], color = colors[cidx])
        cidx += 1
    
    ROW += 1

    axes[1].set_ylabel("Layer", fontweight = 'bold', fontsize = 20)

    for COL in range(1,4):
        axes[COL].set_xlabel("Head", fontweight = 'bold', fontsize = 20)

    axes[0].set_title("% Heads vs Acc. Thresh.", fontweight = "bold", fontsize = 30)

    axes[0].set_xlabel("Acc. Thresh. (T)", fontweight = 'bold', fontsize = 20)

    for COL in range(0,5):
        axes[COL].tick_params(axis='both', labelsize=10)  # Optional: control size
        for label in axes[COL].get_xticklabels():
            label.set_fontweight('bold')
        for label in axes[COL].get_yticklabels():
            label.set_fontweight('bold')

    axes[0].set_ylabel("% Heads (Acc >= T)", fontweight = 'bold', fontsize = 20)

    # plt.tight_layout()


    for COL in range(5):
        for label in axes[COL].get_xticklabels():
            label.set_fontsize(16)
        for label in axes[COL].get_yticklabels():
            label.set_fontsize(16)

    fig.suptitle(f'Representation Space Accuracy\n Model: {model_choice}, Data: {dataset_choice}', fontsize=30, fontweight='bold')
    plt.tight_layout(rect=[0, 0, 1, 0.95])  # Leave space for the suptitle

    cbar = fig.colorbar(im_list[0], ax=axes, location='right', shrink=1)
    for label in cbar.ax.get_yticklabels():
        label.set_fontsize(20)
        label.set_fontweight('bold')

    axes[0].legend()

    cbar.ax.set_title("Acc", fontsize=30, fontweight='bold' , pad=10)



################################################################################################
##### Visualizes + Saves 2D PCA for each layer + method, plot resultant logreg acc + line  #####
################################################################################################
#  TODO: make param to method h_idx   = 0 # unused head idx
# NOTE: Deprecated in favor of other PCA method

def PCA_over_layers(results, layers = list(range(1,32)), save_name = "pirate"):

    for agg in results.keys():  # mean or last 
        for key in results[agg].keys():
            fig, axes = plt.subplots(math.ceil(len(layers) / 6), 6, figsize=(20, 20))
            axes = axes.flat  # Flatten axes for easy 1-based indexing

            for layer_idx in tqdm.tqdm(list(range(len(layers)))):
                layer = layers[layer_idx] 
                N1, N2 = len(results[agg][key]['pos']), len(results[agg][key]['neg'])
                pos = results[agg][key]['pos'][:,layer,...].reshape(N1,-1)
                neg = results[agg][key]['neg'][:,layer,...].reshape(N2,-1)
                # TODO: attn heade specific indexing 
                # neg_layer_acts = [x[layer][h_idx, 0, :] for x in q_negatives] 
                # pos_layer_acts = [x[layer][h_idx, 0, :] for x in q_positives]
                ################################################

                # PCA
                all_embeddings =  torch.vstack( (neg, pos))
                pca = PCA(n_components=2)
                data_pca = pca.fit_transform(all_embeddings.detach().cpu().numpy())
                axes[layer_idx].scatter(data_pca[:len(neg), 0], data_pca[: len(neg), 1],  c='blue', alpha=0.7, s=100)
                axes[layer_idx].scatter(data_pca[len(neg):, 0], data_pca[len(neg) : , 1], c='red', alpha=0.7, s=100)

                # Log Reg
                logisticregression = LogisticRegression()
                y_train = [0] * len(neg) + [1] * len(pos)
                logisticregression.fit(data_pca, y_train)
                y_pred = logisticregression.predict(data_pca)
                accuracy = accuracy_score(y_train, y_pred)
                results[agg][key]['pca train acc'] = accuracy  # Save PCA accuracies for future analysis

                # Plot log reg
                max_x, min_x = data_pca[:,0].max(), data_pca[:,0].min()
                hop          = (max_x - min_x) / 100 
                x_values = np.arange(min_x, max_x + hop, hop)
                y_values = -(logisticregression.coef_[0][0] * x_values + logisticregression.intercept_[0]) / logisticregression.coef_[0][1]
                axes[layer_idx].plot(x_values, y_values, color='green', linestyle='dotted')

                # Title 
                axes[layer_idx].set_title(f"Layer {layer}, Acc {round(accuracy, 2)}" , fontweight = 'bold')

            # Final touches + save
            axes[0].set_ylabel("PC 1", fontweight = 'bold')
            axes[0].set_xlabel("PC 2", fontweight = 'bold')
            fig.suptitle(key + "_" + agg)
            plt.subplots_adjust(hspace=0.5)  # Increase the vertical space between subplots

            plt.savefig(f"{save_name}_{key}_{agg}.png")