import os
import torch
import numpy as np
from sklearn.linear_model import Ridge

from dataloader import ContrastiveLoader
from trainer import Trainer
from utils import set_deterministic


_seed_ = 2024


class SimulatedTrainer(Trainer):
    def _set_log(self):
        logdir = f"Simulated_Data/{self.args.dataset_name}/"
        logdir += f"{self.args.model_name}_ld{self.args.latent_dim}"
        logdir += "_aug/" if self.args.aug else "/"
        logdir += f"epochs{self.args.epochs}_bs{self.args.batch_size}_kldw{self.args.kld_weight}_"
        if self.args.model_name in ["swap_vae", "tidespl_vae"]:
            logdir += f"contw{self.args.cont_weight}_"
            if self.args.model_name in ["tidespl_vae"]:
                logdir += f"temp{self.args.temperature}_priorw{self.args.prior_weight}_"
        
        if self.args.opt == "adam":
            logdir += f"{self.args.opt}_lr{self.args.lr}_wd{self.args.weight_decay}"
        
        self.logdir = os.path.join(self.args.output_path, logdir)
        os.makedirs(self.logdir, exist_ok=True)
        
        with open(os.path.join(self.logdir, "args.txt"), 'w') as f:
            f.write(str(self.args))
    
    def _preprocess_inputs(self, inputs, train=True):
        if self.args.model_name in ["vae"]:
            inputs = {"x": inputs[0].to(self.device)}
        elif self.args.model_name in ["lfads"]:
            inputs = {"x": inputs[0].unsqueeze(0).to(self.device)}
        elif self.args.model_name in ["pivae"]:
            inputs = {"x": inputs[0].to(self.device), "u": inputs[1].to(self.device)}
        elif self.args.model_name in ["swap_vae"]:
            if train:
                inputs = {"x1": inputs[0].squeeze(0).to(self.device), "x2": inputs[1].squeeze(0).to(self.device)}
            else:
                inputs = {"x1": inputs[0].to(self.device), "x2": inputs[0].to(self.device)}
        elif self.args.model_name in ["tidespl_vae"]:
            if train:
                inputs = {"x": inputs[0].to(self.device), "x_pos": inputs[1].to(self.device), "x_neg": inputs[2].to(self.device)}
            else:
                inputs = {"x": inputs[0].unsqueeze(0).to(self.device), "x_pos": inputs[0].unsqueeze(0).to(self.device), "x_neg": inputs[0].unsqueeze(0).to(self.device)}
        
        return inputs
    
    def _compute_score(self, z_predict, z_true):
        z_predict = z_predict.cpu().numpy()
        z_true = z_true.cpu().numpy()
        regression_model = Ridge(alpha=1e-4, solver="svd", random_state=_seed_)
        regression_model.fit(z_predict, z_true)
        score = regression_model.score(z_predict, z_true)
        print(f"Linear Regression Score: {score}")
        return score


def get_args():
    import argparse
    
    parser = argparse.ArgumentParser()
    parser.add_argument("--data-path", default="Simulated_Data/", type=str, help="dataset path")
    parser.add_argument("--data-dim", default=100, type=int, help="data dimension")
    parser.add_argument("--classes", default=0, type=int, help="number of classes")
    parser.add_argument("--version", default=3, type=int, help="version of dataset")

    parser.add_argument("--model-name", default="tidespl_vae", type=str, help="name of model to train")
    parser.add_argument("--latent-dim", default=32, type=int, help="latent variable dimension")
    parser.add_argument("--kld-weight", default=1, type=float, help="weight of kld loss")
    parser.add_argument("--cont-weight", default=1, type=float, help="weight of contrastive loss")
    parser.add_argument("--temperature", default=1, type=float, help="temperature for infonce loss")
    parser.add_argument("--prior-weight", default=1, type=float, help="weight of prior penalty")
    parser.add_argument("--aug", action="store_true", help="whether to use positive pair")
    
    parser.add_argument("--epochs", default=500, type=int, help="number of epochs to train")
    parser.add_argument("--batch-size", default=320, type=int, help="number of images per gpu")

    parser.add_argument("--opt", default="adam", type=str, choices=["adam", "sgd"], help="optimizer")
    parser.add_argument("--lr", default=5e-4, type=float, help="initial learning rate")
    parser.add_argument("--weight-decay", default=0, type=float, help="weight decay (L2 penalty)")

    parser.add_argument("--seed", action="store_true", help="whether to use random seed")
    parser.add_argument("--repeat", default=10, type=int, help="experiment for repeat")

    parser.add_argument("--output-path", default="logs/", help="path to save outputs")
    parser.add_argument("--print-freq", default=100, type=int, help="print frequency")
    
    parser.add_argument("--device", default="cuda:0", help="device for train and test")
    args = parser.parse_args()
    return args


def load_simulated_dataset(args):
    if args.classes > 0:
        dataset_name = f"{args.data_dim}d_poisson_{args.classes}cls_v{args.version}"
    else:
        dataset_name = f"{args.data_dim}d_poisson_v{args.version}"
    simulated_dataset = torch.load(os.path.join(args.data_path, f"{dataset_name}.pt"), map_location="cpu")
    
    x_true = simulated_dataset["x"]
    u_true = simulated_dataset["u"]
    u_true = u_true if args.classes > 0 else u_true.unsqueeze(-1)
    z_true = simulated_dataset["z"]

    return x_true, u_true, z_true, dataset_name


def set_pair_range(num_sample, args):
    if not args.aug:
        return np.zeros((num_sample, 2), dtype=int)
    
    if args.dataset_name == f"{args.data_dim}d_poisson_{args.classes}cls_v1":
        pair_range = np.zeros((num_sample, 2), dtype=int)
    elif args.dataset_name == f"{args.data_dim}d_poisson_v1":
        pair_range = np.zeros((num_sample, 2), dtype=int)
    elif args.dataset_name == f"{args.data_dim}d_poisson_v2":
        pair_range = np.stack([np.ones(num_sample, dtype=int) * -3, np.ones(num_sample, dtype=int) * 3], axis=1)
        start_index = np.arange(0, num_sample, 32)
        for i in range(3):
            pair_range[start_index + i, 0] = -i
            pair_range[start_index + 31 - i, 1] = i
    elif args.dataset_name == f"{args.data_dim}d_poisson_v3":
        left = np.tile(np.arange(0, -4, -1)[np.newaxis, :], (num_sample // 4, 1)).flatten()
        right = np.tile(np.arange(3, -1, -1)[np.newaxis, :], (num_sample // 4, 1)).flatten()
        pair_range = np.stack([left, right], axis=1)
    
    return pair_range


def load_dataloader(x_true, u_true, args):
    num_sample = len(x_true)
    x_train = x_true[:num_sample // 5 * 4]
    u_train = u_true[:num_sample // 5 * 4]
    x_test = x_true[num_sample // 5 * 4:]
    u_test = u_true[num_sample // 5 * 4:]

    if args.model_name in ["vae", "lfads", "pivae"]:
        train_set = torch.utils.data.TensorDataset(x_train, u_train)
        train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True)
    elif args.model_name in ["swap_vae"]:
        pair_range = set_pair_range(len(x_train), args)
        train_set = ContrastiveLoader(x_train, u_train, pair_range, args.batch_size, neg=False)
        train_loader = torch.utils.data.DataLoader(train_set)
    elif args.model_name in ["tidespl_vae"]:
        pair_range = set_pair_range(len(x_train), args)
        train_set = ContrastiveLoader(x_train, u_train, pair_range, args.batch_size, neg=True)
        train_loader = torch.utils.data.DataLoader(train_set)

    test_set = torch.utils.data.TensorDataset(x_test, u_test)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=args.batch_size)

    all_set = torch.utils.data.TensorDataset(x_true, u_true)
    all_loader = torch.utils.data.DataLoader(all_set, batch_size=args.batch_size)

    return train_set, test_set, all_set, train_loader, test_loader, all_loader


def main():
    args = get_args()
    if args.seed:
        set_deterministic(_seed_)
    print(args)

    x_true, u_true, z_true, args.dataset_name = load_simulated_dataset(args)
    train_set, test_set, all_set, train_loader, test_loader, all_loader = load_dataloader(x_true, u_true, args)
    print(f"Train Size: {len(train_set)}  Test/Validation Size: {len(test_set)}")

    scores = np.zeros(args.repeat)
    simulated_data_trainer = SimulatedTrainer(args)
    for i in range(args.repeat):
        simulated_data_trainer.train(train_loader, test_loader, i)
        scores[i] = simulated_data_trainer.test(all_loader, i, z_true=z_true)
        print()
    np.save(os.path.join(simulated_data_trainer.logdir, "regression_score.npy"), scores)


if __name__ == "__main__":
    main()
