import os
import numpy as np
import torch
from sklearn.neighbors import KNeighborsClassifier

from dataloader import ContrastiveLoader, SequentialLoader, ShuffleSequentialLoader
from trainer import Trainer
from utils import set_deterministic


_seed_ = 2024


class MouseTrainer(Trainer):
    def _set_log(self):
        logdir = f"Mouse_Visual_Data/{self.args.stimulus}/mouse{self.args.mouse_id}_ts{self.args.time_step}"
        suffix = ""
        seed_dir = ""
        if self.args.shuffle_trials:
            suffix += f"_random_shuffle"
            seed_dir += f"seed{self.args.shuffle_seed}/"
        if self.args.sample_neurons:
            suffix += f"_random_sample"
            seed_dir += f"neuron{self.args.sample_neurons}_seed{self.args.sample_seed}/"
        suffix += "/"
        logdir += f"{suffix}{seed_dir}"

        logdir += f"{self.args.model_name}_ld{self.args.latent_dim}"
        logdir += f"_aug{self.args.aug}" if self.args.aug > 0 else ""
        logdir += f"_sl{self.args.seq_len}/" if self.args.model_name in ["lfads", "tidespl_vae"] else "/"
        logdir += f"epochs{self.args.epochs}_bs{self.args.batch_size}_kldw{self.args.kld_weight}_"
        if self.args.model_name in ["swap_vae", "tidespl_vae"]:
            logdir += f"contw{self.args.cont_weight}_"
            if self.args.model_name in ["tidespl_vae"]:
                logdir += f"temp{self.args.temperature}_priorw{self.args.prior_weight}_"
        
        if self.args.opt == "adam":
            logdir += f"{self.args.opt}_lr{self.args.lr}_wd{self.args.weight_decay}"
        
        self.logdir = os.path.join(self.args.output_path, logdir)
        os.makedirs(self.logdir, exist_ok=True)
        
        with open(os.path.join(self.logdir, "args.txt"), 'w') as f:
            f.write(str(self.args))
    
    def _preprocess_inputs(self, inputs, train=True):
        if self.args.model_name in ["vae"]:
            inputs = {"x": inputs[0].to(self.device)}
        elif self.args.model_name in ["lfads"]:
            inputs = {"x": inputs[0].squeeze(0).permute(1, 0, 2).to(self.device)}
        elif self.args.model_name in ["pivae"]:
            inputs = {"x": inputs[0].to(self.device), "u": inputs[1].to(self.device)}
        elif self.args.model_name in ["swap_vae"]:
            if train:
                inputs = {"x1": inputs[0].squeeze(0).to(self.device), "x2": inputs[1].squeeze(0).to(self.device)}
            else:
                inputs = {"x1": inputs[0].to(self.device), "x2": inputs[0].to(self.device)}
        elif self.args.model_name in ["tidespl_vae"]:
            if train:
                inputs = {"x": inputs[0].squeeze(0).permute(1, 0, 2).to(self.device), "x_pos": inputs[1].squeeze(0).permute(1, 0, 2).to(self.device), "x_neg": inputs[2].squeeze(0).permute(1, 0, 2).to(self.device)}
            else:
                inputs = {"x": inputs[0].squeeze(0).permute(1, 0, 2).to(self.device), "x_pos": inputs[0].squeeze(0).permute(1, 0, 2).to(self.device), "x_neg": inputs[0].squeeze(0).permute(1, 0, 2).to(self.device)}
        
        return inputs
    
    def _compute_score(self, z_predict, u_true):
        z_predict = z_predict.cpu().numpy()
        u_true = u_true.cpu().numpy()

        if self.args.model_name in ["lfads", "tidespl_vae"]:
            z_predict = np.reshape(z_predict, (self.args.trials * self.args.classes * self.args.time_step, self.args.seq_len, -1))
            z_predict = np.reshape(z_predict[:, -1], (self.args.trials * self.args.classes, self.args.time_step, -1))
            z_predict = np.mean(z_predict, axis=1)
        else:
            z_predict = np.reshape(z_predict, (self.args.trials * self.args.classes, self.args.time_step, -1))
            z_predict = np.mean(z_predict, axis=1)
        u_true = u_true[np.arange(0, len(u_true), self.args.time_step)]

        z_predict = (z_predict - np.min(z_predict, axis=0, keepdims=True)) / (np.max(z_predict, axis=0, keepdims=True) - np.min(z_predict, axis=0, keepdims=True))

        knn_model = KNeighborsClassifier()
        knn_model.fit(z_predict, u_true)
        u_predict = knn_model.predict(z_predict)
        u_errors = u_predict - u_true
        score = np.sum(np.abs(u_errors) <= (1 * 30)) / len(u_true)
        print(f"Classification Score: {score}")
        return score


def get_args():
    import argparse
    
    parser = argparse.ArgumentParser()
    parser.add_argument("--data-path", default="neural_dataset/", type=str, help="dataset path")
    parser.add_argument("--stimulus", default="allen_natural_movie_one", type=str, help="type of visual stimulus")
    parser.add_argument("--time-step", default=4, type=int, help="time steps")
    parser.add_argument("--mouse-id", default=16, type=int, help="mouse id")

    parser.add_argument("--shuffle-trials", action="store_true", help="whether to use shuffled trials")
    parser.add_argument("--shuffle-seed", default=1111, type=int, help="random seed for shuffling")
    parser.add_argument("--sample-neurons", default=0, type=int, help="number of sampled neurons ('0' indicates no sampling)")
    parser.add_argument("--sample-seed", default=1111, type=int, help="random seed for sampling")

    parser.add_argument("--model-name", default="tidespl_vae", type=str, help="name of model to train")
    parser.add_argument("--latent-dim", default=128, type=int, help="latent variable dimension")
    parser.add_argument("--kld-weight", default=1, type=float, help="weight of kld loss")
    parser.add_argument("--cont-weight", default=1, type=float, help="weight of contrastive loss")
    parser.add_argument("--temperature", default=0.2, type=float, help="temperature for infonce loss")
    parser.add_argument("--prior-weight", default=0, type=float, help="weight of prior penalty")
    parser.add_argument("--aug", default=2, type=int, help="whether to use positive pair")
    parser.add_argument("--seq-len", default=4, type=int, help="sequence length for sequential model")
    
    parser.add_argument("--epochs", default=200, type=int, help="number of epochs to train")
    parser.add_argument("--batch-size", default=288, type=int, help="number of images per gpu")

    parser.add_argument("--opt", default="adam", type=str, choices=["adam", "sgd"], help="optimizer")
    parser.add_argument("--lr", default=1e-4, type=float, help="initial learning rate")
    parser.add_argument("--weight-decay", default=0, type=float, help="weight decay (L2 penalty)")

    parser.add_argument("--seed", action="store_true", help="whether to use random seed")
    parser.add_argument("--repeat", default=10, type=int, help="experiment for repeat")

    parser.add_argument("--output-path", default="logs/", help="path to save outputs")
    parser.add_argument("--print-freq", default=50, type=int, help="print frequency")
    
    parser.add_argument("--device", default="cuda:0", help="device for train and test")
    args = parser.parse_args()
    return args


def load_mouse_dataset(args):
    suffix = ""
    if args.shuffle_trials:
        suffix += f"_trial_seed{args.shuffle_seed}"
    if args.sample_neurons:
        suffix += f"_neuron{args.sample_neurons}_seed{args.sample_seed}"
    mouse_dataset = np.load(os.path.join(args.data_path, args.stimulus, f"mouse{args.mouse_id}_ts{args.time_step}{suffix}.npy"))
    num_classes = mouse_dataset.shape[0]
    num_trials = mouse_dataset.shape[1]
    num_neurons = mouse_dataset.shape[-1]

    mouse_dataset = np.transpose(mouse_dataset, (1, 0, 2, 3))
    mouse_dataset = np.reshape(mouse_dataset, (num_trials, -1, num_neurons))
    mouse_dataset = np.reshape(mouse_dataset, (-1, num_neurons))
    labels = np.tile(np.arange(0, num_classes)[:, np.newaxis], (1, args.time_step)).flatten()
    labels = np.tile(labels, num_trials)

    x_true = torch.from_numpy(mouse_dataset)
    u_true = torch.from_numpy(labels)

    return x_true, u_true, num_classes, num_trials, num_neurons


def set_pair_range(num_sample, args):
    if args.aug == 0:
        return np.zeros((num_sample, 2), dtype=int)
    
    pair_range = np.stack([np.ones(num_sample, dtype=int) * -args.aug, np.ones(num_sample, dtype=int) * args.aug], axis=1)
    start_index = np.arange(0, num_sample, args.classes * args.time_step)
    if args.model_name in ["swap_vae"]:
        for i in range(args.aug):
            pair_range[start_index + i, 0] = -i
            pair_range[start_index + args.classes * args.time_step - 1 - i, 1] = i
    elif args.model_name in ["tidespl_vae"]:
        pair_range[(start_index[:, np.newaxis] + np.arange(args.classes * args.time_step - args.seq_len + 1, args.classes * args.time_step)).flatten()] = 0
        for i in range(args.aug):
            pair_range[start_index + i, 0] = -i
            pair_range[start_index + args.classes * args.time_step - args.seq_len - i, 1] = i
    
    return pair_range


def load_dataloader(x_true, u_true, args):
    num_sample = len(x_true)
    x_train = x_true[:num_sample // 5 * 4]
    u_train = u_true[:num_sample // 5 * 4]
    x_test = x_true[num_sample // 5 * 4:]
    u_test = u_true[num_sample // 5 * 4:]

    if args.model_name in ["vae", "pivae"]:
        train_set = torch.utils.data.TensorDataset(x_train, u_train)
        train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True)
    elif args.model_name in ["lfads"]:
        train_set = ShuffleSequentialLoader(x_train, u_train, args.batch_size, args.classes * args.time_step, args.seq_len)
        train_loader = torch.utils.data.DataLoader(train_set)
    elif args.model_name in ["swap_vae"]:
        pair_range = set_pair_range(len(x_train), args)
        train_set = ContrastiveLoader(x_train, u_train, pair_range, args.batch_size, neg=False)
        train_loader = torch.utils.data.DataLoader(train_set)
    elif args.model_name in ["tidespl_vae"]:
        pair_range = set_pair_range(len(x_train), args)
        train_set = ContrastiveLoader(x_train, u_train, pair_range, args.batch_size, neg=True, time_step=args.classes * args.time_step, seq_len=args.seq_len)
        train_loader = torch.utils.data.DataLoader(train_set)

    if args.model_name in ["vae", "pivae", "swap_vae"]:
        test_set = torch.utils.data.TensorDataset(x_test, u_test)
        test_loader = torch.utils.data.DataLoader(test_set, batch_size=args.batch_size)

        all_set = torch.utils.data.TensorDataset(x_true, u_true)
        all_loader = torch.utils.data.DataLoader(all_set, batch_size=args.batch_size)
    elif args.model_name in ["lfads", "tidespl_vae"]:
        x_test = torch.cat([torch.zeros(args.seq_len - 1, x_test.size(1)).to(x_test), x_test], dim=0)
        u_test = torch.cat([torch.zeros(args.seq_len - 1).to(u_test), u_test])
        test_set = SequentialLoader(x_test, u_test, args.batch_size, args.time_step, args.seq_len, interval=1)
        test_loader = torch.utils.data.DataLoader(test_set)

        x_true = torch.cat([torch.zeros(args.seq_len - 1, x_true.size(1)).to(x_true), x_true], dim=0)
        u_true = torch.cat([torch.zeros(args.seq_len - 1).to(u_true), u_true])
        all_set = SequentialLoader(x_true, u_true, args.batch_size, args.time_step, args.seq_len, interval=1)
        all_loader = torch.utils.data.DataLoader(all_set)

    return train_set, test_set, all_set, train_loader, test_loader, all_loader


def main():
    args = get_args()
    if args.seed:
        set_deterministic(_seed_)
    print(args)

    x_true, u_true, args.classes, args.trials, args.data_dim = load_mouse_dataset(args)
    train_set, test_set, all_set, train_loader, test_loader, all_loader = load_dataloader(x_true, u_true, args)
    print(f"Train Size: {len(train_set)}  Test/Validation Size: {len(test_set)}")

    scores = np.zeros(args.repeat)
    mouse_data_trainer = MouseTrainer(args)
    for i in range(args.repeat):
        mouse_data_trainer.train(train_loader, test_loader, i)
        scores[i] = mouse_data_trainer.test(all_loader, i, u_true=u_true)
        print()
    np.save(os.path.join(mouse_data_trainer.logdir, "classification_score.npy"), scores)


if __name__ == "__main__":
    main()
