from .experiment import Experiment
import numpy as np
import torch
import os
import time
from torch.utils.data import DataLoader
import glob

from .models.tabdiff.main_modules import UniModMLP
from .models.tabdiff.main_modules import Model
from .models.tabdiff.unified_ctime_diffusion import UnifiedCtimeDiffusion
from .models.tabdiff.utils import set_seeds
from .models.tabdiff.trainer import Trainer


class Experiment_TabDiff(Experiment):
    
    def __init__(self, config, args):
        super().__init__(config, args)
        
    
    def train(self, **kwargs):
        
        save_model = kwargs.get('save_model', False)
        set_seeds(self.seed, cuda_deterministic=True)
        
        train_loader, _ = self.data_processor.get_data_loaders()
        X_train_cat = train_loader.data[0]
        X_train_num = train_loader.data[1]
        
        # re-construct dataloader with correct data format
        train_loader = DataLoader(torch.cat((X_train_num, X_train_cat), dim=1), 
                                  batch_size=self.config.train.batch_size, shuffle=True, num_workers=4)

        categories = np.array(self.data_processor.X_cat_n_classes)
        self.d_numerical = X_train_num.shape[1]
        self.config['unimodmlp_params']['d_numerical'] = self.d_numerical
        self.config['unimodmlp_params']['categories'] = (categories+1).tolist()  # add one for the mask category
      
        backbone = UniModMLP(**self.config['unimodmlp_params'])
        model = Model(backbone, **self.config['diffusion_params']['edm_params'])
        model.to(self.device)
      
        diffusion = UnifiedCtimeDiffusion(
            num_classes=categories,
            num_numerical_features=self.d_numerical,
            denoise_fn=model,
            y_only_model=None,
            **self.config['diffusion_params'],
            device=self.device,
        )
        num_params = sum(p.numel() for p in diffusion.parameters())
        print("Total parameters = ", num_params)
        diffusion.to(self.device)
        diffusion.train()

        # convert training steps to epochs
        steps_per_epoch = X_train_num.shape[0] / self.config.train.batch_size
        epochs = round(self.config.train.train_steps / steps_per_epoch) 
        
        trainer = Trainer(
            diffusion,
            train_loader,
            self.d_numerical,
            categories,
            **self.config['train'],
            steps=epochs,
            sample_batch_size=self.config['sample']['batch_size'],
            num_samples_to_generate=1000,
            model_save_path=self.workdir,
            result_save_path=self.workdir,
            device=self.device,
            verbose=False,
        )
        training_start_time = time.monotonic()
        trainer.run_loop()
        training_duration = time.monotonic() - training_start_time
        if save_model: 
            self.save_train_time(training_duration)

    
    def save_model(self):
        return
    
    def load_model(self):
        
        set_seeds(self.seed)
        train_loader, _ = self.data_processor.get_data_loaders()
        X_train_num = train_loader.data[1]
        
        categories = np.array(self.data_processor.X_cat_n_classes)
        self.d_numerical = X_train_num.shape[1]
        self.config['unimodmlp_params']['d_numerical'] = self.d_numerical
        self.config['unimodmlp_params']['categories'] = (categories+1).tolist()  # add one for the mask category
      
        backbone = UniModMLP(**self.config['unimodmlp_params'])
        model = Model(backbone, **self.config['diffusion_params']['edm_params'])
        model.to(self.device)
      
        self.diffusion = UnifiedCtimeDiffusion(
            num_classes=categories,
            num_numerical_features=self.d_numerical,
            denoise_fn=model,
            y_only_model=None,
            **self.config['diffusion_params'],
            device=self.device,
        )

        # load best EMA model checkpoint
        state_dicts = torch.load(glob.glob(os.path.join(self.workdir, "best_ema_model_*"))[0])
        self.diffusion._denoise_fn.load_state_dict(state_dicts['denoise_fn'])
        self.diffusion.num_schedule.load_state_dict(state_dicts['num_schedule'])
        self.diffusion.cat_schedule.load_state_dict(state_dicts['cat_schedule'])   
        self.diffusion.to(self.device)
        self.diffusion.eval()
            
    def sample(self, num_samples, seed):
        
        batch_size = min(self.config.sample.batch_size, num_samples)
        set_seeds(seed, cuda_deterministic=True)
        with torch.no_grad():
            syn_data = self.diffusion.sample_all(num_samples, batch_size, keep_nan_samples=True, verbose=False)
        num_all_zero_row = (syn_data.sum(dim=1) == 0).sum()
        if num_all_zero_row:
            print(f"The generated samples contain {num_all_zero_row} Nan instances!")
    
        # get data into required shape
        X_cat_gen = syn_data[:, self.d_numerical:]
        X_cont_gen = syn_data[:, :self.d_numerical]
        X_cat_gen = X_cat_gen.cpu().numpy()
        X_cont_gen = X_cont_gen.cpu().numpy()

        return X_cat_gen, X_cont_gen