import os
import torch
import numpy as np
from sklearn.linear_model import Ridge

from dataloader import SequentialLoader
from tester import Tester, NetworkTester, ReductionTester
from utils import set_deterministic


_seed_ = 2024


class LorenzTester(Tester):
    def _set_output_dir(self):
        if self.args.model_name in ["PCA", "TSNE", "CEBRA_time_1", "CEBRA_time_5", "CEBRA_time_10"]:
            if not self.args.shuffle:
                dataset_name = f"Lorenz_con{self.args.conditions}"
            else:
                dataset_name = f"Lorenz_con{self.args.conditions}_shuffle"
            self.output_dir = os.path.join(self.args.output_path, "Simulated_Data", dataset_name, f"{self.args.model_name}_ld{self.args.latent_dim}")
            if self.args.model_name in ["CEBRA_time_1", "CEBRA_time_5", "CEBRA_time_10"]:
                self.output_dir = os.path.join(self.output_dir, f"iterations{self.args.iterations}_bs{self.args.batch_size}_lr{self.args.lr}")
            os.makedirs(self.output_dir, exist_ok=True)
        else:
            assert self.args.checkpoint_path is not None
            self.output_dir = self.args.checkpoint_path

    def test(self, x_true, z_true, u_true):
        z_true = z_true.numpy()
        num_sample = len(z_true)
        z_true_train = z_true[:num_sample // 5 * 4]
        if self.args.data_split == "train":
            z_true = z_true[:num_sample // 5 * 4]
        elif self.args.data_split == "test":
            z_true = z_true[num_sample // 5 * 4:]
        z_best_predict = None
        z_best_predict_regress = None
        max_score = -np.inf
        scores = np.zeros(self.args.repeat)

        for i in range(self.args.repeat):
            print(f"Start computing latent [repeat {i}]...")
            self._load_model(repeat=i)
            z_predict = self._transform(x_true, u_true, repeat=i)
            
            if self.args.model_name in ["lfads", "tidespl_vae"]:
                z_predict = np.reshape(z_predict, (z_predict.shape[0], self.args.seq_len, -1))
                z_predict = np.reshape(z_predict, (-1, z_predict.shape[-1]))
            if self.args.model_name in ["swap_vae", "tidespl_vae"]:
                if self.args.latent_space == "content":
                    z_predict = z_predict[:, :self.args.latent_dim // 2]
                elif self.args.latent_space == "style":
                    z_predict = z_predict[:, self.args.latent_dim // 2:]

            print("Start testing...")
            z_predict_train = z_predict[:num_sample // 5 * 4]
            if self.args.data_split == "train":
                z_predict = z_predict[:num_sample // 5 * 4]
            elif self.args.data_split == "test":
                z_predict = z_predict[num_sample // 5 * 4:]
            regression_model = Ridge(alpha=1e-4, solver="svd", random_state=_seed_)
            regression_model.fit(z_predict_train, z_true_train)
            score = regression_model.score(z_predict, z_true)
            print(f"Linear Regression Score: {score}")
            scores[i] = score
            if score > max_score:
                max_score = score
                z_best_predict = z_predict
                z_best_predict_regress = regression_model.predict(z_predict)
            print()
        print(f"Average Linear Regression Score: {np.mean(scores)}")
        np.save(os.path.join(self.output_dir, f"regression_score{self.suffix_for_score}.npy"), scores)
        # np.save(os.path.join(self.output_dir, f"regress_latent{self.suffix_for_score}.npy"), z_best_predict_regress)


class LorenzNetworkTester(LorenzTester, NetworkTester):
    def _load_dataloader(self, x_true, u_true):
        if self.args.model_name in ["vae", "pivae", "swap_vae"]:
            all_set = torch.utils.data.TensorDataset(x_true, u_true)
            all_loader = torch.utils.data.DataLoader(all_set, batch_size=self.args.batch_size)
        elif self.args.model_name in ["lfads", "tidespl_vae"]:
            all_set = SequentialLoader(x_true, u_true, 5, self.args.time_step, self.args.seq_len, interval=self.args.seq_len)
            all_loader = torch.utils.data.DataLoader(all_set)

        return all_set, all_loader

    def _preprocess_inputs(self, inputs):
        if self.args.model_name in ["vae"]:
            inputs = {"x": inputs[0].to(self.device)}
        elif self.args.model_name in ["lfads"]:
            inputs = {"x": inputs[0].squeeze(0).permute(1, 0, 2).to(self.device)}
        elif self.args.model_name in ["pivae"]:
            inputs = {"x": inputs[0].to(self.device), "u": inputs[1].to(self.device)}
        elif self.args.model_name in ["swap_vae"]:
            inputs = {"x1": inputs[0].to(self.device), "x2": inputs[0].to(self.device)}
        elif self.args.model_name in ["tidespl_vae"]:
            inputs = {"x": inputs[0].squeeze(0).permute(1, 0, 2).to(self.device), "x_pos": inputs[0].squeeze(0).permute(1, 0, 2).to(self.device), "x_neg": inputs[0].squeeze(0).permute(1, 0, 2).to(self.device)}
        
        return inputs


class LorenzReductionTester(LorenzTester, ReductionTester):
    def _load_train_set(self, x_true, u_true):
        x_true = x_true.numpy()
        u_true = u_true.numpy()
        if self.args.model_name in ["CEBRA_time_1", "CEBRA_time_5", "CEBRA_time_10"]:
            u_true = u_true.astype("float")
        num_sample = len(x_true)
        x_train = x_true[:num_sample // 5 * 4]
        u_train = u_true[:num_sample // 5 * 4]

        return x_true, u_true, x_train, u_train


def get_args():
    import argparse
    
    parser = argparse.ArgumentParser()
    parser.add_argument("--data-path", default="Simulated_Data/", type=str, help="dataset path")
    parser.add_argument("--conditions", default=5, type=int, help="number of conditions")
    parser.add_argument("--shuffle", action="store_true", help="whether to use shuffle data")

    parser.add_argument("--model-name", default="tidespl_vae", type=str, help="name of model to train")
    parser.add_argument("--latent-dim", default=8, type=int, help="latent variable dimension")
    parser.add_argument("--latent-space", default="full", type=str, choices=["full", "content", "style"], help="latent variable space")
    parser.add_argument("--latent-reduc", default="PCA", type=str, choices=["PCA", "TSNE"], help="dimension reduction of latent variable space")
    parser.add_argument("--checkpoint-path", default=None, help="path to load checkpoint")
    parser.add_argument("--seq-len", default=50, type=int, help="sequence length for sequential model")

    parser.add_argument("--iterations", default=20000, type=int, help="number of iterations to train")
    parser.add_argument("--lr", default=1e-3, type=float, help="initial learning rate")

    parser.add_argument("--data-split", default="test", type=str, choices=["all", "train", "test"], help="split part of the dataset")
    parser.add_argument("--batch-size", default=500, type=int, help="number of images per gpu")

    parser.add_argument("--seed", action="store_true", help="whether to use random seed")
    parser.add_argument("--repeat", default=10, type=int, help="experiment for repeat")
    parser.add_argument("--output-path", default="logs/", help="path to save outputs")
    parser.add_argument("--device", default="cuda:0", help="device for train and test")
    args = parser.parse_args()
    return args


def load_lorenz_dataset(args):
    suffix = ""
    if args.shuffle:
        suffix += "_shuffle"
    x_true = np.load(os.path.join(args.data_path, f"lorenz_spike_data{suffix}.npy")).astype(np.float32)
    z_true = np.load(os.path.join(args.data_path, f"lorenz_latent_data{suffix}.npy")).astype(np.float32)
    x_true = x_true[:args.conditions]
    z_true = z_true[:args.conditions]
    x_true = x_true.transpose((1, 0, 2, 3))
    z_true = z_true.transpose((1, 0, 2, 3))
    num_trials = x_true.shape[0]
    time_step = x_true.shape[2]
    num_neurons = x_true.shape[3]
    u_true = np.tile(np.arange(args.conditions)[np.newaxis, :, np.newaxis], (num_trials, 1, time_step))
    
    x_true = np.reshape(x_true, (num_trials, -1, num_neurons))
    z_true = np.reshape(z_true, (num_trials, -1, 3))
    u_true = np.reshape(u_true, (num_trials, -1))
    x_true = np.reshape(x_true, (-1, num_neurons))
    z_true = np.reshape(z_true, (-1, 3))
    u_true = u_true.flatten()

    x_true = torch.from_numpy(x_true)
    z_true = torch.from_numpy(z_true)
    u_true = torch.from_numpy(u_true)

    return x_true, z_true, u_true, num_trials, time_step, num_neurons


def main():
    args = get_args()
    if args.seed:
        set_deterministic(_seed_)
    print(args)

    x_true, z_true, u_true, args.trials, args.time_step, args.data_dim = load_lorenz_dataset(args)
    args.classes = args.conditions
    print(f"Dataset Size: {x_true.size(0)}")

    if args.model_name in ["PCA", "TSNE", "CEBRA_time_1", "CEBRA_time_5", "CEBRA_time_10"]:
        lorenz_data_tester = LorenzReductionTester(args.latent_dim, args)
    else:
        lorenz_data_tester = LorenzNetworkTester(args.latent_dim, args)
    lorenz_data_tester.test(x_true, z_true, u_true)


if __name__ == "__main__":
    main()
