import os
import time
import torch
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from utils import LossForPrint

from model.VAE import *
from model.LFADS import *
from model.PIVAE import *
from model.SwapVAE import *
from model.TiDeSPLVAE import *


class Trainer:
    def __init__(self, args):
        self.args = args
        self._load_model()
        self._set_log()
        self.device = torch.device(args.device)
        self.model.to(self.device)

    def _load_model(self):
        print("Creating model")
        model_args = {
            "input_dim": self.args.data_dim
        }
        if self.args.model_name in ["vae"]:
            model_args.update({
                "latent_dim": self.args.latent_dim
            })
        elif self.args.model_name in ["lfads"]:
            model_args.update({
                "encod_input_dim": self.args.data_dim,
                "factor_dim": self.args.latent_dim,
                "g0_enc_dim": self.args.latent_dim // 2,
                "g0_dim": self.args.latent_dim // 2,
                "con_enc_dim": self.args.latent_dim // 2,
                "con_dim": self.args.latent_dim // 2,
                "u_dim": max(self.args.latent_dim // 16, 2)
            })
        elif self.args.model_name in ["pivae"]:
            model_args.update({
                "latent_dim": self.args.latent_dim,
                "label_dim": self.args.classes if self.args.classes > 0 else 1,
                "discrete_prior": self.args.classes > 0,
                "observation_model": "poisson"
            })
        elif self.args.model_name in ["swap_vae"]:
            model_args.update({
                "content_dim": self.args.latent_dim // 2,
                "style_dim": self.args.latent_dim // 2
            })
        elif self.args.model_name in ["tidespl_vae"]:
            assert self.args.aug > 0
            model_args.update({
                "content_dim": self.args.latent_dim // 2,
                "style_dim": self.args.latent_dim // 2,
                "hidden_state_dim": self.args.latent_dim
            })
        self.model = eval(f"{self.args.model_name}")(**model_args)
        # print(self.model)
        # total_params = sum(p.numel() for p in self.model.parameters())
        # print(f'{total_params:,} total parameters.')
        # total_trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
        # print(f'{total_trainable_params:,} training parameters.')

    def _set_log(self):
        raise NotImplementedError()

    def _preprocess_inputs(self):
        raise NotImplementedError()

    def _prepare_for_loss(self, inputs):
        if self.args.model_name in ["vae", "lfads", "pivae"]:
            loss_kwargs = {"x": inputs["x"], "kld_weight": self.args.kld_weight}
        elif self.args.model_name in ["swap_vae"]:
            loss_kwargs = {"x1": inputs["x1"], "x2": inputs["x2"], "kld_weight": self.args.kld_weight, "align_weight": self.args.cont_weight}
        elif self.args.model_name in ["tidespl_vae"]:
            loss_kwargs = {"x": inputs["x"], "x_pos": inputs["x_pos"], "kld_weight": self.args.kld_weight, "cont_weight": self.args.cont_weight, "temperature": self.args.temperature, "prior_weight": self.args.prior_weight}
        
        return loss_kwargs

    def _get_latent(self, outputs):
        if self.args.model_name in ["vae"]:
            outputs = outputs["z_mu"]
        elif self.args.model_name in ["lfads"]:
            outputs = outputs["f"].permute(1, 0, 2).flatten(1, 2)
        elif self.args.model_name in ["pivae"]:
            outputs = outputs["z_mu"]
        elif self.args.model_name in ["swap_vae"]:
            outputs = torch.cat((outputs["z1_content"], outputs["z1_style_mu"]), dim=-1)
        elif self.args.model_name in ["tidespl_vae"]:
            outputs = torch.cat((outputs["z_content"], outputs["z_style_mu"]), dim=-1).permute(1, 0, 2).flatten(1, 2)
        
        return outputs
    
    def _compute_score(self):
        raise NotImplementedError()

    def train(self, train_loader, test_loader, repeat=0):
        self.writer = SummaryWriter(self.logdir, filename_suffix=f".repeat{repeat}")

        self.model.init_weight()
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.args.lr)

        print(f"Start training [repeat {repeat}]...")
        train_loss_dict = LossForPrint()
        test_loss_dict = LossForPrint()
        start_time = time.time()
        for epoch in range(self.args.epochs):
            self.model.train()
            num_batch = 0
            for i, inputs in enumerate(train_loader):
                inputs = self._preprocess_inputs(inputs, train=True)
                outputs = self.model(**inputs)
                outputs.update(self._prepare_for_loss(inputs))
                train_loss = self.model.compute_loss(**outputs)
                loss = train_loss["loss"]

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                train_loss_dict.update(train_loss)
                num_batch += 1
            train_loss_dict.compute(lambda x: x / num_batch)

            self.model.eval()
            num_batch = 0
            with torch.inference_mode():
                for i, inputs in enumerate(test_loader):
                    inputs = self._preprocess_inputs(inputs, train=False)
                    outputs = self.model(**inputs)
                    outputs.update(self._prepare_for_loss(inputs))
                    test_loss = self.model.compute_loss(**outputs)

                    test_loss_dict.update(test_loss)
                    num_batch += 1
            test_loss_dict.compute(lambda x: x / num_batch)

            self.writer.add_scalar("train_loss", train_loss_dict.get_loss("loss"), epoch)
            self.writer.add_scalar("test_loss", test_loss_dict.get_loss("loss"), epoch)
            if epoch % self.args.print_freq == 0 or epoch == self.args.epochs - 1:
                print_results = f"Epoch[{epoch}]:  total time={time.time() - start_time:.3f}s\n"
                print_results += train_loss_dict.process_print("Train") + "\n"
                print_results += test_loss_dict.process_print("Test")
                print(print_results)
                start_time = time.time()
            train_loss_dict.clear()
            test_loss_dict.clear()
        self.writer.close()
    
    def test(self, all_loader, repeat=0, **kwargs):
        print(f"Start testing [repeat {repeat}]...")
        self.model.eval()
        z_predict = []
        with torch.inference_mode():
            for i, inputs in enumerate(all_loader):
                inputs = self._preprocess_inputs(inputs, train=False)
                outputs = self.model(**inputs)
                z_predict.append(self._get_latent(outputs))
        z_predict = torch.stack(z_predict, dim=0)
        z_predict = z_predict.view(-1, z_predict.size(-1))
        
        score = self._compute_score(z_predict, **kwargs)
        checkpoint = {
            "model": self.model.state_dict(),
            "args": self.args,
            "score": score,
            "repeat": repeat
        }
        torch.save(checkpoint, os.path.join(self.logdir, f"checkpoint_{repeat}.pth"))
        return score
