import os
import torch
import numpy as np
from sklearn.neighbors import KNeighborsClassifier
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

from dataloader import SequentialLoader
from tester import Tester, NetworkTester, ReductionTester
from utils import set_deterministic


_seed_ = 2024


class MouseTester(Tester):
    def _set_output_dir(self):
        if self.args.model_name in ["PCA", "TSNE", "CEBRA_time_delta_4", "CEBRA_time_4"]:
            mouse_dataset_name = f"mouse{self.args.mouse_id}_ts{self.args.time_step}"
            suffix = ""
            seed_dir = ""
            if self.args.shuffle_trials:
                suffix += f"_random_shuffle"
                seed_dir += f"seed{self.args.shuffle_seed}/"
            if self.args.sample_neurons:
                suffix += f"_random_sample"
                seed_dir += f"neuron{self.args.sample_neurons}_seed{self.args.sample_seed}/"
            suffix += "/"
            mouse_dataset_name += f"{suffix}{seed_dir}"

            self.output_dir = os.path.join(self.args.output_path, "Mouse_Visual_Data", self.args.stimulus, mouse_dataset_name, f"{self.args.model_name}_ld{self.args.latent_dim}")
            if self.args.model_name in ["CEBRA_time_delta_4", "CEBRA_time_4"]:
                self.output_dir = os.path.join(self.output_dir, f"iterations{self.args.iterations}_bs{self.args.batch_size}_lr{self.args.lr}")
            os.makedirs(self.output_dir, exist_ok=True)
        else:
            assert self.args.checkpoint_path is not None
            self.output_dir = self.args.checkpoint_path

    def _compute_score(self, z_predict, label):
        if self.args.model_name in ["lfads", "tidespl_vae"]:
            z_predict = np.reshape(z_predict, (self.args.trials * self.args.classes * self.args.time_step, self.args.seq_len, -1))
            z_predict = np.reshape(z_predict[:, -1], (self.args.trials * self.args.classes, self.args.time_step, -1))
            z_predict = np.mean(z_predict, axis=1)
        else:
            z_predict = np.reshape(z_predict, (self.args.trials * self.args.classes, self.args.time_step, -1))
            z_predict = np.mean(z_predict, axis=1)
        
        if self.args.model_name in ["swap_vae", "tidespl_vae"]:
            if self.args.latent_space == "content":
                z_predict = z_predict[..., :self.args.latent_dim // 2]
            elif self.args.latent_space == "style":
                z_predict = z_predict[..., self.args.latent_dim // 2:]

        label = label[np.arange(0, len(label), self.args.time_step)]

        z_predict = (z_predict - np.min(z_predict, axis=0, keepdims=True)) / (np.max(z_predict, axis=0, keepdims=True) - np.min(z_predict, axis=0, keepdims=True))

        train_idx = np.arange(0, len(label) // 5 * 4)
        test_idx = np.arange(len(label) // 5 * 4, len(label))
        val_idx = test_idx[:len(test_idx) // 2]
        test_idx = test_idx[len(test_idx) // 2:]
        
        max_score = 0
        if self.args.data_split == "test":
            best_n = 1
            n_choice = [1, 3, 5, 7, 9, 11, 13, 17, 19]
        else:
            best_n = 3
            n_choice = [3, 5, 7, 9, 11, 13, 17, 19]
        time_window = self.args.time_window
        for n in n_choice:
            knn_model = KNeighborsClassifier(n_neighbors=n)
            knn_model.fit(z_predict[train_idx], label[train_idx])
            if self.args.data_split == "all":
                label_predict = knn_model.predict(z_predict)
                label_errors = label_predict - label
            elif self.args.data_split == "train":
                label_predict = knn_model.predict(z_predict[train_idx])
                label_errors = label_predict - label[train_idx]
            elif self.args.data_split == "test":
                label_predict = knn_model.predict(z_predict[val_idx])
                label_errors = label_predict - label[val_idx]
            score = np.sum(np.abs(label_errors) <= (time_window * 30)) / len(label_errors)
            
            if score > max_score:
                max_score = score
                best_n = n

        knn_model = KNeighborsClassifier(n_neighbors=best_n)
        knn_model.fit(z_predict[train_idx], label[train_idx])
        if self.args.data_split == "all":
            label_predict = knn_model.predict(z_predict)
            label_errors = label_predict - label
        elif self.args.data_split == "train":
            label_predict = knn_model.predict(z_predict[train_idx])
            label_errors = label_predict - label[train_idx]
        elif self.args.data_split == "test":
            label_predict = knn_model.predict(z_predict[test_idx])
            label_errors = label_predict - label[test_idx]
        score = np.sum(np.abs(label_errors) <= (time_window * 30)) / len(label_errors)

        return score

    def test(self, x_true, u_true):
        z_best_predict = None
        max_score = -np.inf
        scores = np.zeros(self.args.repeat)
        for i in range(self.args.repeat):
            print(f"Start computing latent [repeat {i}]...")
            self._load_model(repeat=i)
            z_predict = self._transform(x_true, u_true, repeat=i)

            print("Start testing...")
            score = self._compute_score(z_predict, u_true.cpu().numpy())
            print(f"Classification Score: {score}")
            scores[i] = score
            if score > max_score:
                max_score = score
                z_best_predict = z_predict
            print()
        print(f"Average Classification Score: {np.mean(scores)}")
        np.save(os.path.join(self.output_dir, f"classification_score{self.suffix_for_score}_tw{self.args.time_window}.npy"), scores)
        # np.save(os.path.join(self.output_dir, f"latent{self.suffix_for_score}_tw{self.args.time_window}.npy"), z_best_predict)


class MouseNetworkTester(MouseTester, NetworkTester):
    def _load_dataloader(self, x_true, u_true):
        if self.args.model_name in ["vae", "pivae", "swap_vae"]:
            all_set = torch.utils.data.TensorDataset(x_true, u_true)
            all_loader = torch.utils.data.DataLoader(all_set, batch_size=self.args.batch_size)
        elif self.args.model_name in ["lfads", "tidespl_vae"]:
            x_true = torch.cat([torch.zeros(self.args.seq_len - 1, x_true.size(1)).to(x_true), x_true], dim=0)
            u_true = torch.cat([torch.zeros(self.args.seq_len - 1).to(u_true), u_true])
            all_set = SequentialLoader(x_true, u_true, self.args.batch_size, self.args.time_step, self.args.seq_len, interval=1)
            all_loader = torch.utils.data.DataLoader(all_set)

        return all_set, all_loader

    def _preprocess_inputs(self, inputs):
        if self.args.model_name in ["vae"]:
            inputs = {"x": inputs[0].to(self.device)}
        elif self.args.model_name in ["lfads"]:
            inputs = {"x": inputs[0].squeeze(0).permute(1, 0, 2).to(self.device)}
        elif self.args.model_name in ["pivae"]:
            inputs = {"x": inputs[0].to(self.device), "u": inputs[1].to(self.device)}
        elif self.args.model_name in ["swap_vae"]:
            inputs = {"x1": inputs[0].to(self.device), "x2": inputs[0].to(self.device)}
        elif self.args.model_name in ["tidespl_vae"]:
            inputs = {"x": inputs[0].squeeze(0).permute(1, 0, 2).to(self.device), "x_pos": inputs[0].squeeze(0).permute(1, 0, 2).to(self.device), "x_neg": inputs[0].squeeze(0).permute(1, 0, 2).to(self.device)}
        
        return inputs


class MouseReductionTester(MouseTester, ReductionTester):
    def _load_train_set(self, x_true, u_true):
        if self.args.model_name in ["CEBRA_time_delta_4", "CEBRA_time_4"]:
            dino_feature = torch.load(os.path.join(self.args.data_path, self.args.stimulus, "dino_feature.pt"))
            dino_feature = dino_feature[:, None, :].repeat((1, self.args.time_step, 1)).flatten(0, 1)
            dino_feature = dino_feature.repeat((self.args.trials, 1))
            u_true = dino_feature
        else:
            x_true = x_true.numpy()
            u_true = u_true.numpy()
        num_sample = len(x_true)
        x_train = x_true[:num_sample // 5 * 4]
        u_train = u_true[:num_sample // 5 * 4]

        return x_true, u_true, x_train, u_train


def get_args():
    import argparse
    
    parser = argparse.ArgumentParser()
    parser.add_argument("--data-path", default="neural_dataset/", type=str, help="dataset path")
    parser.add_argument("--stimulus", default="allen_natural_movie_one", type=str, help="type of visual stimulus")
    parser.add_argument("--time-step", default=4, type=int, help="time steps")
    parser.add_argument("--mouse-id", default=16, type=int, help="mouse id")

    parser.add_argument("--shuffle-trials", action="store_true", help="whether to use shuffled trials")
    parser.add_argument("--shuffle-seed", default=1111, type=int, help="random seed for shuffling")
    parser.add_argument("--sample-neurons", default=0, type=int, help="number of sampled neurons ('0' indicates no sampling)")
    parser.add_argument("--sample-seed", default=1111, type=int, help="random seed for sampling")

    parser.add_argument("--model-name", default="tidespl_vae", type=str, help="name of model to train")
    parser.add_argument("--latent-dim", default=128, type=int, help="latent variable dimension")
    parser.add_argument("--latent-space", default="full", type=str, choices=["full", "content", "style"], help="latent variable space")
    parser.add_argument("--latent-reduc", default="PCA", type=str, choices=["PCA", "TSNE"], help="dimension reduction of latent variable space")
    parser.add_argument("--checkpoint-path", default=None, help="path to load checkpoint")
    parser.add_argument("--seq-len", default=4, type=int, help="sequence length for sequential model")

    parser.add_argument("--time-window", default=1.0, type=float, help="for computing frame accuracy within a time window")

    parser.add_argument("--iterations", default=20000, type=int, help="number of iterations to train")
    parser.add_argument("--lr", default=1e-4, type=float, help="initial learning rate")

    parser.add_argument("--data-split", default="test", type=str, choices=["all", "train", "test"], help="split part of the dataset")
    parser.add_argument("--batch-size", default=288, type=int, help="number of images per gpu")

    parser.add_argument("--seed", action="store_true", help="whether to use random seed")
    parser.add_argument("--repeat", default=10, type=int, help="experiment for repeat")
    parser.add_argument("--output-path", default="logs/", help="path to save outputs")
    parser.add_argument("--device", default="cuda:0", help="device for train and test")
    args = parser.parse_args()
    return args


def load_mouse_dataset(args):
    suffix = ""
    if args.shuffle_trials:
        suffix += f"_trial_seed{args.shuffle_seed}"
    if args.sample_neurons:
        suffix += f"_neuron{args.sample_neurons}_seed{args.sample_seed}"
    mouse_dataset = np.load(os.path.join(args.data_path, args.stimulus, f"mouse{args.mouse_id}_ts{args.time_step}{suffix}.npy"))
    num_classes = mouse_dataset.shape[0]
    num_trials = mouse_dataset.shape[1]
    num_neurons = mouse_dataset.shape[-1]

    mouse_dataset = np.transpose(mouse_dataset, (1, 0, 2, 3))
    mouse_dataset = np.reshape(mouse_dataset, (num_trials, -1, num_neurons))
    mouse_dataset = np.reshape(mouse_dataset, (-1, num_neurons))
    labels = np.tile(np.arange(0, num_classes)[:, np.newaxis], (1, args.time_step)).flatten()
    labels = np.tile(labels, num_trials)

    x_true = torch.from_numpy(mouse_dataset)
    u_true = torch.from_numpy(labels)

    return x_true, u_true, num_classes, num_trials, num_neurons


def main():
    args = get_args()
    if args.seed:
        set_deterministic(_seed_)
    print(args)

    x_true, u_true, args.classes, args.trials, args.data_dim = load_mouse_dataset(args)
    print(f"Dataset Size: {x_true.size(0)}")

    if args.model_name in ["PCA", "TSNE", "CEBRA_time_delta_4", "CEBRA_time_4"]:
        mouse_data_tester = MouseReductionTester(args.latent_dim, args)
    else:
        mouse_data_tester = MouseNetworkTester(args.latent_dim, args)

    mouse_data_tester.test(x_true, u_true)


if __name__ == "__main__":
    main()
