import os
import torch
import numpy as np
from sklearn.neighbors import KNeighborsClassifier
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

from dataloader import SequentialLoader
from tester import Tester, NetworkTester, ReductionTester
from utils import set_deterministic


_seed_ = 2024
class_index_dict = {
    5: [19, 5, 39, 35, 26]
}


class MouseTester(Tester):
    def _set_output_dir(self):
        if self.args.model_name in ["PCA", "TSNE", "CEBRA_discrete_5", "CEBRA_time_1", "CEBRA_time_5", "CEBRA_time_10"]:
            mouse_dataset_name = f"mouse{self.args.mouse_id}_cls{self.args.classes}_ts{self.args.time_step}"
            suffix = ""
            seed_dir = ""
            if self.args.shuffle_trials:
                suffix += f"_random_shuffle"
                seed_dir += f"seed{self.args.shuffle_seed}/"
            if self.args.sample_neurons:
                suffix += f"_random_sample"
                seed_dir += f"neuron{self.args.sample_neurons}_seed{self.args.sample_seed}/"
            suffix += "/"
            mouse_dataset_name += f"{suffix}{seed_dir}"

            self.output_dir = os.path.join(self.args.output_path, "Mouse_Visual_Data", self.args.stimulus, mouse_dataset_name, f"{self.args.model_name}_ld{self.args.latent_dim}")
            if self.args.model_name in ["CEBRA_discrete_5", "CEBRA_time_1", "CEBRA_time_5", "CEBRA_time_10"]:
                self.output_dir = os.path.join(self.output_dir, f"iterations{self.args.iterations}_bs{self.args.batch_size}_lr{self.args.lr}")
            os.makedirs(self.output_dir, exist_ok=True)
        else:
            assert self.args.checkpoint_path is not None
            self.output_dir = self.args.checkpoint_path

    def _compute_score(self, z_predict, label, mode="unite"):
        if self.args.model_name in ["lfads", "tidespl_vae"]:
            z_predict = np.reshape(z_predict, (self.args.classes * self.args.trials * (self.args.time_step - self.args.seq_len + 1), self.args.seq_len, -1))
            z_predict = np.reshape(z_predict[:, -1], (self.args.classes * self.args.trials, self.args.time_step - self.args.seq_len + 1, -1))
            z_predict = z_predict[:, 5 - self.args.seq_len + 1:]
        else:
            z_predict = np.reshape(z_predict, (self.args.classes * self.args.trials, self.args.time_step, -1))
            z_predict = z_predict[:, 5:]
        
        if self.args.model_name in ["swap_vae", "tidespl_vae"]:
            if self.args.latent_space == "content":
                z_predict = z_predict[..., :self.args.latent_dim // 2]
            elif self.args.latent_space == "style":
                z_predict = z_predict[..., self.args.latent_dim // 2:]

        if mode == "unite":
            z_predict = np.reshape(z_predict, (self.args.classes * self.args.trials, -1))
            label = label[np.arange(0, len(label), self.args.time_step)]

            train_idx = np.arange(self.args.trials // 5 * 4)[:, np.newaxis] + np.arange(0, self.args.classes * self.args.trials, self.args.trials)
            train_idx = train_idx.flatten()
            test_idx = np.arange(self.args.trials // 5 * 4, self.args.trials)[:, np.newaxis] + np.arange(0, self.args.classes * self.args.trials, self.args.trials)
            test_idx = test_idx.flatten()
        elif mode == "separate":
            z_predict = np.reshape(z_predict, (self.args.classes * self.args.trials * (self.args.time_step - 5), -1))
            label = np.reshape(label, (self.args.classes * self.args.trials, self.args.time_step))
            label = np.reshape(label[:, 5:], (self.args.classes * self.args.trials * (self.args.time_step - 5)))

            num_sample = len(label)
            train_idx = np.arange(0, num_sample, self.args.trials * (self.args.time_step - 5))[:, np.newaxis] + np.arange(self.args.trials * (self.args.time_step - 5) // 5 * 4)
            train_idx = train_idx.flatten()
            test_idx = np.arange(0, num_sample, self.args.trials * (self.args.time_step - 5))[:, np.newaxis] + np.arange(self.args.trials * (self.args.time_step - 5) // 5 * 4, self.args.trials * (self.args.time_step - 5))
            test_idx = test_idx.flatten()
        val_idx = test_idx[:len(test_idx) // 2]
        test_idx = test_idx[len(test_idx) // 2:]
        
        z_predict = (z_predict - np.min(z_predict, axis=0, keepdims=True)) / (np.max(z_predict, axis=0, keepdims=True) - np.min(z_predict, axis=0, keepdims=True))
        
        max_score = 0
        best_n = 3
        for n in [3, 5, 7, 9, 11, 13, 17, 19]:
            knn_model = KNeighborsClassifier(n_neighbors=n)
            knn_model.fit(z_predict[train_idx], label[train_idx])
            if self.args.data_split == "all":
                score = knn_model.score(z_predict, label)
            elif self.args.data_split == "train":
                score = knn_model.score(z_predict[train_idx], label[train_idx])
            elif self.args.data_split == "test":
                score = knn_model.score(z_predict[val_idx], label[val_idx])
            
            if score > max_score:
                max_score = score
                best_n = n
        
        knn_model = KNeighborsClassifier(n_neighbors=best_n)
        knn_model.fit(z_predict[train_idx], label[train_idx])
        if self.args.data_split == "all":
            label_predict = knn_model.predict(z_predict)
            score = knn_model.score(z_predict, label)
        elif self.args.data_split == "train":
            label_predict = knn_model.predict(z_predict[train_idx])
            score = knn_model.score(z_predict[train_idx], label[train_idx])
        elif self.args.data_split == "test":
            label_predict = knn_model.predict(z_predict[test_idx])
            score = knn_model.score(z_predict[test_idx], label[test_idx])

        return score, label_predict

    def test(self, x_true, u_true):
        z_best_predict = None
        max_score = -np.inf
        scores = np.zeros(self.args.repeat)
        labels_predict = []
        for i in range(self.args.repeat):
            print(f"Start computing latent [repeat {i}]...")
            self._load_model(repeat=i)
            z_predict = self._transform(x_true, u_true, repeat=i)

            print("Start testing...")
            score, label_predict = self._compute_score(z_predict, u_true.cpu().numpy())
            labels_predict.append(label_predict.reshape((self.args.classes, -1)))
            print(f"Classification Score: {score}")
            scores[i] = score
            if score > max_score:
                max_score = score
                z_best_predict = z_predict
            print()
        print(f"Average Classification Score: {np.mean(scores)}")
        np.save(os.path.join(self.output_dir, f"classification_score{self.suffix_for_score}.npy"), scores)
        # labels_predict = np.stack(labels_predict, axis=0)
        # np.save(os.path.join(self.output_dir, f"labels_predict{self.suffix_for_score}.npy"), labels_predict)
        # np.save(os.path.join(self.output_dir, f"latent{self.suffix_for_score}.npy"), z_best_predict)
        if self.args.plot:
            self.plot(z_best_predict, u_true.cpu().numpy())

    def plot(self, z_predict, u_true):
        print("Start ploting...")
        if self.args.model_name in ["lfads", "tidespl_vae"]:
            z_predict = np.reshape(z_predict, (self.args.classes * self.args.trials * (self.args.time_step - self.args.seq_len + 1), self.args.seq_len, -1))
            if self.args.model_name in ["tidespl_vae"]:
                if self.args.latent_space == "content":
                    z_predict = z_predict[..., :self.args.latent_dim // 2]
                elif self.args.latent_space == "style":
                    z_predict = z_predict[..., self.args.latent_dim // 2:]
            z_predict = np.reshape(z_predict[:, -1], (self.args.classes * self.args.trials, self.args.time_step - self.args.seq_len + 1, -1))
            z_predict = np.reshape(z_predict[:, 5 - self.args.seq_len + 1:], (self.args.classes * self.args.trials * (self.args.time_step - 5), -1))
        else:
            z_predict = np.reshape(z_predict, (self.args.classes * self.args.trials, self.args.time_step, -1))
            if self.args.model_name in ["swap_vae"]:
                if self.args.latent_space == "content":
                    z_predict = z_predict[..., :self.args.latent_dim // 2]
                elif self.args.latent_space == "style":
                    z_predict = z_predict[..., self.args.latent_dim // 2:]
            z_predict = np.reshape(z_predict[:, 5:], (self.args.classes * self.args.trials * (self.args.time_step - 5), -1))

        u_true = np.reshape(u_true, (self.args.classes * self.args.trials, self.args.time_step))
        u_true = np.reshape(u_true[:, 5:], (self.args.classes * self.args.trials * (self.args.time_step - 5)))

        num_sample = len(u_true)
        train_idx = np.arange(0, num_sample, self.args.trials * (self.args.time_step - 5))[:, np.newaxis] + np.arange(self.args.trials * (self.args.time_step - 5) // 5 * 4)
        train_idx = train_idx.flatten()
        test_idx = np.arange(0, num_sample, self.args.trials * (self.args.time_step - 5))[:, np.newaxis] + np.arange(self.args.trials * (self.args.time_step - 5) // 5 * 4, self.args.trials * (self.args.time_step - 5))
        test_idx = test_idx.flatten()

        z_predict_train = z_predict[train_idx]
        z_predict_test = z_predict[test_idx]

        reduc = eval(f"{self.args.latent_reduc}")(n_components=2, random_state=_seed_)
        z_predict_reduc = reduc.fit_transform(z_predict)
        # np.save(os.path.join(self.output_dir, f"reduce_latent_all_{self.args.latent_reduc}{self.suffix_for_score}.npy"), z_predict_reduc)

        z_predict_reduc_train = z_predict_reduc[train_idx]
        z_predict_reduc_test = z_predict_reduc[test_idx]

        # reduc = eval(f"{self.args.latent_reduc}")(n_components=2, random_state=_seed_)
        # z_predict_reduc_train = reduc.fit_transform(z_predict_train)
        # np.save(os.path.join(self.output_dir, f"reduce_latent_train_{self.args.latent_reduc}{self.suffix_for_score}.npy"), z_predict_reduc_train)
        # reduc = eval(f"{self.args.latent_reduc}")(n_components=2, random_state=_seed_)
        # z_predict_reduc_test = reduc.fit_transform(z_predict_test)
        # np.save(os.path.join(self.output_dir, f"reduce_latent_test_{self.args.latent_reduc}{self.suffix_for_score}.npy"), z_predict_reduc_test)

        u_true_train = u_true[train_idx]
        u_true_test = u_true[test_idx]

        my_color = plt.cm.rainbow(np.linspace(0, 1, self.args.classes))
        color_train = my_color[u_true_train]
        color_test = my_color[u_true_test]

        fig, ax = plt.subplots(2, 3, figsize=(12, 8), dpi=150, gridspec_kw={"hspace": 0.3})
        ax = ax.flatten()
        ax[0].scatter(z_predict_reduc_train[:, 0], z_predict_reduc_train[:, 1], c=color_train, s=2, alpha=0.5)
        ax[0].set_title(f"{self.model_name_for_plot} (train)\nOne Time Point for One Trial")
        ax[3].scatter(z_predict_reduc_test[:, 0], z_predict_reduc_test[:, 1], c=color_test, s=2, alpha=0.5)
        ax[3].set_title(f"{self.model_name_for_plot} (test)\nOne Time Point for One Trial")
        ax[3].set_xlabel("Latent 1")
        ax[3].set_ylabel("Latent 2")
        
        for i in range(self.args.classes):
            for j in range(self.args.trials // 5 * 4):
                l_index = (self.args.time_step - 5) * j + (self.args.time_step - 5) * self.args.trials // 5 * 4 * i
                r_index = l_index + (self.args.time_step - 5)
                ax[1].plot(z_predict_reduc_train[l_index: r_index, 0], z_predict_reduc_train[l_index: r_index, 1], c=my_color[i], marker='*', markersize=3, markevery=[-1], alpha=0.3)
                ax[1].plot(z_predict_reduc_train[l_index, 0], z_predict_reduc_train[l_index, 1], c=my_color[i], marker='s', markersize=1, alpha=0.3)
            for j in range(self.args.trials // 5):
                l_index = (self.args.time_step - 5) * j + (self.args.time_step - 5) * self.args.trials // 5 * i
                r_index = l_index + (self.args.time_step - 5)
                ax[4].plot(z_predict_reduc_test[l_index: r_index, 0], z_predict_reduc_test[l_index: r_index, 1], c=my_color[i], marker='*', markersize=3, markevery=[-1], alpha=0.3)
                ax[4].plot(z_predict_reduc_test[l_index, 0], z_predict_reduc_test[l_index, 1], c=my_color[i], marker='s', markersize=1, alpha=0.3)
        ax[1].set_title(f"{self.model_name_for_plot} (train)\nTrajectory for One Trial")
        ax[4].set_title(f"{self.model_name_for_plot} (test)\nTrajectory for One Trial")

        z_predict_reduc_train_mean = np.mean(np.reshape(z_predict_reduc_train, (self.args.classes, self.args.trials // 5 * 4, self.args.time_step - 5, -1)), axis=1)
        z_predict_reduc_test_mean = np.mean(np.reshape(z_predict_reduc_test, (self.args.classes, self.args.trials // 5, self.args.time_step - 5, -1)), axis=1)
        for i in range(self.args.classes):
            ax[2].plot(z_predict_reduc_train_mean[i, :, 0], z_predict_reduc_train_mean[i, :, 1], c=my_color[i], marker='*', markersize=6, markevery=[-1])
            ax[2].plot(z_predict_reduc_train_mean[i, 0, 0], z_predict_reduc_train_mean[i, 0, 1], c=my_color[i], marker='s', markersize=3)
            ax[2].scatter(z_predict_reduc_train_mean[i, 1:-1, 0], z_predict_reduc_train_mean[i, 1:-1, 1], color=my_color[i], marker='o', s=5)
            ax[5].plot(z_predict_reduc_test_mean[i, :, 0], z_predict_reduc_test_mean[i, :, 1], c=my_color[i], marker='*', markersize=6, markevery=[-1])
            ax[5].plot(z_predict_reduc_test_mean[i, 0, 0], z_predict_reduc_test_mean[i, 0, 1], c=my_color[i], marker='s', markersize=3)
            ax[5].scatter(z_predict_reduc_test_mean[i, 1:-1, 0], z_predict_reduc_test_mean[i, 1:-1, 1], color=my_color[i], marker='o', s=5)
        ax[2].set_title(f"{self.model_name_for_plot} (train)\nTrajectory for Average-Trial")
        ax[5].set_title(f"{self.model_name_for_plot} (test)\nTrajectory for Average-Trial")

        plt.savefig(os.path.join(self.output_dir, f"z{self.suffix_for_plot}.png"))


class MouseNetworkTester(MouseTester, NetworkTester):
    def _load_dataloader(self, x_true, u_true):
        if self.args.model_name in ["vae", "pivae", "swap_vae"]:
            all_set = torch.utils.data.TensorDataset(x_true, u_true)
            all_loader = torch.utils.data.DataLoader(all_set, batch_size=self.args.batch_size)
        elif self.args.model_name in ["lfads", "tidespl_vae"]:
            all_set = SequentialLoader(x_true, u_true, self.args.batch_size, self.args.time_step, self.args.seq_len, interval=1, cutoff=True)
            all_loader = torch.utils.data.DataLoader(all_set)

        return all_set, all_loader

    def _preprocess_inputs(self, inputs):
        if self.args.model_name in ["vae"]:
            inputs = {"x": inputs[0].to(self.device)}
        elif self.args.model_name in ["lfads"]:
            inputs = {"x": inputs[0].squeeze(0).permute(1, 0, 2).to(self.device)}
        elif self.args.model_name in ["pivae"]:
            inputs = {"x": inputs[0].to(self.device), "u": inputs[1].to(self.device)}
        elif self.args.model_name in ["swap_vae"]:
            inputs = {"x1": inputs[0].to(self.device), "x2": inputs[0].to(self.device)}
        elif self.args.model_name in ["tidespl_vae"]:
            inputs = {"x": inputs[0].squeeze(0).permute(1, 0, 2).to(self.device), "x_pos": inputs[0].squeeze(0).permute(1, 0, 2).to(self.device), "x_neg": inputs[0].squeeze(0).permute(1, 0, 2).to(self.device)}
        
        return inputs


class MouseReductionTester(MouseTester, ReductionTester):
    def _load_train_set(self, x_true, u_true):
        x_true = x_true.numpy()
        u_true = u_true.numpy()
        if self.args.model_name in ["CEBRA_time_1", "CEBRA_time_5", "CEBRA_time_10"]:
            u_true = u_true.astype("float")
        num_sample = len(x_true)
        train_idx = np.arange(self.args.trials * self.args.time_step // 5 * 4)[:, np.newaxis] + np.arange(0, num_sample, self.args.trials * self.args.time_step)
        train_idx = train_idx.flatten()
        x_train = x_true[train_idx]
        u_train = u_true[train_idx]

        return x_true, u_true, x_train, u_train


def get_args():
    import argparse
    
    parser = argparse.ArgumentParser()
    parser.add_argument("--data-path", default="neural_dataset/", type=str, help="dataset path")
    parser.add_argument("--stimulus", default="allen_natural_scenes", type=str, help="type of visual stimulus")
    parser.add_argument("--time-step", default=25, type=int, help="time steps")
    parser.add_argument("--mouse-id", default=16, type=int, help="mouse id")
    parser.add_argument("--classes", default=5, type=int, help="number of classes")
    
    parser.add_argument("--shuffle-trials", action="store_true", help="whether to use shuffled trials")
    parser.add_argument("--shuffle-seed", default=1111, type=int, help="random seed for shuffling")
    parser.add_argument("--sample-neurons", default=0, type=int, help="number of sampled neurons ('0' indicates no sampling)")
    parser.add_argument("--sample-seed", default=1111, type=int, help="random seed for sampling")

    parser.add_argument("--model-name", default="tidespl_vae", type=str, help="name of model to train")
    parser.add_argument("--latent-dim", default=128, type=int, help="latent variable dimension")
    parser.add_argument("--latent-space", default="full", type=str, choices=["full", "content", "style"], help="latent variable space")
    parser.add_argument("--latent-reduc", default="PCA", type=str, choices=["PCA", "TSNE"], help="dimension reduction of latent variable space")
    parser.add_argument("--checkpoint-path", default=None, help="path to load checkpoint")
    parser.add_argument("--seq-len", default=5, type=int, help="sequence length for sequential model")

    parser.add_argument("--iterations", default=5000, type=int, help="number of iterations to train")
    parser.add_argument("--lr", default=1e-4, type=float, help="initial learning rate")

    parser.add_argument("--data-split", default="test", type=str, choices=["all", "train", "test"], help="split part of the dataset")
    parser.add_argument("--batch-size", default=250, type=int, help="number of images per gpu")

    parser.add_argument("--seed", action="store_true", help="whether to use random seed")
    parser.add_argument("--repeat", default=10, type=int, help="experiment for repeat")
    parser.add_argument("--plot", action="store_true", help="plot low dimension latent")
    parser.add_argument("--output-path", default="logs/", help="path to save outputs")
    parser.add_argument("--device", default="cuda:0", help="device for train and test")
    args = parser.parse_args()
    return args


def load_mouse_dataset(args):
    suffix = ""
    if args.shuffle_trials:
        suffix += f"_trial_seed{args.shuffle_seed}"
    if args.sample_neurons:
        suffix += f"_neuron{args.sample_neurons}_seed{args.sample_seed}"
    mouse_dataset = np.load(os.path.join(args.data_path, args.stimulus, f"mouse{args.mouse_id}_ts{args.time_step}{suffix}.npy"))
    num_trials = mouse_dataset.shape[1]
    num_neurons = mouse_dataset.shape[-1]

    class_index = class_index_dict[args.classes]
    mouse_dataset = mouse_dataset[class_index]
    mouse_dataset = np.reshape(mouse_dataset, (args.classes, -1, num_neurons))
    mouse_dataset = np.reshape(mouse_dataset, (-1, num_neurons))
    labels = np.tile(np.arange(0, args.classes)[:, np.newaxis], (1, num_trials * args.time_step)).flatten()

    x_true = torch.from_numpy(mouse_dataset)
    u_true = torch.from_numpy(labels)

    return x_true, u_true, num_trials, num_neurons


def main():
    args = get_args()
    if args.seed:
        set_deterministic(_seed_)
    print(args)

    x_true, u_true, args.trials, args.data_dim = load_mouse_dataset(args)
    print(f"Dataset Size: {x_true.size(0)}")

    if args.model_name in ["PCA", "TSNE", "CEBRA_discrete_5", "CEBRA_time_1", "CEBRA_time_5", "CEBRA_time_10"]:
        mouse_data_tester = MouseReductionTester(args.latent_dim, args)
    else:
        mouse_data_tester = MouseNetworkTester(args.latent_dim, args)

    mouse_data_tester.test(x_true, u_true)


if __name__ == "__main__":
    main()
