import pickle
import time

import numpy as np
from ctgan import CTGAN
from ctgan.synthesizers.ctgan import Discriminator

from utils import set_seeds, total_trainable_pars

from .experiment import Experiment


class Experiment_CTGAN(Experiment):
    def __init__(self, config, args):
        super().__init__(config, args)

    def train(self, **kwargs):
        save_model = kwargs.get("save_model", False)
        set_seeds(self.seed, cuda_deterministic=True)

        train_loader, _ = self.data_processor.get_data_loaders()
        x_cat_train = train_loader.data[0]
        self.num_cat_feat = x_cat_train.shape[1]
        x_cont_train = train_loader.data[1]
        x_train = np.column_stack((x_cat_train, x_cont_train))
        categorical_features = list(range(x_cat_train.shape[1]))

        # batch sizes must be multiple of 10 (otherwise PAC does not work)
        batch_size = min(self.config.model.batch_size, x_train.shape[0])
        remainder = batch_size % 10
        batch_size -= remainder

        # convert training steps to epochs
        steps_per_epoch = x_train.shape[0] / batch_size
        epochs = round(self.config.model.train_steps / steps_per_epoch)
        print(f"Training for {epochs} epochs.")

        self.model = CTGAN(
            embedding_dim=self.config.model.emb_dim,
            generator_dim=self.config.model.generator_dim,
            discriminator_dim=self.config.model.discriminator_dim,
            generator_lr=self.config.model.generator_lr,
            discriminator_lr=self.config.model.discriminator_lr,
            batch_size=batch_size,
            epochs=epochs,
            cuda=self.config.model.cuda,
            verbose=False,
        )

        training_start_time = time.monotonic()
        self.model.fit(x_train, categorical_features)
        training_duration = time.monotonic() - training_start_time

        # compute total number of parameters
        data_dim = self.model._transformer.output_dimensions
        discriminator = Discriminator(
            data_dim + self.model._data_sampler.dim_cond_vec(), self.model._discriminator_dim, pac=self.model.pac
        )
        discriminator_params = total_trainable_pars(discriminator)
        generator_params = total_trainable_pars(self.model._generator)
        print(f"Total parameters: {discriminator_params + generator_params}")

        if save_model:
            self.save_model()
            self.save_train_time(training_duration)

    def save_model(self):
        with (self.workdir / "model.pkl").open("wb") as f:
            pickle.dump(self.model, f, protocol=pickle.HIGHEST_PROTOCOL)

    def load_model(self):
        set_seeds(self.seed)
        train_loader, _ = self.data_processor.get_data_loaders()
        x_cat_train = train_loader.data[0]
        self.num_cat_feat = x_cat_train.shape[1]

        with (self.workdir / "model.pkl").open("rb") as f:
            self.model = pickle.load(f)

    def sample(self, num_samples, seed):
        set_seeds(seed, cuda_deterministic=True)
        gen_data = self.model.sample(num_samples)

        # bring generated data in required format
        x_cat_gen = gen_data[:, : self.num_cat_feat]
        x_cont_gen = gen_data[:, self.num_cat_feat :]

        return x_cat_gen, x_cont_gen
