from .experiment import Experiment
from ctgan import TVAE
from ctgan.synthesizers.tvae import Encoder
import numpy as np
import os
import pickle
import time
from utils import total_trainable_pars, set_seeds


class Experiment_TVAE(Experiment):
    def __init__(self, config, args):
        super().__init__(config, args)
        
        
    def train(self, **kwargs):
        
        save_model = kwargs.get('save_model', False)
        set_seeds(self.seed, cuda_deterministic=True)
    
        train_loader, _ = self.data_processor.get_data_loaders()
        X_cat_train = train_loader.data[0]
        X_cont_train = train_loader.data[1]
        X_train = np.column_stack((X_cat_train, X_cont_train))
        categorical_features = list(range(X_cat_train.shape[1]))
        self.num_cat_feat = X_cat_train.shape[1]
        
        # convert training steps to epochs
        batch_size = min(X_train.shape[0], self.config.model.batch_size)
        steps_per_epoch = X_train.shape[0] / batch_size
        epochs = round(self.config.model.train_steps / steps_per_epoch)
        print(f"Training for {epochs} epochs.")

        self.model = TVAE(embedding_dim=self.config.model.emb_dim,
                          compress_dims=self.config.model.compress_dims,
                          decompress_dims=self.config.model.decompress_dims,
                          batch_size=batch_size,
                          epochs=epochs,
                          cuda=self.config.model.cuda,
                          verbose=False)
        
        training_start_time = time.monotonic()
        self.model.fit(X_train, categorical_features)
        training_duration = time.monotonic() - training_start_time
        
        # calculate total number of parameters
        data_dim = self.model.transformer.output_dimensions
        encoder = Encoder(data_dim, self.config.model.compress_dims, self.model.embedding_dim)
        encoder_params = total_trainable_pars(encoder)
        decoder_params = total_trainable_pars(self.model.decoder)
        print(f"Total parameters: {encoder_params + decoder_params}")
        
        if save_model: 
            self.save_model()
            self.save_train_time(training_duration)
        
    
    def save_model(self):
         with open(os.path.join(self.workdir, 'model.pkl'), 'wb') as f:
                pickle.dump(self.model, f, protocol=pickle.HIGHEST_PROTOCOL)
    
    
    def load_model(self):
        set_seeds(self.seed)
        train_loader, _ = self.data_processor.get_data_loaders()
        X_cat_train = train_loader.data[0]
        self.num_cat_feat = X_cat_train.shape[1]
        with open(os.path.join(self.workdir, 'model.pkl'), 'rb') as f:
            self.model = pickle.load(f)
    

    def sample(self, num_samples, seed):
        
        set_seeds(seed, cuda_deterministic=True)
        gen_data = self.model.sample(num_samples)
        
        # bring generated data in required format
        X_cat_gen = gen_data[:, :self.num_cat_feat]
        X_cont_gen = gen_data[:, self.num_cat_feat:]

        return X_cat_gen, X_cont_gen