"""
Perform and expeirment with a toy model.

Results will be saved in 'output/toy_model/{experiment_name}/' directory.
"""
import argparse
import os
import pickle
import pandas as pd
import tqdm
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA


class NeuronLayer(nn.Module):
    """
    Simple neuron layer with Gaussian tuning curves.
    """
    def __init__(self, n_neurons, tuning_width):
        super().__init__()
        self.n_neurons = n_neurons
        self.tuning_width = tuning_width
        self.peaks = torch.arange(0, 1, 1/n_neurons)

    def forward(self, x):
        # x can be (1,) or (n_samples,)
        x = x.unsqueeze(-1)  # shape: (n_samples, 1)
        dist = torch.abs(x - self.peaks)  # broadcasting
        dist = torch.minimum(dist, 1 - dist)  # circular distance
        activity = torch.exp(-0.5 * (dist / self.tuning_width) ** 2)
        return activity  # (n_samples, n_neurons)
    
def retrieve(
        X: np.ndarray, 
        Y: np.ndarray, 
        target_feature: np.ndarray
    ) -> int:
    """
    Retrieve the input x from that produces the feature y closest to the target feature
    """
    # calculate the distance between the target feature and the neuron activity
    distances = np.linalg.norm(Y - target_feature, axis=1)
    # find the index of the minimum distance
    min_index = np.argmin(distances)
    # return the corresponding input x
    return X[min_index]


def add_noise_with_cosine_distance(
        feature: torch.Tensor,
        target_cos_dist: float
):
    cos_sim = 1 - target_cos_dist
    norm = torch.linalg.norm(feature.flatten())
    d = feature.flatten().shape[0]
    std = norm * np.sqrt((1 / cos_sim**2 - 1) / d)
    noise = torch.randn_like(feature) * std
    noised_feature = feature + noise

    # calculate the actual cosine distance
    cos_sim = torch.nn.functional.cosine_similarity(
        feature.flatten(), noised_feature.flatten(), dim=0
    ).item()
    cos_dist = 1 - cos_sim

    return noised_feature, std.item(), cos_dist
    

def main(
        experimnet_name: str,
        n_neurons: int,
        tuning_width: float,
        n_samples: int,
        n_trial: int,
        target_cos_dists: list,
        seed: int
):
    """
    Perform a toy model experiment
    Args:
        experiment_name (str): Name of the experiment
        n_neurons (int): Number of neurons
        tuning_width (float): Width of the tuning curves
        n_samples (int): Number of samples to calculate PCA
        n_trial (int): Number of trials to perform reconstruction
        seed (int): Random seed
    """
    # Create the neuron layer and generate activity
    neuron_layer = NeuronLayer(n_neurons, tuning_width)

    # Calculate manifold and apply PCA
    X = torch.linspace(0, 1, n_samples)  # input samples
    Y = neuron_layer(X).detach().numpy()  # activity of the neurons (shape: (n_samples, n_neurons))
    X = X.detach().numpy()  
    pca = PCA(n_components=3)
    Y_pca = pca.fit_transform(Y)

    # prepare true feature
    true_x = 0.5
    true_feature = neuron_layer(torch.tensor([true_x]))
    torch.random.manual_seed(seed)  # for reproducibility

    # readout analysis
    all_target_features = []
    results = []
    for target_cos_dist in tqdm.tqdm(target_cos_dists):
        for trial in range(n_trial):
            target_feature, std, cos_dist = add_noise_with_cosine_distance(true_feature, target_cos_dist)
            x_rec = retrieve(X, Y, target_feature.detach().numpy())
            error = x_rec - true_x

            # store the results
            all_target_features.append(target_feature)
            results.append({
                'target_cos_dist': target_cos_dist,
                'cos_dist': cos_dist,
                'std': std,
                'trial': trial,
                'recon_x': x_rec.item(),
                'error': error.item(),
            })

    # turn results into pandas dataframe
    results = pd.DataFrame(results)

    # apply PCA to the true feature
    true_feature = true_feature.detach().cpu().numpy()
    true_feature_pca = pca.transform(true_feature.reshape(1, -1))
    true_feature_pca = true_feature_pca[0]

    # apply PCA to the target features
    target_features = torch.vstack(all_target_features).detach().cpu().numpy()
    target_features_pca = pca.transform(target_features)

    # save the results
    save_dir = f'output/toy_model/{experimnet_name}/'
    os.makedirs(save_dir, exist_ok=True)

    # Save pca model
    with open(os.path.join(save_dir, 'pca.pkl'), 'wb') as f:
        pickle.dump(pca, f)
    # Save X, Y, and Y_pca
    np.savez(
        os.path.join(save_dir, 'X_Y.npz'),
        X=X,
        Y=Y,
        Y_pca=Y_pca
    )

    # save readout analysis results
    np.savez(
        os.path.join(save_dir, 'readout_analysis.npz'),
        true_feature=true_feature,
        true_feature_pca=true_feature_pca,
        target_features=target_features,
        target_features_pca=target_features_pca,
    )
    results.to_csv(
        os.path.join(save_dir, 'readout_analysis.csv'),
        index=False
    )

    # plot tuning curves
    fig, ax = plt.subplots(figsize=(6, 4))
    for i in range(n_neurons):
        ax.plot(X, Y[:, i], color='gray', alpha=0.5)
    i_half = n_neurons // 2
    ax.plot(X, Y[:, i_half], color='red')
    ax.set_xlabel('Input')
    ax.set_ylabel('Neuron Activity')
    ax.set_title('Tuning Curves of Neurons')
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.set_xlim(0, 1.05)
    ax.set_ylim(0, 1.05)
    plt.savefig(os.path.join(save_dir, 'tuning_curves.png'), dpi=300)
    plt.close(fig)

    # plot neural manifold
    fig = plt.figure(figsize=(6, 4))
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(Y_pca[:, 0], Y_pca[:, 1], Y_pca[:, 2], c=X, cmap='viridis', s=5)
    ax.set_xlabel('PC1')
    ax.set_ylabel('PC2')
    ax.set_zlabel('PC3')
    ax.set_title('Manifold')
    cbar = plt.colorbar(ax.collections[0], ax=ax, pad=0.1)
    cbar.set_label('$x$')
    cbar.set_ticks([0, 0.5, 1])

    # get current xlim and ylim
    xlim = ax.get_xlim()
    ylim = ax.get_ylim()
    zlim = ax.get_zlim()

    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, 'neural_manifold.png'), dpi=300)
    plt.close(fig)

    # Plot manifold with readout representation
    # Set erro as color
    color = np.abs(results['error'].values)
    fig = plt.figure(figsize=(6, 4))
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(Y_pca[:, 0], Y_pca[:, 1], Y_pca[:, 2], c='gray', s=5, label='Manifold')
    ax.scatter(
        target_features_pca[:, 0], target_features_pca[:, 1], target_features_pca[:, 2], 
        c=color, s=5, label='Readout', cmap='Oranges_r'
    )
    ax.scatter(
        true_feature_pca[0], true_feature_pca[1], true_feature_pca[2],
        c='blue', s=5, label='True Feature'
    )
    # add colorbar
    cbar = plt.colorbar(ax.collections[1], ax=ax, pad=0.1)
    cbar.set_label('Error')
    ax.set_xlabel('PC1')
    ax.set_ylabel('PC2')
    ax.set_zlabel('PC3')
    ax.set_title('Manifold and readout representation')
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    ax.set_zlim(zlim)
    ax.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, 'neural_manifold_with_readout.png'), dpi=300)
    plt.close(fig)

    # plot feature distance and error
    fig, ax = plt.subplots(figsize=(6, 4))
    ax.scatter(results['cos_dist'], results['error'].abs(), alpha=0.5)
    ax.set_xlabel('Feature cosine distance')
    ax.set_ylabel('Absolute retrieval Error')
    ax.set_title('Feature distance and reconstruction error')
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.set_xlim(-0.05, 1.05)
    ax.set_ylim(-0.05, 1.05)
    ax.set_xticks(np.arange(0, 1.1, 0.25))
    ax.set_yticks(np.arange(0, 1.1, 0.25))
    ax.grid(True, which='major', linestyle='--', alpha=0.3, linewidth=0.5, color='gray')
    ax.set_aspect('equal', adjustable='box')
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, 'feature_distance_and_error.png'), dpi=300)
    plt.close(fig)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Toy model experiment')
    parser.add_argument('--experiment_name', type=str, help='Name of the experiment')
    parser.add_argument('--n_neurons', default=100, type=int, help='Number of neurons')
    parser.add_argument('--tuning_width', default=0.1, type=float, help='Width of the tuning curves')
    parser.add_argument('--n_samples', default=100, type=int, help='Number of samples to calculate PCA')
    parser.add_argument('--n_trial', default=100, type=int, help='Number of trials to perform reconstruction')
    parser.add_argument(
        '--target_cos_dists', 
        default=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], 
        type=list, help='Target cosine distances'
    )
    parser.add_argument('--seed', default=42, type=int, help='Random seed')
    args = parser.parse_args()

    main(
        experimnet_name=args.experiment_name,
        n_neurons=args.n_neurons,
        tuning_width=args.tuning_width,
        n_samples=args.n_samples,
        n_trial=args.n_trial,
        target_cos_dists=args.target_cos_dists,
        seed=args.seed
    )
