from .experiment import Experiment
import numpy as np
import torch
import os
import time

from copy import deepcopy
from tqdm import tqdm
from .models.tabddpm.gaussian_multinomial_diffusion import GaussianMultinomialDiffusion
from .models.tabddpm.utils import set_seeds, cycle

           

class Experiment_TabDDPM(Experiment):
    """
    Based on the synthcity implementation of TabDDPM.
    Adjusted to ensure same architectures as well as unconditional generation.
    """
    
    def __init__(self, config, args):
        super().__init__(config, args)
        
    
    def train(self, **kwargs):
        
        save_model = kwargs.get('save_model', False)
        set_seeds(self.seed, cuda_deterministic=True)
        
        train_loader, _ = self.data_processor.get_data_loaders()
        train_iter = cycle(train_loader)
        X_num_train = train_loader.data[1]
        self.n_num_feat = X_num_train.shape[1]
        
        self.diffusion = GaussianMultinomialDiffusion(
            num_numerical_features=self.n_num_feat,
            num_categorical_features=tuple(self.data_processor.X_cat_n_classes),
            model_params=self.config.model.model_params,
            num_timesteps=self.config.model.num_timesteps,
            dim_emb=self.config.model.dim_emb,
            gaussian_loss_type=self.config.model.gaussian_loss_type,
            scheduler=self.config.model.scheduler,
            device=self.device,
        ).to(self.device)
            
        self.ema_model = deepcopy(self.diffusion.denoise_fn)
        for param in self.ema_model.parameters():
            param.detach_()

        self.optimizer = torch.optim.AdamW(
            self.diffusion.parameters(), lr=self.config.model.lr, weight_decay=self.config.model.weight_decay
        )

        step = 0
        curr_loss_multi = 0.0
        curr_loss_gauss = 0.0
        curr_count = 0
        self.n_iter = self.config.model.train_steps
        pbar = tqdm(range(self.n_iter), desc="Steps")
        training_start_time = time.monotonic()
        
        while step < self.n_iter:
            
            check_time = (time.monotonic() - training_start_time) / 60
            if check_time > 30:
                break
            
            self.optimizer.zero_grad()
            x_cat, x_num = next(train_iter)
            x = torch.column_stack((x_num, x_cat)).to(self.device) # this order of concat is needed for GaussianMultinomialDiffusion to work properly
            loss_multi, loss_gauss = self.diffusion.mixed_loss(x, cond=None)
            loss = loss_multi + loss_gauss
            loss.backward()
            self.optimizer.step()

            curr_count += len(x)
            curr_loss_multi += loss_multi.item() * len(x)
            curr_loss_gauss += loss_gauss.item() * len(x)

            self._update_ema(
                self.ema_model.parameters(), self.diffusion.parameters()
            )
            pbar.update(1)
            step += 1
            self._anneal_lr(step)

            if step % self.config.model.log_steps == 0:
                mloss = np.around(curr_loss_multi / curr_count, 4)
                gloss = np.around(curr_loss_gauss / curr_count, 4)
                loss = mloss + gloss
                pbar.set_description(f"Losses (last {self.config.model.log_steps} steps): {loss:.4f}")
                curr_count = 0
                curr_loss_gauss = 0.0
                curr_loss_multi = 0.0
                    
        pbar.close()
        training_duration = time.monotonic() - training_start_time
        if save_model: 
            self.save_train_time(training_duration)
            self.save_model()

    def _anneal_lr(self, step: int) -> None:
        frac_done = step / self.n_iter
        lr = self.config.model.lr * (1 - frac_done)
        for param_group in self.optimizer.param_groups:
            param_group["lr"] = lr

    def _update_ema(self, target_params, source_params, rate=0.999):
        """
        Update target parameters to be closer to those of source parameters using
        an exponential moving average.
        :param target_params: the target parameter sequence.
        :param source_params: the source parameter sequence.
        :param rate: the EMA rate (closer to 1 means slower).
        """
        for targ, src in zip(target_params, source_params):
            targ.detach().mul_(rate).add_(src.detach(), alpha=1 - rate)         
    
    def save_model(self):
        torch.save(self.diffusion.state_dict(), os.path.join(self.workdir, "model.pt"))
    
    def load_model(self):
        
        set_seeds(self.seed)
        train_loader, _ = self.data_processor.get_data_loaders()
        X_num_train = train_loader.data[1]
        self.n_num_feat = X_num_train.shape[1]
        
        self.diffusion = GaussianMultinomialDiffusion(
            num_numerical_features=self.n_num_feat,
            num_categorical_features=tuple(self.data_processor.X_cat_n_classes),
            model_params=self.config.model.model_params,
            num_timesteps=self.config.model.num_timesteps,
            dim_emb=self.config.model.dim_emb,
            gaussian_loss_type=self.config.model.gaussian_loss_type,
            scheduler=self.config.model.scheduler,
            device=self.device,
        )
        checkpoint = torch.load(os.path.join(self.workdir, "model.pt"))
        self.diffusion.load_state_dict(checkpoint)
        self.diffusion.to(self.device)
        self.diffusion.eval()
            
    def sample(self, num_samples, seed):
        
        set_seeds(seed, cuda_deterministic=True)
        sample = self.diffusion.sample_all(num_samples, None, 
                                           max_batch_size=self.config.model.sample_batch_size).detach().cpu().numpy()
        
        # get data into required shape
        X_cat_gen = sample[:, self.n_num_feat:]
        X_cont_gen = sample[:, :self.n_num_feat]

        return X_cat_gen, X_cont_gen