import pickle
import time

import torch

from .experiment import Experiment
from .models.cdtd import CDTD
from .utils import set_seeds


class Experiment_CDTD(Experiment):
    def __init__(self, config, args):
        super().__init__(config, args)

    def train(self, **kwargs):
        save_model = kwargs.get("save_model", False)
        set_seeds(self.seed, cuda_deterministic=True)

        train_loader, _ = self.data_processor.get_data_loaders()
        x_cat_train = train_loader.data[0]
        x_num_train = train_loader.data[1]

        self.model = CDTD(
            x_cat_train,
            x_num_train,
            self.config.model.cat_emb_dim,
            self.config.model.mlp_emb_dim,
            self.config.model.mlp_n_layers,
            self.config.model.mlp_n_units,
            self.config.model.sigma_data_cat,
            self.config.model.sigma_data_cont,
            self.config.model.sigma_min_cat,
            self.config.model.sigma_min_cont,
            self.config.model.sigma_max_cat,
            self.config.model.sigma_max_cont,
            self.config.model.cat_emb_init_sigma,
            self.config.model.timewarp_type,
            self.config.model.timewarp_weight_low_noise,
        )

        num_params = sum(p.numel() for p in self.model.diff_model.parameters())
        print("Total parameters = ", num_params)

        training_start_time = time.monotonic()
        batch_size = min(self.config.training.batch_size, x_cat_train.shape[0])
        self.model.fit(
            x_cat_train,
            x_num_train,
            self.config.training.num_steps_train,
            self.config.training.num_steps_lr_warmup,
            batch_size,
            self.config.training.lr,
            self.seed,
            self.config.training.ema_decay,
            self.config.training.log_steps,
        )
        training_duration = time.monotonic() - training_start_time
        if save_model:
            self.save_train_time(training_duration)
            self.save_model()

    def tune_cdtd(self):
        grid = [0.2, 0.4, 0.6, 0.8, 1.0, 1.2, 1.4, 1.6, 1.8, 2.0, 2.5, 3.0, 3.5, 4.0]
        for i, rel_cat_weight in enumerate(grid):
            set_seeds(self.seed, cuda_deterministic=True)
            train_loader, _ = self.data_processor.get_data_loaders()
            x_cat_train = train_loader.data[0]
            x_num_train = train_loader.data[1]

            self.model = CDTD(
                x_cat_train,
                x_num_train,
                self.config.model.cat_emb_dim,
                self.config.model.mlp_emb_dim,
                self.config.model.mlp_n_layers,
                self.config.model.mlp_n_units,
                self.config.model.sigma_data_cat,
                self.config.model.sigma_data_cont,
                self.config.model.sigma_min_cat,
                self.config.model.sigma_min_cont,
                self.config.model.sigma_max_cat,
                self.config.model.sigma_max_cont,
                self.config.model.cat_emb_init_sigma,
                self.config.model.timewarp_type,
                self.config.model.timewarp_weight_low_noise,
                rel_cat_weight=rel_cat_weight,
            )

            self.model.fit(
                x_cat_train,
                x_num_train,
                self.config.training.num_steps_train,
                self.config.training.num_steps_lr_warmup,
                self.config.training.batch_size,
                self.config.training.lr,
                self.seed,
                self.config.training.ema_decay,
                self.config.training.log_steps,
            )

            self.save_model()
            self.evaluate()

            # load results and save with different name
            with (self.workdir / "eval_results.pkl").open("rb") as f:
                eval_results = pickle.load(f)

            with (self.workdir / f"eval_results_{i}.pkl").open("wb") as f:
                pickle.dump(eval_results, f)

    def save_model(self):
        torch.save(self.model.diff_model.state_dict(), self.workdir / "model.pt")

    def load_model(self):
        set_seeds(self.seed)
        train_loader, _ = self.data_processor.get_data_loaders()
        x_cat_train = train_loader.data[0]
        x_num_train = train_loader.data[1]

        self.model = CDTD(
            x_cat_train,
            x_num_train,
            self.config.model.cat_emb_dim,
            self.config.model.mlp_emb_dim,
            self.config.model.mlp_n_layers,
            self.config.model.mlp_n_units,
            self.config.model.sigma_data_cat,
            self.config.model.sigma_data_cont,
            self.config.model.sigma_min_cat,
            self.config.model.sigma_min_cont,
            self.config.model.sigma_max_cat,
            self.config.model.sigma_max_cont,
            self.config.model.cat_emb_init_sigma,
            self.config.model.timewarp_type,
            self.config.model.timewarp_weight_low_noise,
        )
        checkpoint = torch.load(self.workdir / "model.pt")
        self.model.diff_model.load_state_dict(checkpoint)
        self.model.diff_model.to(self.device)
        self.model.diff_model.eval()

    def sample(self, num_samples, seed):
        x_cat_gen, x_num_gen = self.model.sample(
            num_samples,
            self.config.model.generation_steps,
            self.config.model.generation_batch_size,
            seed,
        )
        return x_cat_gen, x_num_gen
