import os
import torch
import numpy as np
from sklearn.linear_model import Ridge
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

from tester import Tester, NetworkTester, ReductionTester
from utils import set_deterministic


_seed_ = 2024


class SimulatedTester(Tester):
    def _set_output_dir(self):
        if self.args.model_name in ["PCA", "TSNE", "CEBRA_delta"]:
            self.output_dir = os.path.join(self.args.output_path, "Simulated_Data", self.args.dataset_name, f"{self.args.model_name}_ld{self.args.latent_dim}")
            if self.args.model_name in ["CEBRA_delta"]:
                self.output_dir = os.path.join(self.output_dir, f"iterations{self.args.iterations}_bs{self.args.batch_size}_lr{self.args.lr}")
            os.makedirs(self.output_dir, exist_ok=True)
        else:
            assert self.args.checkpoint_path is not None
            self.output_dir = self.args.checkpoint_path

    def _preprocess_plot(self, z1_true, z2_true, z1_predict, z2_predict):
        transforms_type = [
            [1, 1, False], [-1, 1, False], [1, -1, False], [-1, -1, False],
            [1, 1, True], [-1, 1, True], [1, -1, True], [-1, -1, True]
        ]
        _min = np.inf
        k = 0
        for i in range(len(transforms_type)):
            z1 = z1_predict * transforms_type[i][0]
            z2 = z2_predict * transforms_type[i][1]
            if transforms_type[i][2]:
                t = z1
                z1 = z2
                z2 = t
            distance = np.mean(np.sqrt((z1_true - z1) ** 2 + (z2_true - z2) ** 2))
            if distance < _min:
                _min = distance
                k = i
        print(f"Plot Transform: {k}")
        z1 = z1_predict * transforms_type[k][0]
        z2 = z2_predict * transforms_type[k][1]
        if transforms_type[k][2]:
            t = z1
            z1 = z2
            z2 = t
        return z1, z2

    def test(self, x_true, u_true, z_true):
        z_true = z_true.numpy()
        num_sample = len(z_true)
        z_true_train = z_true[:num_sample // 5 * 4]
        if self.args.data_split == "train":
            z_true = z_true[:num_sample // 5 * 4]
        elif self.args.data_split == "test":
            z_true = z_true[num_sample // 5 * 4:]
        z_best_predict = None
        z_best_predict_regress = None
        max_score = -np.inf
        scores = np.zeros(self.args.repeat)
        
        for i in range(self.args.repeat):
            print(f"Start computing latent [repeat {i}]...")
            self._load_model(repeat=i)
            z_predict = self._transform(x_true, u_true, repeat=i)
            if self.args.model_name in ["swap_vae", "tidespl_vae"]:
                if self.args.latent_space == "content":
                    z_predict = z_predict[:, :self.args.latent_dim // 2]
                elif self.args.latent_space == "style":
                    z_predict = z_predict[:, self.args.latent_dim // 2:]

            print("Start testing...")
            z_predict_train = z_predict[:num_sample // 5 * 4]
            if self.args.data_split == "train":
                z_predict = z_predict[:num_sample // 5 * 4]
            elif self.args.data_split == "test":
                z_predict = z_predict[num_sample // 5 * 4:]
            regression_model = Ridge(alpha=1e-4, solver="svd", random_state=_seed_)
            regression_model.fit(z_predict_train, z_true_train)
            score = regression_model.score(z_predict, z_true)
            print(f"Linear Regression Score: {score}")
            scores[i] = score
            if score > max_score:
                max_score = score
                z_best_predict = z_predict
                z_best_predict_regress = regression_model.predict(z_predict)
            print()
        print(f"Average Linear Regression Score: {np.mean(scores)}")
        if not self.args.only_plot:
            np.save(os.path.join(self.output_dir, f"regression_score{self.suffix_for_score}.npy"), scores)
            # np.save(os.path.join(self.output_dir, f"regress_latent{self.suffix_for_score}.npy"), z_best_predict_regress)
        if self.args.only_plot:
            if z_best_predict.shape[-1] > 2:
                reduc = eval(f"{self.args.latent_reduc}")(n_components=2, random_state=_seed_)
                z_predict = reduc.fit_transform(z_best_predict)
            else:
                z_predict = z_best_predict
            z1_predict, z2_predict = self._preprocess_plot(z_true[:, 0], z_true[:, 1], z_predict[:, 0], z_predict[:, 1])
            z_predict[:, 0] = z1_predict
            z_predict[:, 1] = z2_predict
        else:
            z_predict = z_best_predict_regress

        print("Start ploting...")
        if self.args.data_split == "train":
            u = u_true[:num_sample // 5 * 4].cpu().numpy()
        elif self.args.data_split == "test":
            u = u_true[num_sample // 5 * 4:].cpu().numpy()
        else:
            u = u_true.cpu().numpy()
        if self.args.classes > 0:
            my_color = np.array(["red", "orange", "green", "blue", "purple"])
            my_color = my_color[u]
        else:
            num_bins = 50
            my_color = plt.cm.viridis(np.linspace(0, 1, num_bins))
            bins = np.linspace(0, 2 * np.pi, num_bins)
            bins_center = (bins[1:] + bins[:-1]) / 2
            my_color = my_color[np.digitize(u, bins_center)]
        fig, ax = plt.subplots(1, 2, figsize=(8, 4), dpi=200)
        ax = ax.flatten()
        ax[0].scatter(z_true[:, 0], z_true[:, 1], c=my_color, s=1, alpha=0.5)
        ax[0].set_title("Ground Truth")
        ax[1].scatter(z_predict[:, 0], z_predict[:, 1], c=my_color, s=1, alpha=0.5)
        ax[1].set_title(f"{self.model_name_for_plot}")
        ax[0].set_xlabel("Latent 1")
        ax[0].set_ylabel("Latent 2")

        plt.savefig(os.path.join(self.output_dir, f"z{self.suffix_for_plot}.png"))


class SimulatedNetworkTester(SimulatedTester, NetworkTester):
    def _load_dataloader(self, x_true, u_true):
        all_set = torch.utils.data.TensorDataset(x_true, u_true)
        all_loader = torch.utils.data.DataLoader(all_set, batch_size=self.args.batch_size)

        return all_set, all_loader

    def _preprocess_inputs(self, inputs):
        if self.args.model_name in ["vae"]:
            inputs = {"x": inputs[0].to(self.device)}
        elif self.args.model_name in ["lfads"]:
            inputs = {"x": inputs[0].unsqueeze(0).to(self.device)}
        elif self.args.model_name in ["pivae"]:
            inputs = {"x": inputs[0].to(self.device), "u": inputs[1].to(self.device)}
        elif self.args.model_name in ["swap_vae"]:
            inputs = {"x1": inputs[0].to(self.device), "x2": inputs[0].to(self.device)}
        elif self.args.model_name in ["tidespl_vae"]:
            inputs = {"x": inputs[0].unsqueeze(0).to(self.device), "x_pos": inputs[0].unsqueeze(0).to(self.device), "x_neg": inputs[0].unsqueeze(0).to(self.device)}
        
        return inputs


class SimulatedReductionTester(SimulatedTester, ReductionTester):
    def _load_train_set(self, x_true, u_true):
        x_true = x_true.numpy()
        u_true = u_true.numpy()
        num_sample = len(x_true)
        x_train = x_true[:num_sample // 5 * 4]
        u_train = u_true[:num_sample // 5 * 4]

        return x_true, u_true, x_train, u_train


def get_args():
    import argparse
    
    parser = argparse.ArgumentParser()
    parser.add_argument("--data-path", default="Simulated_Data/", type=str, help="dataset path")
    parser.add_argument("--data-dim", default=100, type=int, help="data dimension")
    parser.add_argument("--classes", default=0, type=int, help="number of classes")
    parser.add_argument("--version", default=3, type=int, help="version of dataset")

    parser.add_argument("--model-name", default="tidespl_vae", type=str, help="name of model to train")
    parser.add_argument("--latent-dim", default=32, type=int, help="latent variable dimension")
    parser.add_argument("--latent-space", default="full", type=str, choices=["full", "content", "style"], help="latent variable space")
    parser.add_argument("--latent-reduc", default="PCA", type=str, choices=["PCA", "TSNE"], help="dimension reduction of latent variable space")
    parser.add_argument("--checkpoint-path", default=None, help="path to load checkpoint")

    parser.add_argument("--iterations", default=20000, type=int, help="number of iterations to train")
    parser.add_argument("--lr", default=5e-4, type=float, help="initial learning rate")

    parser.add_argument("--data-split", default="all", type=str, choices=["all", "train", "test"], help="split part of the dataset")
    parser.add_argument("--batch-size", default=320, type=int, help="number of images per gpu")

    parser.add_argument("--seed", action="store_true", help="whether to use random seed")
    parser.add_argument("--repeat", default=10, type=int, help="experiment for repeat")
    parser.add_argument("--only-plot", action="store_true", help="directly plot latent without transform")
    parser.add_argument("--output-path", default="logs/", help="path to save outputs")
    parser.add_argument("--device", default="cuda:0", help="device for train and test")
    args = parser.parse_args()
    return args


def load_simulated_dataset(args):
    if args.classes > 0:
        dataset_name = f"{args.data_dim}d_poisson_{args.classes}cls_v{args.version}"
    else:
        dataset_name = f"{args.data_dim}d_poisson_v{args.version}"
    simulated_dataset = torch.load(os.path.join(args.data_path, f"{dataset_name}.pt"), map_location="cpu")
    
    x_true = simulated_dataset["x"]
    u_true = simulated_dataset["u"]
    z_true = simulated_dataset["z"]

    return x_true, u_true, z_true, dataset_name


def main():
    args = get_args()
    if args.seed:
        set_deterministic(_seed_)
    print(args)

    x_true, u_true, z_true, args.dataset_name = load_simulated_dataset(args)
    print(f"Dataset Size: {x_true.size(0)}")
    
    if args.model_name in ["PCA", "TSNE", "CEBRA_delta"]:
        simulated_data_tester = SimulatedReductionTester(args.latent_dim, args)
    else:
        simulated_data_tester = SimulatedNetworkTester(args.latent_dim, args)
    simulated_data_tester.test(x_true, u_true, z_true)


if __name__ == "__main__":
    main()
