import pickle
import time

import numpy as np
from ForestDiffusion import ForestDiffusionModel

from utils import set_seeds

from .experiment import Experiment


class Experiment_ForestDiffusion(Experiment):
    def __init__(self, config, args):
        super().__init__(config, args)

    def train(self, **kwargs):
        save_model = kwargs.get("save_model", False)
        set_seeds(self.seed)

        train_loader, _ = self.data_processor.get_data_loaders()
        x_cat_trn, x_num_trn = train_loader.data

        # determine n_classes for X_cat
        self.n_classes_cat = []
        for i in range(x_cat_trn.shape[1]):
            val = x_cat_trn[:, i].unique()
            self.n_classes_cat.append(len(val))

        # vector that indicates if column is binary
        bin_indexes = np.nonzero(np.array(self.n_classes_cat) == 2)[0].tolist()

        # vector that indicates if column is categorical (>= 3 categories)
        cat_indexes = np.nonzero(np.array(self.n_classes_cat) > 2)[0].tolist()

        x_trn = np.column_stack((x_cat_trn, x_num_trn))

        n_t = 50  # number of noise levels (the lower the worse the performance; use default from paper)
        duplicate_K = 100  # leave as default
        max_depth = 7  # recommended to leave at default
        n_estimators = 100  # recommended to leave at default
        n_batch = 8

        # n_t = 5
        # n_estimators = 10

        # Note: we train and UNconditional model to keep the comparison fair
        # otherwise we would need to train separate models for each class label
        training_start_time = time.monotonic()
        self.model = ForestDiffusionModel(
            x_trn,
            n_t=n_t,
            duplicate_K=duplicate_K,
            n_estimators=n_estimators,
            max_depth=max_depth,
            bin_indexes=bin_indexes,
            cat_indexes=cat_indexes,
            diffusion_type="flow",
            n_jobs=16,  # same as for ARF
            seed=self.seed,
            n_batch=n_batch,
            gpu_hist=False,
        )

        training_duration = time.monotonic() - training_start_time

        if save_model:
            self.save_model()
            self.save_train_time(training_duration)

    def save_model(self):
        with (self.workdir / "model.pkl").open("wb") as f:
            pickle.dump(self.model, f, protocol=pickle.HIGHEST_PROTOCOL)

    def load_model(self):
        # init data processor encoders
        train_loader, _ = self.data_processor.get_data_loaders(mean_impute=False)
        x_cat_trn, _ = train_loader.data

        # determine n_classes for X_cat
        self.n_classes_cat = []
        for i in range(x_cat_trn.shape[1]):
            val = x_cat_trn[:, i].unique()
            self.n_classes_cat.append(len(val))

        with (self.workdir / "model.pkl").open("rb") as f:
            self.model = pickle.load(f)

    def sample(self, num_samples, seed):
        set_seeds(seed)

        gen_data = self.model.generate(num_samples)
        x_cat_gen = gen_data[:, : len(self.self.n_classes_cat)]
        x_cont_gen = gen_data[:, len(self.self.n_classes_cat) :]

        return x_cat_gen.astype("int64"), x_cont_gen.astype("float64")
