import torch
import numpy as np
import os
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import seaborn as sns
import sys


import os
import numpy as np
import torch
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.ticker import ScalarFormatter
from scipy.cluster.hierarchy import linkage, dendrogram
from sklearn.manifold import MDS

from network import Network, Network_muP
from utils import set_seed


def plot_degeneracy_metrics(tasks, gamma_all, seed_range, base_path,
                            thresh=0.015, labels=None, n_hidden=128,
                            xlabel='$\\gamma$'):
    if labels is None:
        labels = [f"$\\gamma$ = {g}" for g in gamma_all]
    plot_loss_curves(tasks, seed_range, base_path, thresh, labels)
    weight_change_from_init(seed_range, tasks, gamma_all, base_path, n_hidden, xlabel)
    plot_representation_alignment(seed_range, tasks, gamma_all, base_path, xlabel)
    plot_kernel_alignment(seed_range, tasks, gamma_all, base_path, xlabel)
    plot_average_dimensionality(tasks, seed_range, base_path, xlabel)
    plot_dynamic_similarity(tasks, seed_range, base_path, xlabel)
    plot_weight_degeneracy(tasks, seed_range, base_path, xlabel)


def plot_loss_curves(tasks, seed_range, base_path,
                     thresh=0.01, labels=None):
    if labels is None:
        labels = tasks
    plt.figure(figsize=(8,5))
    colors = plt.cm.viridis(np.linspace(0,1,len(tasks)))

    for idx, task in enumerate(tasks):
        losses = []
        task_dir = os.path.join(base_path, task, 'losses')
        for seed in seed_range:
            try:
                arr = np.load(os.path.join(task_dir, f'seed_{seed}.npy'))
                losses.append(arr)
            except FileNotFoundError:
                continue
        if not losses:
            print(f"No losses for {task}")
            continue
        L = np.vstack(losses)
        mean_L = L.mean(axis=0)
        std_L = L.std(axis=0)
        epochs = np.arange(len(mean_L))
        plt.plot(epochs, mean_L, label=labels[idx], color=colors[idx])
        plt.fill_between(epochs, mean_L-std_L, mean_L+std_L,
                         color=colors[idx], alpha=0.2)

    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend(frameon=False)
    plt.tight_layout()
    plt.show()


def weight_change_from_init(seed_arr, task_arr, gamma_all,
                            base_path, n_hidden=128, ylabel='$\\gamma$'):
    data = []
    for g, task in zip(gamma_all, task_arr):
        vals = []
        params = {
            'n_hidden': n_hidden,
            'dim_input': 1,
            'dim_output': 1,
            'rnn_type': 'vrnn',
            'device': 'cpu',
            'trainable_ratio': 1.0,
            'muP_param': True,
            'gain': 0.6,
            'gamma': g,
            'tau': 1
        }
        for seed in seed_arr:
            try:
                set_seed(seed)
                model = Network_muP(params)
                init_params = [p.cpu().numpy() for n,p in model.named_parameters()
                               if n in ['J','U','w_readout']]
                trained = torch.load(
                    os.path.join(base_path, task, 'weights', f'seed_{seed}.pt'),
                    map_location='cpu'
                )
                trained_vals = [trained[k].cpu().numpy() for k in ['J','U','w_readout']]
                diff = sum(np.linalg.norm(a-b,'fro') for a,b in zip(init_params, trained_vals))
                vals.append(diff)
            except FileNotFoundError:
                continue
        data.append(vals)

    plt.figure(figsize=(6,4))
    for vals in data:
        plt.plot(vals, marker='o')
    plt.xlabel(ylabel)
    plt.ylabel('||ΔW|| (Fro)')
    plt.tight_layout()
    plt.show()


def representation_alignment(h0, hf):
    R0 = h0.T @ h0
    Rf = hf.T @ hf
    return np.trace(Rf @ R0) / (np.linalg.norm(Rf,'fro')*np.linalg.norm(R0,'fro'))


def plot_representation_alignment(seed_arr, task_arr, gamma_all,
                                  base_path, ylabel):
    plt.figure(figsize=(6,4))
    for g, task in zip(gamma_all, task_arr):
        vals = []
        for seed in seed_arr:
            try:
                h0 = np.load(os.path.join(base_path, task, 'hxs', f'seed_{seed}_initial.npy'))
                hf = np.load(os.path.join(base_path, task, 'hxs', f'seed_{seed}.npy'))
                vals.append(representation_alignment(h0[:,-1,:], hf[:,-1,:]))
            except FileNotFoundError:
                continue
        plt.plot(vals, marker='o')
    plt.xlabel(ylabel)
    plt.ylabel('Representation Alignment')
    plt.tight_layout()
    plt.show()


def plot_kernel_alignment(seed_arr, task_arr, gamma_all,
                          base_path, ylabel):
    def ka(K0,Kf): return np.trace(Kf@K0)/(np.linalg.norm(Kf,'fro')*np.linalg.norm(K0,'fro'))
    plt.figure(figsize=(6,4))
    for g, task in zip(gamma_all, task_arr):
        vals = []
        for seed in seed_arr:
            try:
                K0 = np.load(os.path.join(base_path, task, 'ntk', f'seed_{seed}_initial.npy'))
                Kf = np.load(os.path.join(base_path, task, 'ntk', f'seed_{seed}.npy'))
                vals.append(ka(K0,Kf))
            except FileNotFoundError:
                continue
        plt.plot(vals, marker='o')
    plt.xlabel(ylabel)
    plt.ylabel('Kernel Alignment')
    plt.tight_layout()
    plt.show()


def plot_average_dimensionality(task_arr, seed_arr,
                                base_path, ylabel,
                                thres=0.99, rotation=0):
    dims = []
    for task in task_arr:
        ds = []
        for seed in seed_arr:
            try:
                h = np.load(os.path.join(base_path, task, 'hxs', f'seed_{seed}.npy'))
                T,N = h.shape[0]*h.shape[1], h.shape[2]
                X = h.reshape(-1,N)
                pca = PCA(n_components=N)
                pca.fit(X)
                cum = np.cumsum(pca.explained_variance_ratio_)
                ds.append(np.argmax(cum>thres))
            except:
                continue
        dims.append(ds)
    means = [np.mean(d) for d in dims]
    ses = [np.std(d)/np.sqrt(len(d)) for d in dims]
    plt.figure(figsize=(6,4))
    x = np.arange(len(task_arr))
    plt.errorbar(x, means, yerr=ses, fmt='-o')
    plt.xticks(x, task_arr, rotation=rotation)
    plt.xlabel(ylabel)
    plt.ylabel('Dimensionality')
    plt.tight_layout()
    plt.show()


def plot_dynamic_similarity(task_arr, seed_range,
                            base_path, ylabel):
    plt.figure(figsize=(6,4))
    for task in task_arr:
        scores = []
        for seed in seed_range:
            try:
                s = np.load(os.path.join(base_path, 'DSA', task, f'network_{seed}.npy'))
                scores.extend(s)
            except:
                continue
        plt.plot(scores, marker='o')
    plt.xlabel(ylabel)
    plt.ylabel('Dynamical Degeneracy')
    plt.tight_layout()
    plt.show()


def plot_weight_degeneracy(task_arr, seed_range, base_path,
                           ylabel, n_hidden=128):
    plt.figure(figsize=(6,4))
    for task in task_arr:
        vals = []
        for seed in seed_range:
            try:
                arr = np.load(os.path.join(base_path, 'W_degeneracy', task, f'{seed}.npy'))
                vals.extend(arr)
            except:
                continue
        mean, se = np.mean(vals), np.std(vals)/np.sqrt(len(vals))
        plt.errorbar([task], [mean], yerr=[se], fmt='o')
    plt.xlabel(ylabel)
    plt.ylabel('Weight Degeneracy')
    plt.tight_layout()
    plt.show()


def set_seed(seed): 
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    if torch.cuda.is_available():
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
    np.random.seed(seed)
    
    
def check_network_outputs(args, predicted, output):
    """
    Plots the predicted vs. target values for the first 10 trials in 10 subfigures.

    Args:
        predicted (torch.Tensor): The predicted outputs from the model.
        output (torch.Tensor): The target outputs.
    """
    # Ensure inputs are converted to numpy arrays for plotting
    predicted = np.transpose(np.squeeze(np.array(predicted)))
    output = np.squeeze(np.array(output))

    # Plot the first 10 trials
    num_trials = min(10, predicted.shape[0])  # Limit to 10 trials
    fig, axes = plt.subplots(2, 5, figsize=(15, 6))  # Create a 2x5 grid of subplots
    axes = axes.flatten()  # Flatten the axes for easy iteration

    for i in range(num_trials):
        ax = axes[i]
        ax.plot(predicted[i], label='Predicted', color='blue')
        ax.plot(output[i], label='Target', color='orange', linestyle='--')
        ax.set_title(f'Trial {i+1}')
        ax.legend()

    # Remove unused subplots if less than 10 trials
    for j in range(num_trials, len(axes)):
        fig.delaxes(axes[j])

    plt.tight_layout()
    
    figname = f'{args.task_name}_seed_{args.seed}_network_outputs'
    
    
    
def get_dynamics(args, task, check_output=False):
    try:
        inputs = np.load(os.path.join(args.save_path, 'task_inputs', f'{args.task_name}.npy'))
        outputs = np.load(os.path.join(args.save_path, 'task_outputs', f'{args.task_name}.npy'))
    except: 
        inputs, outputs = task.generate_test_data()
        inputs = inputs.cpu().numpy()
        outputs = outputs.cpu().numpy()
        np.save(os.path.join(args.save_path, 'task_inputs', f'{args.task_name}.npy'), inputs)
        np.save(os.path.join(args.save_path, 'task_outputs', f'{args.task_name}.npy'), outputs)

    if 'sine' in args.task_name:
        inputs = []
        for i in range(len(task)):
            inputs.append(task[i][0].reshape(args.sequence_length, args.num_channels))
            
    inputs = torch.from_numpy(np.array(inputs).astype(np.float32)).to(args.device)
        
    prediction, activations = task.model.get_activations(inputs)
    
    if check_output:
        check_network_outputs(args, prediction, outputs)

    return activations


def get_ntk(args, task):
    inputs = np.load(os.path.join(args.save_path, 'task_inputs', f'{args.task_name}.npy'))
    
    if 'sine' in args.task_name:
        inputs = []
        for i in range(len(task)):
            inputs.append(task[i][0].reshape(args.sequence_length, args.num_channels))
                
    inputs = torch.from_numpy(np.array(inputs).astype(np.float32)).to(args.device)
    
    network = task.model
    device = network.params['device']
    
    batch_size = inputs.shape[0]
    h = torch.zeros([1, batch_size, network.params['n_hidden']], dtype=torch.float32).to(device)
    
    gradients = []

    # Forward pass and compute gradients
    for idx in range(inputs.shape[1]): 
        x = inputs[:, idx, :].unsqueeze(1)
        outputs, h = network(x, h) 
        
        # Compute gradients of the output w.r.t. the parameters
        for output in outputs.view(-1, outputs.shape[-1]):  # Flatten outputs
            for scalar_output in output:  # Loop over individual elements
                try:
                    grad = torch.autograd.grad(scalar_output, network.J, retain_graph=True)
                except:
                    grad = torch.autograd.grad(scalar_output, network.recurrent.weight_hh_l0, retain_graph=True)
                flattened_grad = torch.cat([g.flatten() for g in grad])
                gradients.append(flattened_grad)

    gradients = torch.stack(gradients)  # Shape: [batch_size * seq_len, num_params]
    ntk = gradients @ gradients.T  # Shape: [batch_size * seq_len, batch_size * seq_len]

    return ntk
    

def save_ntk(args, task, initial_save=False):
    ntk = get_ntk(args, task)
    os.makedirs(os.path.join(args.save_path, args.task_name, 'ntk'), exist_ok=True)
    if initial_save:
        save_path = os.path.join(args.save_path, args.task_name, 'ntk', f"seed_{args.seed}_initial.npy")
    else:
        save_path = os.path.join(args.save_path, args.task_name, 'ntk', f"seed_{args.seed}.npy")
    np.save(save_path, ntk.cpu().numpy())
    return ntk


def kernel_alignment(ntk_0, ntk_f):
    """
    Compute Kernel Alignment (KA) between two NTKs using NumPy arrays.
    
    Args:
        ntk_0: NTK before training (NumPy array).
        ntk_f: NTK after training (NumPy array).
    
    Returns:
        Kernel Alignment (KA).
    """
    numerator = np.trace(np.dot(ntk_f, ntk_0))
    norm_f = np.linalg.norm(ntk_f, ord='fro')
    norm_0 = np.linalg.norm(ntk_0, ord='fro')
    return numerator / (norm_f * norm_0)


def representation_alignment(h0: np.ndarray, hf: np.ndarray) -> float:
    """
    Computes representation alignment between h0 and hf.
    
    Parameters:
    -----------
    h0 : np.ndarray of shape (N, m)
        Hidden activity (e.g., before training).
    hf : np.ndarray of shape (N, m)
        Hidden activity (e.g., after training).
    
    Returns:
    --------
    float
        The representation alignment, a scalar between -1 and 1.
    """
    # Gram matrices
    R0 = h0 @ h0.T   # shape (m, m)
    Rf = hf @ hf.T # shape (m, m)
    
    # Numerator: trace of Rf R0
    numerator = np.trace(Rf @ R0)
    
    # Denominator: product of Frobenius norms of Rf and R0
    denominator = np.linalg.norm(Rf, 'fro') * np.linalg.norm(R0, 'fro')
    
    return numerator / denominator
