import os
import numpy as np
import torch
from sklearn.linear_model import Ridge

from dataloader import ContrastiveLoader, SequentialLoader, ShuffleSequentialLoader
from trainer import Trainer
from utils import set_deterministic


_seed_ = 2024


class LorenzTrainer(Trainer):
    def _set_log(self):
        if not self.args.shuffle:
            logdir = f"Simulated_Data/Lorenz_con{self.args.conditions}/"
        else:
            logdir = f"Simulated_Data/Lorenz_con{self.args.conditions}_shuffle/"
        logdir += f"{self.args.model_name}_ld{self.args.latent_dim}"
        logdir += f"_aug{self.args.aug}" if self.args.aug > 0 else ""
        logdir += f"_sl{self.args.seq_len}/" if self.args.model_name in ["lfads", "tidespl_vae"] else "/"
        logdir += f"epochs{self.args.epochs}_bs{self.args.batch_size}_kldw{self.args.kld_weight}_"
        if self.args.model_name in ["swap_vae", "tidespl_vae"]:
            logdir += f"contw{self.args.cont_weight}_"
            if self.args.model_name in ["tidespl_vae"]:
                logdir += f"temp{self.args.temperature}_priorw{self.args.prior_weight}_"
        
        if self.args.opt == "adam":
            logdir += f"{self.args.opt}_lr{self.args.lr}_wd{self.args.weight_decay}"
        
        self.logdir = os.path.join(self.args.output_path, logdir)
        os.makedirs(self.logdir, exist_ok=True)
        
        with open(os.path.join(self.logdir, "args.txt"), 'w') as f:
            f.write(str(self.args))

    def _preprocess_inputs(self, inputs, train=True):
        if self.args.model_name in ["vae"]:
            inputs = {"x": inputs[0].to(self.device)}
        elif self.args.model_name in ["lfads"]:
            inputs = {"x": inputs[0].squeeze(0).permute(1, 0, 2).to(self.device)}
        elif self.args.model_name in ["pivae"]:
            inputs = {"x": inputs[0].to(self.device), "u": inputs[1].to(self.device)}
        elif self.args.model_name in ["swap_vae"]:
            if train:
                inputs = {"x1": inputs[0].squeeze(0).to(self.device), "x2": inputs[1].squeeze(0).to(self.device)}
            else:
                inputs = {"x1": inputs[0].to(self.device), "x2": inputs[0].to(self.device)}
        elif self.args.model_name in ["tidespl_vae"]:
            if train:
                inputs = {"x": inputs[0].squeeze(0).permute(1, 0, 2).to(self.device), "x_pos": inputs[1].squeeze(0).permute(1, 0, 2).to(self.device), "x_neg": inputs[2].squeeze(0).permute(1, 0, 2).to(self.device)}
            else:
                inputs = {"x": inputs[0].squeeze(0).permute(1, 0, 2).to(self.device), "x_pos": inputs[0].squeeze(0).permute(1, 0, 2).to(self.device), "x_neg": inputs[0].squeeze(0).permute(1, 0, 2).to(self.device)}
        
        return inputs
    
    def _compute_score(self, z_predict, z_true):
        z_predict = z_predict.cpu().numpy()
        z_true = z_true.cpu().numpy()

        if self.args.model_name in ["lfads", "tidespl_vae"]:
            z_predict = np.reshape(z_predict, (z_predict.shape[0], self.args.seq_len, -1))
            z_predict = np.reshape(z_predict, (-1, z_predict.shape[-1]))
        
        regression_model = Ridge(alpha=1e-4, solver="svd", random_state=_seed_)
        regression_model.fit(z_predict, z_true)
        score = regression_model.score(z_predict, z_true)
        print(f"Linear Regression Score: {score}")
        return score


def get_args():
    import argparse
    
    parser = argparse.ArgumentParser()
    parser.add_argument("--data-path", default="Simulated_Data/", type=str, help="dataset path")
    parser.add_argument("--conditions", default=5, type=int, help="number of conditions")
    parser.add_argument("--shuffle", action="store_true", help="whether to use shuffle data")

    parser.add_argument("--model-name", default="tidespl_vae", type=str, help="name of model to train")
    parser.add_argument("--latent-dim", default=8, type=int, help="latent variable dimension")
    parser.add_argument("--kld-weight", default=1, type=float, help="weight of kld loss")
    parser.add_argument("--cont-weight", default=1, type=float, help="weight of contrastive loss")
    parser.add_argument("--temperature", default=1, type=float, help="temperature for infonce loss")
    parser.add_argument("--prior-weight", default=0.05, type=float, help="weight of prior penalty")
    parser.add_argument("--aug", default=5, type=int, help="whether to use positive pair")
    parser.add_argument("--seq-len", default=50, type=int, help="sequence length for sequential model")
    
    parser.add_argument("--epochs", default=100, type=int, help="number of epochs to train")
    parser.add_argument("--batch-size", default=500, type=int, help="number of images per gpu")

    parser.add_argument("--opt", default="adam", type=str, choices=["adam", "sgd"], help="optimizer")
    parser.add_argument("--lr", default=1e-3, type=float, help="initial learning rate")
    parser.add_argument("--weight-decay", default=0, type=float, help="weight decay (L2 penalty)")

    parser.add_argument("--seed", action="store_true", help="whether to use random seed")
    parser.add_argument("--repeat", default=5, type=int, help="experiment for repeat")

    parser.add_argument("--output-path", default="logs/", help="path to save outputs")
    parser.add_argument("--print-freq", default=20, type=int, help="print frequency")
    
    parser.add_argument("--device", default="cuda:0", help="device for train and test")
    args = parser.parse_args()
    return args


def load_lorenz_dataset(args):
    suffix = ""
    if args.shuffle:
        suffix += "_shuffle"
    x_true = np.load(os.path.join(args.data_path, f"lorenz_spike_data{suffix}.npy")).astype(np.float32)
    z_true = np.load(os.path.join(args.data_path, f"lorenz_latent_data{suffix}.npy")).astype(np.float32)
    x_true = x_true[:args.conditions]
    z_true = z_true[:args.conditions]
    x_true = x_true.transpose((1, 0, 2, 3))
    z_true = z_true.transpose((1, 0, 2, 3))
    num_trials = x_true.shape[0]
    time_step = x_true.shape[2]
    num_neurons = x_true.shape[3]
    u_true = np.tile(np.arange(args.conditions)[np.newaxis, :, np.newaxis], (num_trials, 1, time_step))
    
    x_true = np.reshape(x_true, (num_trials, -1, num_neurons))
    z_true = np.reshape(z_true, (num_trials, -1, 3))
    u_true = np.reshape(u_true, (num_trials, -1))
    x_true = np.reshape(x_true, (-1, num_neurons))
    z_true = np.reshape(z_true, (-1, 3))
    u_true = u_true.flatten()

    x_true = torch.from_numpy(x_true)
    z_true = torch.from_numpy(z_true)
    u_true = torch.from_numpy(u_true)

    return x_true, z_true, u_true, num_trials, time_step, num_neurons


def set_pair_range(num_sample, args):
    if args.aug == 0:
        return np.zeros((num_sample, 2), dtype=int)
    
    pair_range = np.stack([np.ones(num_sample, dtype=int) * -args.aug, np.ones(num_sample, dtype=int) * args.aug], axis=1)
    start_index = np.arange(0, num_sample, args.time_step)
    if args.model_name in ["swap_vae"]:
        for i in range(args.aug):
            pair_range[start_index + i, 0] = -i
            pair_range[start_index + args.time_step - 1 - i, 1] = i
    elif args.model_name in ["tidespl_vae"]:
        pair_range[(start_index[:, np.newaxis] + np.arange(args.time_step - args.seq_len + 1, args.time_step)).flatten()] = 0
        for i in range(args.aug):
            pair_range[start_index + i, 0] = -i
            pair_range[start_index + args.time_step - args.seq_len - i, 1] = i

    return pair_range


def load_dataloader(x_true, u_true, args):
    num_sample = len(x_true)
    x_train = x_true[:num_sample // 5 * 4]
    u_train = u_true[:num_sample // 5 * 4]
    x_test = x_true[num_sample // 5 * 4:]
    u_test = u_true[num_sample // 5 * 4:]

    if args.model_name in ["vae", "pivae"]:
        train_set = torch.utils.data.TensorDataset(x_train, u_train)
        train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True)
    elif args.model_name in ["lfads"]:
        train_set = ShuffleSequentialLoader(x_train, u_train, args.batch_size, args.time_step, args.seq_len)
        train_loader = torch.utils.data.DataLoader(train_set)
    elif args.model_name in ["swap_vae"]:
        pair_range = set_pair_range(len(x_train), args)
        train_set = ContrastiveLoader(x_train, u_train, pair_range, args.batch_size, neg=False)
        train_loader = torch.utils.data.DataLoader(train_set)
    elif args.model_name in ["tidespl_vae"]:
        pair_range = set_pair_range(len(x_train), args)
        train_set = ContrastiveLoader(x_train, u_train, pair_range, args.batch_size, neg=True, time_step=args.time_step, seq_len=args.seq_len)
        train_loader = torch.utils.data.DataLoader(train_set)
    
    if args.model_name in ["vae", "pivae", "swap_vae"]:
        test_set = torch.utils.data.TensorDataset(x_test, u_test)
        test_loader = torch.utils.data.DataLoader(test_set, batch_size=args.batch_size)

        all_set = torch.utils.data.TensorDataset(x_true, u_true)
        all_loader = torch.utils.data.DataLoader(all_set, batch_size=args.batch_size)
    elif args.model_name in ["lfads", "tidespl_vae"]:
        test_set = SequentialLoader(x_test, u_test, 5, args.time_step, args.seq_len, interval=args.seq_len)
        test_loader = torch.utils.data.DataLoader(test_set)
        
        all_set = SequentialLoader(x_true, u_true, 5, args.time_step, args.seq_len, interval=args.seq_len)
        all_loader = torch.utils.data.DataLoader(all_set)
    
    return train_set, test_set, all_set, train_loader, test_loader, all_loader


def main():
    args = get_args()
    if args.seed:
        set_deterministic(_seed_)
    print(args)

    x_true, z_true, u_true, args.trials, args.time_step, args.data_dim = load_lorenz_dataset(args)
    args.classes = args.conditions
    train_set, test_set, all_set, train_loader, test_loader, all_loader = load_dataloader(x_true, u_true, args)
    print(f"Train Size: {len(train_set)}  Test/Validation Size: {len(test_set)}")

    scores = np.zeros(args.repeat)
    lorenz_data_trainer = LorenzTrainer(args)
    for i in range(args.repeat):
        lorenz_data_trainer.train(train_loader, test_loader, i)
        scores[i] = lorenz_data_trainer.test(all_loader, i, z_true=z_true)
        print()
    np.save(os.path.join(lorenz_data_trainer.logdir, "regression_score.npy"), scores)


if __name__ == "__main__":
    main()
