from .experiment import Experiment
import torch
import os
import time
import pickle
import numpy as np

from .models.cdtd import CDTD
from .utils import set_seeds


class Experiment_CDTD(Experiment):
    
    def __init__(self, config, args):
        super().__init__(config, args)
        
    
    def train(self, **kwargs):
        
        save_model = kwargs.get('save_model', False)
        set_seeds(self.seed, cuda_deterministic=True)
        
        train_loader, _ = self.data_processor.get_data_loaders()
        X_cat_train = train_loader.data[0]
        X_num_train = train_loader.data[1]
        
        self.model = CDTD(X_cat_train, X_num_train,
                     self.config.model.cat_emb_dim,
                     self.config.model.mlp_emb_dim,
                     self.config.model.mlp_n_layers,
                     self.config.model.mlp_n_units,
                     self.config.model.sigma_data_cat,
                     self.config.model.sigma_data_cont,
                     self.config.model.sigma_min_cat,
                     self.config.model.sigma_min_cont,
                     self.config.model.sigma_max_cat,
                     self.config.model.sigma_max_cont,
                     self.config.model.cat_emb_init_sigma,
                     self.config.model.timewarp_type,
                     self.config.model.timewarp_weight_low_noise,
                     )
        
        num_params = sum(p.numel() for p in self.model.diff_model.parameters())
        print("Total parameters = ", num_params)
        
        training_start_time = time.monotonic()
        self.model.fit(X_cat_train, X_num_train,
                  self.config.training.num_steps_train,
                  self.config.training.num_steps_lr_warmup,
                  self.config.training.batch_size,
                  self.config.training.lr,
                  self.seed,
                  self.config.training.ema_decay,
                  self.config.training.log_steps,
                  )
        training_duration = time.monotonic() - training_start_time
        if save_model: 
            self.save_train_time(training_duration)
            self.save_model()
            
            
    def tune_cdtd(self):

        grid = [0.2, 0.4, 0.6, 0.8, 1.0, 1.2, 1.4, 1.6, 1.8, 2.0, 2.5, 3.0, 3.5, 4.0]
        for i, rel_cat_weight in enumerate(grid):
            
            set_seeds(self.seed, cuda_deterministic=True)
            train_loader, _ = self.data_processor.get_data_loaders()
            X_cat_train = train_loader.data[0]
            X_num_train = train_loader.data[1]
            
            self.model = CDTD(X_cat_train, X_num_train,
                        self.config.model.cat_emb_dim,
                        self.config.model.mlp_emb_dim,
                        self.config.model.mlp_n_layers,
                        self.config.model.mlp_n_units,
                        self.config.model.sigma_data_cat,
                        self.config.model.sigma_data_cont,
                        self.config.model.sigma_min_cat,
                        self.config.model.sigma_min_cont,
                        self.config.model.sigma_max_cat,
                        self.config.model.sigma_max_cont,
                        self.config.model.cat_emb_init_sigma,
                        self.config.model.timewarp_type,
                        self.config.model.timewarp_weight_low_noise,
                        rel_cat_weight=rel_cat_weight,
                        )
            
            self.model.fit(X_cat_train, X_num_train,
                    self.config.training.num_steps_train,
                    self.config.training.num_steps_lr_warmup,
                    self.config.training.batch_size,
                    self.config.training.lr,
                    self.seed,
                    self.config.training.ema_decay,
                    self.config.training.log_steps,
                    )
            
            self.save_model()
            self.evaluate()
            
            # load results and save with different name
            with open(self.workdir / 'eval_results.pkl', 'rb') as f:
                eval_results = pickle.load(f)
                
            with open(self.workdir / f'eval_results_{i}.pkl', 'wb') as f:
                pickle.dump(eval_results, f)
                
                
    def save_model(self):
        torch.save(self.model.diff_model.state_dict(), os.path.join(self.workdir, "model.pt"))
        
    
    def load_model(self):
        
        set_seeds(self.seed)
        train_loader, _ = self.data_processor.get_data_loaders()
        X_cat_train = train_loader.data[0]
        X_num_train = train_loader.data[1]
        
        self.model = CDTD(X_cat_train, X_num_train,
                     self.config.model.cat_emb_dim,
                     self.config.model.mlp_emb_dim,
                     self.config.model.mlp_n_layers,
                     self.config.model.mlp_n_units,
                     self.config.model.sigma_data_cat,
                     self.config.model.sigma_data_cont,
                     self.config.model.sigma_min_cat,
                     self.config.model.sigma_min_cont,
                     self.config.model.sigma_max_cat,
                     self.config.model.sigma_max_cont,
                     self.config.model.cat_emb_init_sigma,
                     self.config.model.timewarp_type,
                     self.config.model.timewarp_weight_low_noise,
                     )
        checkpoint = torch.load(os.path.join(self.workdir, "model.pt"))
        self.model.diff_model.load_state_dict(checkpoint)
        self.model.diff_model.to(self.device)
        self.model.diff_model.eval()

            
    def sample(self, num_samples, seed):
        X_cat_gen, X_num_gen = self.model.sample(num_samples, self.config.model.generation_steps,
                                                 self.config.model.generation_batch_size, seed)
        return X_cat_gen, X_num_gen