import os
import torch
import numpy as np
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from cebra import CEBRA
import cebra.integrations.sklearn.dataset as cebra_sklearn_dataset

from model.VAE import *
from model.LFADS import *
from model.PIVAE import *
from model.SwapVAE import *
from model.TiDeSPLVAE import *
from model.CEBRAforMovie import *


class Tester:
    def __init__(self, num_latent, args):
        self.num_latent = num_latent
        self.args = args
        self.device = torch.device(args.device)

        self._set_output_dir()
        self._set_show_info()

    def _set_output_dir(self):
        raise NotImplementedError()

    def _set_show_info(self):
        model_name_for_plot = self.args.model_name
        suffix_for_plot = ""
        suffix_for_score = ""
        if self.args.model_name in ["swap_vae", "tidespl_vae"]:
            model_name_for_plot += f" - {self.args.latent_space}"
            suffix_for_plot += f"_{self.args.latent_space}"
            suffix_for_score += f"_{self.args.latent_space}"
        
        if "only_plot" not in self.args or self.args.only_plot:
            if self.args.latent_dim > 2:
                suffix_for_plot += f"_{self.args.latent_reduc.lower()}"
        else:
            suffix_for_plot += f"_regress"
        
        suffix_for_plot += f"_{self.args.data_split}"
        suffix_for_score += f"_{self.args.data_split}"

        self.model_name_for_plot = model_name_for_plot
        self.suffix_for_plot = suffix_for_plot
        self.suffix_for_score = suffix_for_score

    def _load_model(self):
        raise NotImplementedError()

    def _transform(self):
        raise NotImplementedError()

    def test(self):
        raise NotImplementedError()


class NetworkTester(Tester):
    def __init__(self, num_latent, args):
        super().__init__(num_latent, args)
        self._set_model()
     
    def _set_model(self):
        assert self.args.checkpoint_path is not None
        model_args = {
            "input_dim": self.args.data_dim
        }
        if self.args.model_name in ["vae"]:
            model_args.update({
                "latent_dim": self.args.latent_dim
            })
        elif self.args.model_name in ["lfads"]:
            model_args.update({
                "encod_input_dim": self.args.data_dim,
                "factor_dim": self.args.latent_dim,
                "g0_enc_dim": self.args.latent_dim // 2,
                "g0_dim": self.args.latent_dim // 2,
                "con_enc_dim": self.args.latent_dim // 2,
                "con_dim": self.args.latent_dim // 2,
                "u_dim": max(self.args.latent_dim // 16, 2)
            })
        elif self.args.model_name in ["pivae"]:
            model_args.update({
                "latent_dim": self.args.latent_dim,
                "label_dim": self.args.classes if self.args.classes > 0 else 1,
                "discrete_prior": self.args.classes > 0,
                "observation_model": "poisson"
            })
        elif self.args.model_name in ["swap_vae"]:
            model_args.update({
                "content_dim": self.args.latent_dim // 2,
                "style_dim": self.args.latent_dim // 2
            })
        elif self.args.model_name in ["tidespl_vae"]:
            model_args.update({
                "content_dim": self.args.latent_dim // 2,
                "style_dim": self.args.latent_dim // 2,
                "hidden_state_dim": self.args.latent_dim
            })
        self.model = eval(f"{self.args.model_name}")(**model_args)
    
    def _load_model(self, repeat=0):
        checkpoint = torch.load(os.path.join(self.args.checkpoint_path, f"checkpoint_{repeat}.pth"), map_location=self.device)
        self.model.load_state_dict(checkpoint["model"])
        self.model.to(self.device)

    def _load_dataloader(self, x_true, u_true):
        raise NotImplementedError()

    def _preprocess_inputs(self, inputs):
        raise NotImplementedError()

    def _get_latent(self, outputs):
        if self.args.model_name in ["vae"]:
            outputs = outputs["z_mu"]
        elif self.args.model_name in ["lfads"]:
            outputs = outputs["f"].permute(1, 0, 2).flatten(1, 2)
        elif self.args.model_name in ["pivae"]:
            outputs = outputs["z_mu"]
        elif self.args.model_name in ["swap_vae"]:
            outputs = torch.cat((outputs["z1_content"], outputs["z1_style_mu"]), dim=-1)
        elif self.args.model_name in ["tidespl_vae"]:
            outputs = torch.cat((outputs["z_content"], outputs["z_style_mu"]), dim=-1).permute(1, 0, 2).flatten(1, 2)
        
        return outputs

    def _transform(self, x_true, u_true, repeat=0):
        u_true = u_true if self.args.classes > 0 else u_true.unsqueeze(-1)
        all_set, all_loader = self._load_dataloader(x_true, u_true)

        self.model.eval()
        z_predict = []
        with torch.inference_mode():
            for i, inputs in enumerate(all_loader):
                inputs = self._preprocess_inputs(inputs)
                outputs = self.model(**inputs)
                z_predict.append(self._get_latent(outputs))
        z_predict = torch.stack(z_predict, dim=0)
        z_predict = z_predict.view(-1, z_predict.size(-1)).cpu().numpy()

        return z_predict


class ReductionTester(Tester):
    def _load_model(self, repeat=0):
        if self.args.model_name in ["PCA", "TSNE"]:
            self.model = eval(f"{self.args.model_name}")(n_components=self.num_latent)
        elif self.args.model_name in ["CEBRA_delta", "CEBRA_discrete_5", "CEBRA_time_1", "CEBRA_time_5", "CEBRA_time_10"]:
            if os.path.exists(os.path.join(self.output_dir, f"checkpoint_{repeat}.pt")):
                self.model = CEBRA.load(os.path.join(self.output_dir, f"checkpoint_{repeat}.pt"))
            else:
                if self.args.model_name == "CEBRA_delta":
                    self.model = CEBRA(
                        model_architecture="offset1-model",
                        device="cuda_if_available",
                        criterion="infonce",
                        distance="cosine",
                        conditional="delta",
                        temperature=1.0,
                        temperature_mode="constant",
                        time_offsets=1,
                        delta=0.1,
                        max_iterations=self.args.iterations,
                        batch_size=self.args.batch_size,
                        learning_rate=self.args.lr,
                        optimizer="adam",
                        output_dimension=self.args.latent_dim,
                        verbose=True,
                        num_hidden_units=self.args.data_dim
                    )
                elif self.args.model_name == "CEBRA_discrete_5":
                    self.model = CEBRA(
                        model_architecture="offset5-model",
                        max_iterations=self.args.iterations,
                        batch_size=self.args.batch_size,
                        learning_rate=self.args.lr,
                        output_dimension=self.args.latent_dim,
                        verbose=True,
                        num_hidden_units=self.args.data_dim
                    )
                elif self.args.model_name == "CEBRA_time_1":
                    self.model = CEBRA(
                        model_architecture="offset1-model",
                        conditional="time",
                        time_offsets=1,
                        max_iterations=self.args.iterations,
                        batch_size=self.args.batch_size,
                        learning_rate=self.args.lr,
                        output_dimension=self.args.latent_dim,
                        verbose=True,
                        num_hidden_units=self.args.data_dim
                    )
                elif self.args.model_name == "CEBRA_time_5":
                    self.model = CEBRA(
                        model_architecture="offset5-model",
                        conditional="time",
                        time_offsets=3,
                        max_iterations=self.args.iterations,
                        batch_size=self.args.batch_size,
                        learning_rate=self.args.lr,
                        output_dimension=self.args.latent_dim,
                        verbose=True,
                        num_hidden_units=self.args.data_dim
                    )
                elif self.args.model_name == "CEBRA_time_10":
                    self.model = CEBRA(
                        model_architecture="offset10-model",
                        conditional="time",
                        time_offsets=5,
                        max_iterations=self.args.iterations,
                        batch_size=self.args.batch_size,
                        learning_rate=self.args.lr,
                        output_dimension=self.args.latent_dim,
                        verbose=True,
                        num_hidden_units=self.args.data_dim
                    )
        elif self.args.model_name in ["CEBRA_time_delta_4", "CEBRA_time_4"]:
            if self.args.model_name == "CEBRA_time_delta_4":
                self.model = CEBRAforMovie(
                    model_architecture="resample1-model",
                    conditional="time_delta",
                    time_offsets=2,
                    max_iterations=self.args.iterations,
                    batch_size=self.args.batch_size,
                    learning_rate=self.args.lr,
                    output_dimension=self.args.latent_dim,
                    num_hidden_units=self.args.data_dim,
                    verbose=True
                )
            elif self.args.model_name == "CEBRA_time_4":
                self.model = CEBRAforMovie(
                    model_architecture="resample1-model",
                    conditional="time",
                    time_offsets=2,
                    max_iterations=self.args.iterations,
                    batch_size=self.args.batch_size,
                    learning_rate=self.args.lr,
                    output_dimension=self.args.latent_dim,
                    num_hidden_units=self.args.data_dim,
                    verbose=True
                )
    
    def _load_train_set(self, x_true, u_true):
        raise NotImplementedError()

    def _transform(self, x_true, u_true, repeat=0):
        u_true = u_true if self.args.classes > 0 else u_true.unsqueeze(-1)
        x_true, u_true, x_train, u_train = self._load_train_set(x_true, u_true)

        if self.args.model_name == "PCA":
            self.model.fit(x_train)
            z_predict = self.model.transform(x_true)
        elif self.args.model_name == "TSNE":
            z_predict = self.model.fit_transform(x_true)
        elif self.args.model_name in ["CEBRA_delta", "CEBRA_discrete_5", "CEBRA_time_1", "CEBRA_time_5", "CEBRA_time_10"]:
            if not os.path.exists(os.path.join(self.output_dir, f"checkpoint_{repeat}.pt")):
                self.model.fit(x_train, u_train)
                self.model.save(os.path.join(self.output_dir, f"checkpoint_{repeat}.pt"))
            z_predict = self.model.transform(x_true)
        elif self.args.model_name in ["CEBRA_time_delta_4", "CEBRA_time_4"]:
            if not os.path.exists(os.path.join(self.output_dir, f"checkpoint_{repeat}.pt")):
                train_dataset = cebra_sklearn_dataset.SklearnDataset(x_train, (u_train,))
                self.model.fit(train_dataset)
                self.model.save(os.path.join(self.output_dir, f"checkpoint_{repeat}.pt"))
            else:
                self.model.load(x_true.size(1), os.path.join(self.output_dir, f"checkpoint_{repeat}.pt"))
            dataset = cebra_sklearn_dataset.SklearnDataset(x_true, (u_true,))
            z_predict = self.model.transform(dataset)
        
        return z_predict
