
# %matplotlib inline
# %config InlineBackend.figure_format = 'retina'

import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec, GridSpecFromSubplotSpec
import numpy as np

import torch
# import torch.nn as nn
# import torch.nn.functional as F

# do PCA on the embedding matrix, only the digit embeddings
from sklearn.decomposition import PCA
from sklearn.metrics.pairwise import cosine_similarity

import pandas as pd
import pickle
from glob import glob
import os

from scipy.optimize import curve_fit

# to get tensor v alues as numpy arrays
to_np = lambda x: x.detach().cpu().numpy()
V = to_np

import sys
sys.path.append('../')


# Define the logistic function
# def logistic(x, L, k, x0,y0):
#     return L / (1 + np.exp(-k * (np.log10(x) - np.log10(x0)))) + y0
def logistic(x, L, k, x0,y0):
    return (np.tanh(L)-y0) / (1 + np.exp(-k * (np.log10(x) - np.log10(x0)))) + y0

def invlogistic(y, L, k, x0, y0):
    """Inverse of our logistic function."""
    k1 = k / np.log(10)
    dy0 = np.tanh(L) - y0
    x = x0 * (dy0/(y - y0) -1)**(-1/k1)
    return x
    # return 10 ** (x0 + np.log((L - y + y0) / (y - y0)) / k)

# Function to perform logistic regression on continuous accuracy data
def logistic_fit_continuous(df, operation='add', column='accuracy_final', learning_threshold = 0.8):
    """Outputs dict: 
    {
        'params': [L, k, x0, y0],
        'transition_point': float,
        'learning_transition_point': float
    }"""
    # Filter data for the specific operation
    op_data = df[df['ops'] == operation]
    
    if len(op_data) == 0:
        print(f"No data found for operation: {operation}")
        # return None, None
        return {'params': None, 'transition_point': None, 'learning_transition_point': None}
    
    # Extract features and continuous target
    X = op_data['number_of_parameters'].values
    y_continuous = op_data[column].values
    
    # get sorting by parameters
    sorted_indices = np.argsort(X)
    X = X[sorted_indices]
    y_continuous = y_continuous[sorted_indices]
    
    # find where the accuracy is above the learning threshold
    learning_indices = np.where(y_continuous >= learning_threshold)[0]
    if len(learning_indices) == 0:
        print(f"No data points above the learning threshold ({learning_threshold}) for operation: {operation}.")
        # return None, None
        learning_transition_point = None
    else:
        learning_transition_point = X[learning_indices[0]]  # first point above the threshold
        
    
    # Fit the logistic function using scipy's curve_fit
    from scipy.optimize import curve_fit
    
    # Initial parameter guesses: L=max accuracy, k=1, x0=median of parameters
    p0 = [np.max(y_continuous), 1, np.median(X), 0.1]
    
    try:
        params, _ = curve_fit(logistic, X.ravel(), y_continuous, p0=p0, maxfev=10000)
        L, k, x0, y0 = params
        if k < 0:
            print(f"Warning: Negative steepness (k) found for operation: {operation}.")
        # The transition point is where the logistic function has its inflection point
        transition_point = x0
        
        print(f"Results for operation: {operation}, using column: {column}")
        print(f"L (max): {L:.4f}, k (steepness): {k:.4f}, x0 (midpoint): {transition_point:.2f}, y0 (offset): {y0:.4f}")
        # use invlogistic to find the x value corresponding to the learning threshold
        learning_transition_point = invlogistic(learning_threshold, L, k, x0, y0)
        # return params, transition_point
        return {'params': params, 'transition_point': transition_point, 'learning_transition_point': learning_transition_point}
    except Exception as e:
        print(f"Fitting failed for {operation}: {e}")
        # return None, None
        return {'params': None, 'transition_point': None, 'learning_transition_point': learning_transition_point}
    
def plot_logistic_fit_continuous(df, params, transition_point, group='SUM', group_col='ops', column='accuracy_final'):
    if params is None:
        return
    
    # Filter data for the specific operation
    group_data = df[df[group_col] == group]
    
    if len(group_data) == 0:
        return
    
    L, k, x0, y0 = params
    
    
    # Scatter plot of original data
    # plt.scatter(op_data['number_of_parameters'], op_data['accuracy_final'],
    #             s=op_data['n_layer']*10, alpha=0.7)
    x, y = group_data['number_of_parameters'].values, group_data[column].values
    # Scatter plot of original data
    plt.scatter(x,y, c=y, cmap='coolwarm', s=group_data['n_layer']*10, alpha=1)
    
    # Create grid of x values for the curve
    x_min, x_max = group_data['number_of_parameters'].min(), group_data['number_of_parameters'].max()
    x_range = np.logspace(np.log10(max(x_min, 1)), np.log10(x_max), 500)
    
    # Calculate predicted values
    y_curve = logistic(x_range, L, k, x0, y0)
    
    # Plot logistic curve
    plt.plot(x_range, y_curve, 'g--', lw=2, label='Logistic fit')
    
    # Mark transition point
    if transition_point > 0 and transition_point < x_max * 1.5:
        plt.plot(transition_point, logistic(transition_point, L, k, x0, y0), 'y*', markersize=18,
                markeredgecolor='black', markeredgewidth=1, 
                label=f'Transit.: {transition_point:.3g}'
                )
    
    # Format plot
    plt.xscale('log')
    plt.xlabel('Number of Parameters')
    plt.ylabel('Test Accuracy')
    plt.title(f'Logistic Fit for {group}')
    plt.legend()
    plt.grid(True, alpha=0.3)

def get_chars_digits(data_file):
    with open(data_file, 'rb') as f:
        data = pickle.load(f)
    print(f"Data loaded from {data_file}")

    char_names = data['metadata']['vocab']
    idx_digits = [char_names.index(str(c)) for c in data['metadata']['input_set']]

    print(f"Char names: {char_names}")
    print(f"Idx digits: {idx_digits}")
    return char_names, idx_digits

# function to get the model file for a given row
def get_model_file(row):
    return glob(f"{row['save_path']}/model_*{row['name']}*.pt")[0]

# function to get the results file for a given row
def get_results_file(row):
    return glob(f"{row['save_path']}/results_*{row['name']}*.pkl")[0]

# function to read files
def read_model_results_files(row):
    model_file = get_model_file(row)
    results_file = get_results_file(row)
    # print("accuracy_final: ", row.accuracy_final)
    # print(model_file, results_file)
    model = torch.load(model_file, map_location=torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu'))
    results = pickle.load(open(results_file, 'rb'))
    return model, results

def get_emb_corr_df(df, idx_digits):
    """
    Get the average embedding correlation for a given dataframe
    """
    # Create a copy of the dataframe to avoid the SettingWithCopyWarning
    df_copy = df.copy()
    corr_col = []
    corr_digit_col = []
    valid_indices = []  # Keep track of rows we successfully process
    
    for i, row in df_copy.iterrows():
        print(f"Row {i}: {row['name']} ({row.id}) acc:{row.accuracy_final}", end='\t\t\t\r')
        try:
            model, results = read_model_results_files(row)
            emb = model['token_embedding_table.weight']
            cor = to_np(torch.corrcoef(emb))#[char_map][:, char_map]
            corr_col.append(cor)
            # cor_digits = to_np(torch.corrcoef(emb[char_map[:10], :]))
            cor_digits = to_np(torch.corrcoef(emb[idx_digits]))
            # to debug, replace emb with random values
            # cor_digits = to_np(torch.corrcoef(torch.randn(10, emb.shape[1])))
            corr_digit_col.append(cor_digits)
            valid_indices.append(i)  # Add this index to valid indices
        except Exception as e:
            print(f"Error reading model/results files: {e}: {row['name']}")
            continue
    
    # Filter the dataframe to only include rows we successfully processed
    df_filtered = df_copy.loc[valid_indices].copy()
    
    # Use loc to set values to avoid the warning
    df_filtered.loc[:, 'embedding_correlation'] = corr_col
    df_filtered.loc[:, 'embedding_correlation_digits'] = corr_digit_col
    return df_filtered

def get_avg_corrs(df_op):
    # apply np.mean to the list of tensors
    avg_corr = np.mean(df_op.embedding_correlation.values, axis=0) 
    avg_corr_digits = np.mean(df_op.embedding_correlation_digits.values, axis=0)
    
    v = avg_corr
    print(f"Corr All Chars: shape={v.shape}, Mean={v.mean():.2g}, std={v.std():.2g}, max={v.max():.2g}, min={v.min():.2g}")
    v = avg_corr_digits
    print(f"Corr Digits: shape={v.shape}, Mean={v.mean():.2g}, std={v.std():.2g}, max={v.max():.2g}, min={v.min():.2g}")
    return avg_corr, avg_corr_digits

def plot_corr_pca(avg_corr, avg_corr_digits, char_names, idx_digits, 
            num_pca=4, vmax=2, size=4.5,
            plot_digits=False, cmap='seismic', cmap_cor='bwr', show_nums = True, markersize=100, oddeven=False, mod =2):
    """
    Plot the correlation matrix and PCA of the digit embeddings for a given key.
    """
    m = vmax
    # avg_corr, avg_corr_digits = get_avg_corrs(df_acc_groups)
    num_plots = 3 if plot_digits else 2
    figsize=(num_plots*size*1.15, size)
    # Create a figure with the correlation matrices in columns 1-2 and PCA grid in column 3
    fig = plt.figure(figsize=figsize)
        
    ii=1
    # First column: Full correlation matrix
    ax1 = plt.subplot(1, num_plots, ii)
    im1 = ax1.imshow(avg_corr, cmap=cmap_cor, vmin=-m, vmax=m)
    plt.colorbar(im1, ax=ax1)
    ax1.set_xticks(range(len(char_names)))
    ax1.set_xticklabels(char_names, rotation=90)
    ax1.set_yticks(range(len(char_names)))
    ax1.set_yticklabels(char_names)
    ax1.set_title('Avg Embedding Correlation')
    
    if plot_digits:
        ii+=1
        # Second column: Digits correlation matrix
        ax2 = plt.subplot(1, num_plots, ii)
        im2 = ax2.imshow(avg_corr_digits, cmap=cmap_cor, vmin=-m, vmax=m)
        plt.colorbar(im2, ax=ax2)
        # num = len(idx_digits)
        digits = np.array(char_names)[idx_digits]
        ax2.set_xticks(range(len(digits)))
        ax2.set_xticklabels(digits, rotation=90)
        ax2.set_yticks(range(len(digits)))
        ax2.set_yticklabels(digits)
        ax2.set_title('Avg Embed. Corr. of Digits')
    
    # Third column: PCA visualization with subplots for component pairs
    
    # Perform PCA on the digit correlation matrix
    pca = PCA(n_components=num_pca)
    cor = avg_corr_digits
    cor_pca = pca.fit_transform(cor)
    # to make comparison easier, we multiply by the sign of the first component of each vector
    cor_pca = cor_pca * np.sign(cor_pca[0])[np.newaxis]

    k = pca.n_components_
    # to get left side of gs, we need to know if there were 2 or 3 columns
    left = 0.7 if num_plots==3 else 0.53
    gs = plt.GridSpec(k-1, k-1, wspace=0.0, hspace=0.0, 
                    left=left, right=0.98, bottom=0.05, top=0.87)
    
    # for plot range use the max and min of all the PCA components
    x_min, x_max = 1.2*cor_pca.min(), 1.2*cor_pca.max()
    
    # plot PC[i] vs PC[j] for i,j < n_components
    for i in range(k):
        for j in range(i+1, k):
            colors = (np.arange(cor_pca.shape[0]) % mod) if oddeven else np.arange(cor_pca.shape[0])
            # print(i,j)
            # Note the reversed indices to get upper triangle
            # ax = axes[i, j-1]
            ax = fig.add_subplot(gs[i, j-1],)#aspect='equal')
            ax.plot(cor_pca[:, j], cor_pca[:, i], '-', alpha=0.5)
            ax.scatter(cor_pca[:, j], cor_pca[:, i], 
                    c=colors, 
                    cmap=cmap, alpha=1.0, s=markersize, zorder=2)
            if show_nums:
                # put digit text on the points
                for l in range(cor_pca.shape[0]):
                    ax.text(cor_pca[l, j], cor_pca[l, i], str(l), fontsize=10, ha='center', va='center', color='white')
                
            if j==i+1:
                ax.set_xlabel(f'PC{j+1}')
                ax.set_ylabel(f'PC{i+1}')
            # else:
                # hide tick labels only, but keep the ticks
            ax.set_xticklabels([])
            ax.set_yticklabels([])
                
            ax.grid(True, alpha=0.3)
            ax.set_xlim(x_min, x_max)
            ax.set_ylim(x_min, x_max)
            
    # add title to gs
    gs_title = fig.add_subplot(gs[0, :])
    gs_title.axis('off')
    # super title centered
    gs_title.set_title('PCA of Digit Embeddings', )#fontsize=14)
    
    plt.tight_layout()
    # plt.show()
    # return avg_corr, avg_corr_digits
    return fig


def plot_avg_pca(avg_corr_digits, num_pca=4, cmap='seismic', show_nums = True, markersize=150, oddeven=False):
    """
    Plot the PCA of the digit embeddings for a given key.
    """
    # avg_corr, avg_corr_digits = get_avg_corrs(df_acc_groups)
    # get digit part of the correlation matrix
    cor = avg_corr_digits

    pca = PCA(n_components=num_pca)
    # emb_digits = to_np(cor)
    cor_pca = pca.fit_transform(cor)
    # to make comparison easier, we multiply by the sign of the first component of each vector
    cor_pca = cor_pca * np.sign(cor_pca[0])[np.newaxis]

    k = pca.n_components_

    # plot the PCA with shared axes in the grid
    fig, axes = plt.subplots(k-1, k-1, figsize=(2*(k-1), 2*(k-1)),)# sharex=True, sharey=True)
    # for plot range use the max and min of all the PCA components
    x_min, x_max = 1.2*cor_pca.min(), 1.2*cor_pca.max()

    # hide all axes initially
    for i in range(k-1):
        for j in range(i):
            # Hide the empty subplot frame
            # print(i,j)
            axes[i, j].set_visible(False)

    # plot PC[i] vs PC[j] for i,j < n_components
    for i in range(k):
        for j in range(i+1, k):
            colors = (np.arange(cor_pca.shape[0]) % 2) if oddeven else np.arange(cor_pca.shape[0])
            # print(i,j)
            # Note the reversed indices to get upper triangle
            ax = axes[i, j-1]
            ax.plot(cor_pca[:, j], cor_pca[:, i], '-', alpha=0.5)
            ax.scatter(cor_pca[:, j], cor_pca[:, i], 
                    c=colors, 
                    cmap=cmap, alpha=1.0, s=markersize, zorder=2)
            if show_nums:
                # put digit text on the points
                for l in range(cor_pca.shape[0]):
                    ax.text(cor_pca[l, j], cor_pca[l, i], str(l), fontsize=10, ha='center', va='center', color='white')
                
            if j==i+1:
                ax.set_xlabel(f'PC{j+1}')
                ax.set_ylabel(f'PC{i+1}')
            else:
                # hide tick labels only, but keep the ticks
                ax.set_xticklabels([])
                ax.set_yticklabels([])
                
            ax.grid(True, alpha=0.3)
            ax.set_xlim(x_min, x_max)
            ax.set_ylim(x_min, x_max)

    # super title centered
    # plt.suptitle(f'{op} Mean PCA of Embed. {key.capitalize()} Acc. ({df_op_select.accuracy_final.mean()*100:.0f}%)'+\
    #             f" size={df_op_select['number_of_parameters'].mean()/1e3:.0f}K", 
    #             fontsize=16)
    plt.tight_layout()  
    fig.subplots_adjust(wspace=0.0, hspace=0.0)

    return fig

# plt.suptitle(f'{op} Avg. Embed Corr. of {key.capitalize()} Acc. ({df_acc_groups[key].accuracy_final.mean()*100:.0f}%)'+\
    #             f" mean size={df_acc_groups[key]['number_of_parameters'].mean()/1e3:.0f}K", 
    #             fontsize=16)
    
    # plt.tight_layout()#rect=[0, 0, 1, 0.95])  # Adjust layout to accommodate suptitle
    
# plt.suptitle(f'{op} Mean PCA of Embed. {key.capitalize()} Acc. ({df_op_select.accuracy_final.mean()*100:.0f}%)'+\
    #             f" size={df_op_select['number_of_parameters'].mean()/1e3:.0f}K", 
    #             fontsize=16)
    
# FIGS_DIR = f"../figs/{project_name}/"
# figs_pca_dir = os.path.join(FIGS_DIR, "pca_embedding/")
# os.makedirs(figs_pca_dir, exist_ok=True)
# # save as pca_embedding_<ops>_<acc>.pdf
# plt.savefig(figs_pca_dir+ f"./pca_embedding_{ops}_{row.accuracy_final}.pdf", bbox_inches='tight')