import glob
import os
import time
from copy import deepcopy

import numpy as np
import torch
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm

from .utils import update_ema

# from https://github.com/MinkaiXu/TabDiff/blob/main/tabdiff/trainer.py

BAR = "=============="


def print_with_bar(log_msg):
    log_msg = BAR + log_msg + BAR
    if "End" in log_msg:
        log_msg += "\n"
    print(log_msg)


class Trainer:
    def __init__(
        self,
        diffusion,
        train_iter,
        d_numerical,
        categories,
        lr,
        weight_decay,
        steps,
        check_val_every,
        sample_batch_size,
        model_save_path,
        result_save_path,
        num_samples_to_generate=None,
        lr_scheduler="reduce_lr_on_plateau",
        reduce_lr_patience=100,
        factor=0.9,
        ema_decay=0.997,
        closs_weight_schedule="fixed",
        c_lambda=1.0,
        d_lambda=1.0,
        device=torch.device("cuda:1"),
        ckpt_path=None,
        y_only=False,
        verbose=True,
        **kwargs,
    ):
        self.y_only = y_only
        self.diffusion = diffusion
        self.ema_model = deepcopy(self.diffusion._denoise_fn)
        for param in self.ema_model.parameters():
            param.detach_()
        self.ema_num_schedule = deepcopy(self.diffusion.num_schedule)
        for param in self.ema_num_schedule.parameters():
            param.detach_()
        self.ema_cat_schedule = deepcopy(self.diffusion.cat_schedule)
        for param in self.ema_cat_schedule.parameters():
            param.detach_()

        self.train_iter = train_iter
        self.steps = steps
        self.init_lr = lr
        self.optimizer = torch.optim.AdamW(self.diffusion.parameters(), lr=lr, weight_decay=weight_decay)
        self.ema_decay = ema_decay
        self.lr_scheduler = lr_scheduler
        self.scheduler = ReduceLROnPlateau(self.optimizer, mode="min", factor=factor, patience=reduce_lr_patience)
        self.closs_weight_schedule = closs_weight_schedule
        self.c_lambda = c_lambda
        self.d_lambda = d_lambda

        self.d_numerical = d_numerical
        self.categories = categories
        self.sample_batch_size = sample_batch_size
        self.num_samples_to_generate = num_samples_to_generate
        self.check_val_every = check_val_every

        self.verbose = verbose
        self.device = device
        self.model_save_path = model_save_path
        self.result_save_path = result_save_path
        self.ckpt_path = ckpt_path
        if self.ckpt_path is not None:
            state_dicts = torch.load(self.ckpt_path, map_location=self.device)
            self.diffusion._denoise_fn.load_state_dict(state_dicts["denoise_fn"])
            self.diffusion.num_schedule.load_state_dict(state_dicts["num_schedule"])
            self.diffusion.cat_schedule.load_state_dict(state_dicts["cat_schedule"])
            print(f"Weights are loaded from {self.ckpt_path}")

        self.curr_epoch = (
            int(os.path.basename(self.ckpt_path).split("_")[-1].split(".")[0]) if self.ckpt_path is not None else 0
        )

    def _anneal_lr(self, step):
        frac_done = step / self.steps
        lr = self.init_lr * (1 - frac_done)
        for param_group in self.optimizer.param_groups:
            param_group["lr"] = lr

    def _run_step(self, x, closs_weight, dloss_weight):
        x = x.to(self.device)

        self.diffusion.train()

        self.optimizer.zero_grad()

        dloss, closs = self.diffusion.mixed_loss(x)

        loss = dloss_weight * dloss + closs_weight * closs
        loss.backward()
        self.optimizer.step()

        return dloss, closs

    def compute_loss(self):  # eval loss is not weighted
        curr_dloss = 0.0
        curr_closs = 0.0
        curr_count = 0
        data_iter = self.train_iter
        for batch in data_iter:
            x = batch.float().to(self.device)
            self.diffusion.eval()
            with torch.no_grad():
                batch_dloss, batch_closs = self.diffusion.mixed_loss(x)
            curr_dloss += batch_dloss.item() * len(x)
            curr_closs += batch_closs.item() * len(x)
            curr_count += len(x)
        mloss = np.around(curr_dloss / curr_count, 4)
        gloss = np.around(curr_closs / curr_count, 4)
        return mloss, gloss

    def run_loop(self):
        patience = 0
        saved_a_checkpoint = False
        closs_weight, dloss_weight = self.c_lambda, self.d_lambda
        best_loss = np.inf
        best_ema_loss = np.inf
        best_val_loss = np.inf
        start_time = time.monotonic()
        print_with_bar(f"Starting Training Loop, total number of epoch = {self.steps}")

        start_epoch = self.curr_epoch
        if start_epoch > 0:
            print_with_bar(
                f"Resuming training from epoch {start_epoch}, with validation check every {self.check_val_every} epoches"
            )
        for epoch in tqdm(range(start_epoch, self.steps), disable=self.verbose):
            self.curr_epoch = epoch + 1
            is_last_epoch = (epoch + 1) == self.steps
            # Set up pbar
            pbar = tqdm(self.train_iter, total=len(self.train_iter), disable=not self.verbose)
            pbar.set_description(f"Epoch {epoch + 1}/{self.steps}")

            check_time = (time.monotonic() - start_time) / 60
            if check_time > 30:
                break

            # Compute the loss weights
            if self.closs_weight_schedule == "fixed":
                pass
            elif self.closs_weight_schedule == "anneal":
                frac_done = epoch / self.steps
                closs_weight = self.c_lambda * (1 - frac_done)
            else:
                raise NotImplementedError(
                    f"The continuous loss weight schedule {self.closs_weight_schedule} is not implemneted"
                )

            # Training Step
            curr_dloss = 0.0
            curr_closs = 0.0
            curr_count = 0
            curr_lr = self.optimizer.param_groups[0]["lr"]
            for batch in pbar:
                x = batch.float().to(self.device)
                batch_dloss, batch_closs = self._run_step(x, closs_weight, dloss_weight)
                curr_dloss += batch_dloss.item() * len(x)
                curr_closs += batch_closs.item() * len(x)
                curr_count += len(x)
                pbar.set_postfix(
                    {
                        "lr": curr_lr,
                        "DLoss": np.around(curr_dloss / curr_count, 4),
                        "CLoss": np.around(curr_closs / curr_count, 4),
                        "TotalLoss": np.around((curr_dloss + curr_closs) / curr_count, 4),
                        "closs_weight": closs_weight,
                        "dloss_weight": dloss_weight,
                    }
                )

            # Log training Loss
            log_dict = {}
            mloss = np.around(curr_dloss / curr_count, 4)
            gloss = np.around(curr_closs / curr_count, 4)
            total_loss = mloss + gloss
            if np.isnan(gloss):
                print("Finding Nan in gaussian loss")
                break
            loss_dict = {
                "epoch": epoch + 1,
                "lr": curr_lr,
                "closs_weight": closs_weight,
                "dloss_weight": dloss_weight,
                "loss/c_loss": gloss,
                "loss/d_loss": mloss,
                "loss/total_loss": total_loss,
            }
            log_dict.update(loss_dict)

            # Log the learned noise schedules for numerical dimensions
            if self.d_numerical > 0:  # numerical data is not empty
                num_noise_dict = {}
                if self.diffusion.num_schedule.rho().dim() == 0:  # non-learnable num schedule
                    num_noise_dict = {"num_noise/rho": self.diffusion.num_schedule.rho().item()}
                else:
                    num_noise_dict = {
                        f"num_noise/rho_col_{i}": value.item()
                        for i, value in enumerate(self.diffusion.num_schedule.rho())
                    }
                log_dict.update(num_noise_dict)

            # Log the learned noise schedules for categlrical dimensions
            if len(self.categories) > 0:  # categorical data is not empty
                cat_noise_dict = {}
                if self.diffusion.cat_schedule.k().dim() == 0:  # non-learnable cat schedule
                    cat_noise_dict = {"cat_noise/k": self.diffusion.cat_schedule.k().item()}
                else:
                    cat_noise_dict = {
                        f"cat_noise/k_col_{i}": value.item() for i, value in enumerate(self.diffusion.cat_schedule.k())
                    }
                log_dict.update(cat_noise_dict)

            # Adjust learning rate
            if self.lr_scheduler == "reduce_lr_on_plateau":
                self.scheduler.step(total_loss)
            elif self.lr_scheduler == "anneal":
                self._anneal_lr(epoch)
            elif self.lr_scheduler == "fixed":
                pass
            else:
                raise NotImplementedError(f"LR scheduler with name '{self.lr_scheduler}' is not implemented")

            # Update EMA models
            update_ema(self.ema_model.parameters(), self.diffusion._denoise_fn.parameters(), rate=self.ema_decay)
            update_ema(
                self.ema_num_schedule.parameters(), self.diffusion.num_schedule.parameters(), rate=self.ema_decay
            )
            update_ema(
                self.ema_cat_schedule.parameters(), self.diffusion.cat_schedule.parameters(), rate=self.ema_decay
            )

            # Compute and log EMA model loss
            curr_model, curr_num_schedule, curr_cat_schedule = self.to_ema_model()
            ema_mloss, ema_gloss = self.compute_loss()
            self.to_model(curr_model, curr_num_schedule, curr_cat_schedule)
            ema_total_loss = ema_mloss + ema_gloss
            ema_loss_dict = {
                "ema_loss/c_loss": ema_gloss,
                "ema_loss/d_loss": ema_mloss,
                "ema_loss/total_loss": ema_total_loss,
            }

            # Save the best ema ckpt
            if ema_total_loss < best_ema_loss and (self.curr_epoch > 4000 or is_last_epoch):
                best_ema_loss = ema_total_loss
                to_remove = glob.glob(os.path.join(self.model_save_path, f"best_ema_model_*"))
                if to_remove:
                    os.remove(to_remove[0])
                state_dicts = {
                    "denoise_fn": self.ema_model.state_dict(),
                    "num_schedule": self.ema_num_schedule.state_dict(),
                    "cat_schedule": self.ema_cat_schedule.state_dict(),
                }
                torch.save(
                    state_dicts,
                    os.path.join(self.model_save_path, f"best_ema_model_{np.round(ema_total_loss, 4)}_{epoch + 1}.pt"),
                )
                saved_a_checkpoint = True

        end_time = time.monotonic()
        print_with_bar(f"Ending Training Loop, total training time = {end_time - start_time}")

        if not saved_a_checkpoint:
            # in case we overwrite earlier results
            to_remove = glob.glob(os.path.join(self.model_save_path, f"best_ema_model_*"))
            if to_remove:
                os.remove(to_remove[0])
            state_dicts = {
                "denoise_fn": self.ema_model.state_dict(),
                "num_schedule": self.ema_num_schedule.state_dict(),
                "cat_schedule": self.ema_cat_schedule.state_dict(),
            }
            torch.save(
                state_dicts,
                os.path.join(self.model_save_path, f"best_ema_model_{np.round(ema_total_loss, 4)}_{epoch + 1}.pt"),
            )

    def sample_synthetic(self, num_samples, keep_nan_samples=True, ema=False):
        if ema:
            curr_model, curr_num_schedule, curr_cat_schedule = self.to_ema_model()

        syn_data = self.diffusion.sample_all(num_samples, self.sample_batch_size, keep_nan_samples=keep_nan_samples)
        print(f"Shape of the generated sample = {syn_data.shape}")

        if keep_nan_samples:
            num_all_zero_row = (syn_data.sum(dim=1) == 0).sum()
            if num_all_zero_row:
                print(f"The generated samples contain {num_all_zero_row} Nan instances!!!")

        return syn_data

    def to_ema_model(self):
        curr_model = self.diffusion._denoise_fn
        curr_num_schedule = self.diffusion.num_schedule
        curr_cat_schedule = self.diffusion.cat_schedule
        self.diffusion._denoise_fn = self.ema_model  # temporarily install the ema parameters into the model
        self.diffusion.num_schedule = self.ema_num_schedule
        self.diffusion.cat_schedule = self.ema_cat_schedule

        return curr_model, curr_num_schedule, curr_cat_schedule

    def to_model(self, curr_model, curr_num_schedule, curr_cat_schedule):
        self.diffusion._denoise_fn = curr_model  # give back the parameters
        self.diffusion.num_schedule = curr_num_schedule
        self.diffusion.cat_schedule = curr_cat_schedule
