import os
import glob
import time
import torch
from torch.optim.lr_scheduler import ReduceLROnPlateau
import numpy as np

from copy import deepcopy

from .utils import update_ema

from tqdm import tqdm

# from https://github.com/MinkaiXu/TabDiff/blob/main/tabdiff/trainer.py

BAR = "=============="
def print_with_bar(log_msg):
    log_msg = BAR + log_msg + BAR
    if "End" in log_msg:
         log_msg += "\n"
    print(log_msg)

class Trainer:
    def __init__(
            self, diffusion, 
            train_iter, d_numerical, categories,
            lr, weight_decay,
            steps, check_val_every,
            sample_batch_size, model_save_path, result_save_path,
            num_samples_to_generate=None,
            lr_scheduler='reduce_lr_on_plateau',
            reduce_lr_patience=100, factor=0.9, 
            ema_decay=0.997,
            closs_weight_schedule = "fixed",
            c_lambda = 1.0,
            d_lambda = 1.0,
            device=torch.device('cuda:1'),
            ckpt_path = None,
            y_only=False,
            verbose=True,
            **kwargs
    ):
        self.y_only = y_only
        self.diffusion = diffusion
        self.ema_model = deepcopy(self.diffusion._denoise_fn)
        for param in self.ema_model.parameters():
            param.detach_()
        self.ema_num_schedule = deepcopy(self.diffusion.num_schedule)
        for param in self.ema_num_schedule.parameters():
            param.detach_()
        self.ema_cat_schedule = deepcopy(self.diffusion.cat_schedule)
        for param in self.ema_cat_schedule.parameters():
            param.detach_()

        self.train_iter = train_iter
        self.steps = steps
        self.init_lr = lr
        self.optimizer = torch.optim.AdamW(self.diffusion.parameters(), lr=lr, weight_decay=weight_decay)
        self.ema_decay = ema_decay
        self.lr_scheduler = lr_scheduler
        self.scheduler = ReduceLROnPlateau(self.optimizer, mode='min', 
                                           factor=factor, patience=reduce_lr_patience)
        self.closs_weight_schedule = closs_weight_schedule
        self.c_lambda = c_lambda
        self.d_lambda = d_lambda

        self.d_numerical = d_numerical
        self.categories = categories
        self.sample_batch_size = sample_batch_size
        self.num_samples_to_generate = num_samples_to_generate
        self.check_val_every = check_val_every
        
        self.verbose = verbose
        self.device = device
        self.model_save_path = model_save_path
        self.result_save_path = result_save_path
        self.ckpt_path = ckpt_path
        if self.ckpt_path is not None:
            state_dicts = torch.load(self.ckpt_path, map_location=self.device)
            self.diffusion._denoise_fn.load_state_dict(state_dicts['denoise_fn'])
            self.diffusion.num_schedule.load_state_dict(state_dicts['num_schedule'])
            self.diffusion.cat_schedule.load_state_dict(state_dicts['cat_schedule'])   
            print(f"Weights are loaded from {self.ckpt_path}")     
        
        self.curr_epoch = int(os.path.basename(self.ckpt_path).split('_')[-1].split('.')[0]) if self.ckpt_path is not None else 0

    def _anneal_lr(self, step):
        frac_done = step / self.steps
        lr = self.init_lr * (1 - frac_done)
        for param_group in self.optimizer.param_groups:
            param_group["lr"] = lr

    def _run_step(self, x, closs_weight, dloss_weight):
        x = x.to(self.device)
        
        self.diffusion.train()

        self.optimizer.zero_grad()

        dloss, closs = self.diffusion.mixed_loss(x)

        loss = dloss_weight * dloss + closs_weight * closs
        loss.backward()
        self.optimizer.step()

        return dloss, closs
    
    def compute_loss(self):      # eval loss is not weighted
        curr_dloss = 0.0
        curr_closs = 0.0
        curr_count = 0
        data_iter = self.train_iter
        for batch in data_iter:
            x = batch.float().to(self.device)
            self.diffusion.eval()
            with torch.no_grad():
                batch_dloss, batch_closs = self.diffusion.mixed_loss(x)
            curr_dloss += batch_dloss.item() * len(x)
            curr_closs += batch_closs.item() * len(x)
            curr_count += len(x)
        mloss = np.around(curr_dloss / curr_count, 4)
        gloss = np.around(curr_closs / curr_count, 4)
        return mloss, gloss
    
    def run_loop(self):
        patience = 0
        saved_a_checkpoint = False
        closs_weight, dloss_weight = self.c_lambda, self.d_lambda
        best_loss = np.inf
        best_ema_loss = np.inf
        best_val_loss = np.inf
        start_time = time.monotonic()
        print_with_bar(f"Starting Training Loop, total number of epoch = {self.steps}")
        
        start_epoch = self.curr_epoch
        if start_epoch > 0:
            print_with_bar(f"Resuming training from epoch {start_epoch}, with validation check every {self.check_val_every} epoches")
        for epoch in tqdm(range(start_epoch, self.steps), disable=self.verbose):
            self.curr_epoch = epoch+1
            is_last_epoch = (epoch + 1) == self.steps
            # Set up pbar
            pbar = tqdm(self.train_iter, total=len(self.train_iter), disable=not self.verbose)
            pbar.set_description(f"Epoch {epoch+1}/{self.steps}")
            
            check_time = (time.monotonic() - start_time) / 60
            if check_time > 30:
                break
            
            # Compute the loss weights
            if self.closs_weight_schedule == "fixed":
                pass
            elif self.closs_weight_schedule == "anneal":
                frac_done = epoch / self.steps
                closs_weight = self.c_lambda * (1 - frac_done)
            else:
                raise NotImplementedError(f"The continuous loss weight schedule {self.closs_weight_schedule} is not implemneted")

            # Training Step
            curr_dloss = 0.0
            curr_closs = 0.0
            curr_count = 0
            curr_lr = self.optimizer.param_groups[0]['lr']
            for batch in pbar:
                x = batch.float().to(self.device)
                batch_dloss, batch_closs = self._run_step(x, closs_weight, dloss_weight)
                curr_dloss += batch_dloss.item() * len(x)
                curr_closs += batch_closs.item() * len(x)
                curr_count += len(x)
                pbar.set_postfix({
                    "lr": curr_lr,
                    "DLoss": np.around(curr_dloss/curr_count, 4),
                    "CLoss": np.around(curr_closs/curr_count, 4),
                    "TotalLoss": np.around((curr_dloss + curr_closs)/curr_count, 4),
                    "closs_weight": closs_weight,
                    "dloss_weight": dloss_weight,
                })
                
            # Log training Loss
            log_dict = {}
            mloss = np.around(curr_dloss / curr_count, 4)
            gloss = np.around(curr_closs / curr_count, 4)
            total_loss = mloss + gloss
            if np.isnan(gloss):
                    print('Finding Nan in gaussian loss')
                    break
            loss_dict = {
                "epoch": epoch + 1,
                "lr": curr_lr,
                "closs_weight": closs_weight,
                "dloss_weight": dloss_weight,
                "loss/c_loss": gloss,
                "loss/d_loss": mloss,
                "loss/total_loss": total_loss
            }
            log_dict.update(loss_dict)
            
            # Log the learned noise schedules for numerical dimensions
            if self.d_numerical > 0:    # numerical data is not empty
                num_noise_dict = {}
                if self.diffusion.num_schedule.rho().dim() == 0:   # non-learnable num schedule
                    num_noise_dict = {"num_noise/rho": self.diffusion.num_schedule.rho().item()}
                else:
                    num_noise_dict = {f"num_noise/rho_col_{i}": value.item() for i, value in enumerate(self.diffusion.num_schedule.rho())}
                log_dict.update(num_noise_dict)            

            # Log the learned noise schedules for categlrical dimensions
            if len(self.categories) > 0:    # categorical data is not empty
                cat_noise_dict = {}
                if self.diffusion.cat_schedule.k().dim() == 0:   # non-learnable cat schedule
                    cat_noise_dict = {"cat_noise/k": self.diffusion.cat_schedule.k().item()}
                else:
                    cat_noise_dict = {f"cat_noise/k_col_{i}": value.item() for i, value in enumerate(self.diffusion.cat_schedule.k())}
                log_dict.update(cat_noise_dict)
            
            # Adjust learning rate
            if self.lr_scheduler == 'reduce_lr_on_plateau':
                self.scheduler.step(total_loss)
            elif  self.lr_scheduler == 'anneal':
                self._anneal_lr(epoch)
            elif self.lr_scheduler == 'fixed':
                pass
            else:
                raise NotImplementedError(f"LR scheduler with name '{self.lr_scheduler}' is not implemented")
            
            # Update EMA models
            update_ema(self.ema_model.parameters(), self.diffusion._denoise_fn.parameters(), rate=self.ema_decay)
            update_ema(self.ema_num_schedule.parameters(), self.diffusion.num_schedule.parameters(), rate=self.ema_decay)
            update_ema(self.ema_cat_schedule.parameters(), self.diffusion.cat_schedule.parameters(), rate=self.ema_decay)

            # Save ckpt base on the best training loss
            # if total_loss < best_loss and (self.curr_epoch > 4000 or is_last_epoch):
            #     best_loss = total_loss
            #     to_remove = glob.glob(os.path.join(self.model_save_path, f"best_model_*"))
            #     if to_remove:
            #         os.remove(to_remove[0])
            #     state_dicts = {
            #         'denoise_fn': self.diffusion._denoise_fn.state_dict(), 
            #         'num_schedule':self.diffusion.num_schedule.state_dict(), 
            #         'cat_schedule': self.diffusion.cat_schedule.state_dict(),
            #     }
            #     torch.save(state_dicts, os.path.join(self.model_save_path, f'best_model_{np.round(total_loss,4)}_{epoch+1}.pt'))
            #     patience = 0
            # else:
            #     patience += 1   # increment patience if best loss is not surpassed
            
            # Compute and log EMA model loss
            curr_model, curr_num_schedule, curr_cat_schedule = self.to_ema_model()
            ema_mloss, ema_gloss = self.compute_loss()
            self.to_model(curr_model, curr_num_schedule, curr_cat_schedule)
            ema_total_loss = ema_mloss + ema_gloss
            ema_loss_dict = {
                "ema_loss/c_loss": ema_gloss,
                "ema_loss/d_loss": ema_mloss,
                "ema_loss/total_loss": ema_total_loss
            }
            
            # Save the best ema ckpt
            if ema_total_loss < best_ema_loss and (self.curr_epoch > 4000 or is_last_epoch):
                best_ema_loss = ema_total_loss
                to_remove = glob.glob(os.path.join(self.model_save_path, f"best_ema_model_*"))
                if to_remove:
                    os.remove(to_remove[0])
                state_dicts = {
                    'denoise_fn': self.ema_model.state_dict(), 
                    'num_schedule':self.ema_num_schedule.state_dict(), 
                    'cat_schedule': self.ema_cat_schedule.state_dict(),
                }
                torch.save(state_dicts, os.path.join(self.model_save_path, f'best_ema_model_{np.round(ema_total_loss,4)}_{epoch+1}.pt'))
                saved_a_checkpoint = True
                
        
            
        end_time = time.monotonic()
        print_with_bar(f"Ending Training Loop, total training time = {end_time - start_time}")
        
        if not saved_a_checkpoint:
            # in case we overwrite earlier results
            to_remove = glob.glob(os.path.join(self.model_save_path, f"best_ema_model_*"))
            if to_remove:
                os.remove(to_remove[0])
            state_dicts = {
                'denoise_fn': self.ema_model.state_dict(), 
                'num_schedule':self.ema_num_schedule.state_dict(), 
                'cat_schedule': self.ema_cat_schedule.state_dict(),
            }
            torch.save(state_dicts, os.path.join(self.model_save_path, f'best_ema_model_{np.round(ema_total_loss,4)}_{epoch+1}.pt'))


    def sample_synthetic(self, num_samples, keep_nan_samples=True, ema=False):
        if ema:
            curr_model, curr_num_schedule, curr_cat_schedule = self.to_ema_model()
   
        syn_data = self.diffusion.sample_all(num_samples, self.sample_batch_size, keep_nan_samples=keep_nan_samples)
        print(f"Shape of the generated sample = {syn_data.shape}")
        
        if keep_nan_samples:
            num_all_zero_row = (syn_data.sum(dim=1) == 0).sum()
            if num_all_zero_row:
                print(f"The generated samples contain {num_all_zero_row} Nan instances!!!")
        
        return syn_data
    
    def to_ema_model(self):
        curr_model = self.diffusion._denoise_fn
        curr_num_schedule = self.diffusion.num_schedule
        curr_cat_schedule = self.diffusion.cat_schedule
        self.diffusion._denoise_fn = self.ema_model  # temporarily install the ema parameters into the model
        self.diffusion.num_schedule = self.ema_num_schedule
        self.diffusion.cat_schedule = self.ema_cat_schedule
        
        return curr_model, curr_num_schedule, curr_cat_schedule

    def to_model(self, curr_model, curr_num_schedule, curr_cat_schedule):
        self.diffusion._denoise_fn = curr_model      # give back the parameters
        self.diffusion.num_schedule = curr_num_schedule
        self.diffusion.cat_schedule = curr_cat_schedule
        