import pickle
import time

import pandas as pd
import torch
from arfpy import arf

from experiments.models.lowres.encoder import Discretizer
from utils import set_seeds

from .experiment import Experiment


class Experiment_ARF(Experiment):
    def __init__(self, config, args):
        super().__init__(config, args)

    def train(self, **kwargs):
        save_model = kwargs.get("save_model", False)
        set_seeds(self.seed, cuda_deterministic=True)

        train_loader, _ = self.data_processor.get_data_loaders(mean_impute=not self.config.model.use_as_lowres)
        x_cat_trn = train_loader.data[0]
        x_num_trn = train_loader.data[1]

        if self.config.model.use_as_lowres:
            encoder = Discretizer(
                x_num_trn,
                variant="dt",
                seed=self.seed,
                k_max=10,
                max_depth=8,
            )
            groups, _ = encoder.encode(x_num_trn)
            # remove extra missingness indicator (not used for TabCascade) from cat features
            x_cat_trn_no_miss = x_cat_trn[:, : len(self.data_processor.cat_cols)]
            x_cat_trn = torch.column_stack((x_cat_trn_no_miss, groups))
            self.num_cat_feat = x_cat_trn.shape[1]
            df = pd.DataFrame(x_cat_trn).astype("category")
        else:
            df_cat = pd.DataFrame(x_cat_trn).astype("category")
            self.num_cat_feat = df_cat.shape[1]
            df_cont = pd.DataFrame(x_num_trn)
            df = pd.concat((df_cat, df_cont), axis=1, ignore_index=True)

        training_start_time = time.monotonic()
        self.model = arf.arf(
            df,
            num_trees=self.config.model.num_trees,
            delta=self.config.model.delta,
            max_iters=self.config.model.max_iters,
            min_node_size=self.config.model.min_node_size,
            random_state=self.seed,
            n_jobs=self.config.model.n_jobs,
        )
        self.model.forde()
        training_duration = time.monotonic() - training_start_time

        if save_model:
            self.save_train_time(training_duration)
            self.save_model()

    def save_model(self):
        checkpoint = {"model": self.model, "num_cat_feat": self.num_cat_feat}
        with (self.workdir / "checkpoint.pkl").open("wb") as f:
            pickle.dump(checkpoint, f, protocol=pickle.HIGHEST_PROTOCOL)

    def load_model(self):
        # init data preprocessor objects
        set_seeds(self.seed)
        train_loader, _ = self.data_processor.get_data_loaders()
        with (self.workdir / "checkpoint.pkl").open("rb") as f:
            checkpoint = pickle.load(f)
            self.model = checkpoint["model"]
            self.num_cat_feat = checkpoint["num_cat_feat"]

    def sample(self, num_samples, seed):
        set_seeds(seed, cuda_deterministic=True)
        df = self.model.forge(n=num_samples)

        if self.config.model.use_as_lowres:
            x_cat_gen = df.to_numpy()
            x_cont_gen = None
        else:
            x_cat_gen = df.to_numpy()[:, : self.num_cat_feat]
            x_cont_gen = df.to_numpy()[:, self.num_cat_feat :]

        return x_cat_gen, x_cont_gen
