import glob
import os
import time

import numpy as np
import torch
from torch.utils.data import DataLoader

from .experiment import Experiment
from .models.tabdiff.main_modules import Model, UniModMLP
from .models.tabdiff.trainer import Trainer
from .models.tabdiff.unified_ctime_diffusion import UnifiedCtimeDiffusion
from .models.tabdiff.utils import set_seeds


class Experiment_TabDiff(Experiment):
    def __init__(self, config, args):
        super().__init__(config, args)

    def train(self, **kwargs):
        save_model = kwargs.get("save_model", False)
        set_seeds(self.seed, cuda_deterministic=True)

        train_loader, _ = self.data_processor.get_data_loaders()
        x_train_cat = train_loader.data[0]
        x_train_num = train_loader.data[1]

        # re-construct dataloader with correct data format
        batch_size = min(self.config.train.batch_size, x_train_num.shape[0])
        train_loader = DataLoader(
            torch.cat((x_train_num, x_train_cat), dim=1),
            batch_size=batch_size,
            shuffle=True,
            num_workers=4,
        )

        categories = np.array(self.data_processor.X_cat_n_classes)
        self.d_numerical = x_train_num.shape[1]
        self.config["unimodmlp_params"]["d_numerical"] = self.d_numerical
        self.config["unimodmlp_params"]["categories"] = (categories + 1).tolist()  # add one for the mask category

        backbone = UniModMLP(**self.config["unimodmlp_params"])
        model = Model(backbone, **self.config["diffusion_params"]["edm_params"])
        model.to(self.device)

        diffusion = UnifiedCtimeDiffusion(
            num_classes=categories,
            num_numerical_features=self.d_numerical,
            denoise_fn=model,
            y_only_model=None,
            **self.config["diffusion_params"],
            device=self.device,
        )
        num_params = sum(p.numel() for p in diffusion.parameters())
        print("Total parameters = ", num_params)
        diffusion.to(self.device)
        diffusion.train()

        # convert training steps to epochs
        steps_per_epoch = x_train_num.shape[0] / self.config.train.batch_size
        epochs = round(self.config.train.train_steps / steps_per_epoch)

        trainer = Trainer(
            diffusion,
            train_loader,
            self.d_numerical,
            categories,
            **self.config["train"],
            steps=epochs,
            sample_batch_size=self.config["sample"]["batch_size"],
            num_samples_to_generate=1000,
            model_save_path=self.workdir,
            result_save_path=self.workdir,
            device=self.device,
            verbose=False,
        )
        training_start_time = time.monotonic()
        trainer.run_loop()
        training_duration = time.monotonic() - training_start_time
        if save_model:
            self.save_train_time(training_duration)

    def save_model(self):
        return

    def load_model(self):
        set_seeds(self.seed)
        train_loader, _ = self.data_processor.get_data_loaders()
        x_train_num = train_loader.data[1]

        categories = np.array(self.data_processor.X_cat_n_classes)
        self.d_numerical = x_train_num.shape[1]
        self.config["unimodmlp_params"]["d_numerical"] = self.d_numerical
        self.config["unimodmlp_params"]["categories"] = (categories + 1).tolist()  # add one for the mask category

        backbone = UniModMLP(**self.config["unimodmlp_params"])
        model = Model(backbone, **self.config["diffusion_params"]["edm_params"])
        model.to(self.device)

        self.diffusion = UnifiedCtimeDiffusion(
            num_classes=categories,
            num_numerical_features=self.d_numerical,
            denoise_fn=model,
            y_only_model=None,
            **self.config["diffusion_params"],
            device=self.device,
        )

        # load best EMA model checkpoint
        state_dicts = torch.load(glob.glob(os.path.join(self.workdir, "best_ema_model_*"))[0])
        self.diffusion._denoise_fn.load_state_dict(state_dicts["denoise_fn"])
        self.diffusion.num_schedule.load_state_dict(state_dicts["num_schedule"])
        self.diffusion.cat_schedule.load_state_dict(state_dicts["cat_schedule"])
        self.diffusion.to(self.device)
        self.diffusion.eval()

    def sample(self, num_samples, seed):
        batch_size = min(self.config.sample.batch_size, num_samples)
        set_seeds(seed, cuda_deterministic=True)
        with torch.no_grad():
            syn_data = self.diffusion.sample_all(num_samples, batch_size, keep_nan_samples=True, verbose=False)
        num_all_zero_row = (syn_data.sum(dim=1) == 0).sum()
        if num_all_zero_row:
            print(f"The generated samples contain {num_all_zero_row} Nan instances!")

        # get data into required shape
        x_cat_gen = syn_data[:, self.d_numerical :]
        x_cont_gen = syn_data[:, : self.d_numerical]
        x_cat_gen = x_cat_gen.cpu().numpy()
        x_cont_gen = x_cont_gen.cpu().numpy()

        return x_cat_gen, x_cont_gen
