from .experiment import Experiment
from utils import set_seeds
import pandas as pd
from arfpy import arf
import pickle
import os
import time


class Experiment_ARF(Experiment):
    
    def __init__(self, config, args):
        super().__init__(config, args)
        
    
    def train(self, **kwargs):
        
        save_model = kwargs.get('save_model', False)
        set_seeds(self.seed, cuda_deterministic=True)
        
        train_loader, _ = self.data_processor.get_data_loaders()
        df_cat = pd.DataFrame(train_loader.data[0]).astype('category')
        self.num_cat_feat = df_cat.shape[1]
        df_cont = pd.DataFrame(train_loader.data[1])
        df = pd.concat((df_cat, df_cont), axis=1, ignore_index=True)
        
        training_start_time = time.monotonic()
        self.model = arf.arf(df,
                             num_trees = self.config.model.num_trees,
                             delta = self.config.model.delta,
                             max_iters = self.config.model.max_iters,
                             min_node_size = self.config.model.min_node_size,
                             random_state = self.seed,
                             n_jobs = self.config.model.n_jobs)
        self.model.forde()
        training_duration = time.monotonic() - training_start_time
        
        if save_model: 
            self.save_train_time(training_duration)
            self.save_model()
    
    
    def save_model(self):
        checkpoint =  {'model': self.model, 'num_cat_feat': self.num_cat_feat}
        with open(os.path.join(self.workdir, 'checkpoint.pkl'), 'wb') as f:
                pickle.dump(checkpoint, f, protocol=pickle.HIGHEST_PROTOCOL)
    
    
    def load_model(self):
        # init data preprocessor objects
        set_seeds(self.seed)
        train_loader, _ = self.data_processor.get_data_loaders()
        with open(os.path.join(self.workdir, 'checkpoint.pkl'), 'rb') as f:
            checkpoint = pickle.load(f)
            self.model = checkpoint['model']
            self.num_cat_feat = checkpoint['num_cat_feat']
            
            
    def sample(self, num_samples, seed):
        
        set_seeds(seed, cuda_deterministic=True)
        df = self.model.forge(n=num_samples)
        X_cat_gen = df.to_numpy()[:, :self.num_cat_feat]
        X_cont_gen = df.to_numpy()[:, self.num_cat_feat:]

        return X_cat_gen, X_cont_gen