import logging
import numpy as np
import torch
import os
import heapq

from time import time
from torch import optim
from tqdm import tqdm
from transformers import get_linear_schedule_with_warmup, get_constant_schedule_with_warmup

from model.utils import *

class Trainer(object):
    def __init__(self, args, model, data_num):
        self.args = args
        self.model = model

        self.lr = args.lr
        self.learner = args.learner
        self.lr_scheduler_type = args.lr_scheduler_type

        self.weight_decay = args.weight_decay
        self.epochs = args.epochs
        self.warmup_steps = args.warmup_epochs * data_num
        self.max_steps = args.epochs * data_num

        self.save_limit = args.save_limit
        self.best_save_heap = []
        self.newest_save_queue = []
        self.eval_step = min(args.eval_step, self.epochs)
        self.device = args.device
        self.device = torch.device(self.device)
        self.ckpt_dir = args.ckpt_dir
        saved_model_dir = "{}_beta_{:.3f}_codebook_[{},{},{}]_sk_{:.3f}".format(get_local_time(), 
                                                                                args.beta, 
                                                                                args.num_emb_list[0], 
                                                                                args.num_emb_list[1], 
                                                                                args.num_emb_list[2],
                                                                                args.sk_epsilons[2])
        self.ckpt_dir = os.path.join(self.ckpt_dir, saved_model_dir)
        ensure_dir(self.ckpt_dir)
        self.logger = self.setup_logger(self.ckpt_dir)

        self.best_loss = np.inf
        self.best_collision_rate = np.inf
        self.best_loss_ckpt = "best_loss_model.pth"
        self.best_collision_ckpt = "best_collision_model.pth"
        self.optimizer = self._build_optimizer()
        self.scheduler = self._get_scheduler()
        self.model = self.model.to(self.device)

    def setup_logger(self, ckpt_dir):
        logger = logging.getLogger("hrqvae")
        logger.setLevel(logging.INFO)

        # 避免重复添加 handler
        if not logger.handlers:
            # 文件输出
            file_handler = logging.FileHandler(os.path.join(ckpt_dir, "hrqvae.log"))
            file_handler.setLevel(logging.INFO)

            # 控制台输出
            console_handler = logging.StreamHandler()
            console_handler.setLevel(logging.INFO)

            # 统一日志格式
            formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
            file_handler.setFormatter(formatter)
            console_handler.setFormatter(formatter)

            logger.addHandler(file_handler)
            logger.addHandler(console_handler)

        return logger

    def _build_optimizer(self):
        params = self.model.parameters()
        learner = self.learner
        learning_rate = self.lr
        weight_decay = self.weight_decay

        if learner.lower() == "adam":
            optimizer = optim.Adam(params, lr=learning_rate, weight_decay=weight_decay)
        elif learner.lower() == "sgd":
            optimizer = optim.SGD(params, lr=learning_rate, weight_decay=weight_decay)
        elif learner.lower() == "adagrad":
            optimizer = optim.Adagrad(
                params, lr=learning_rate, weight_decay=weight_decay
            )
            for state in optimizer.state.values():
                for k, v in state.items():
                    if torch.is_tensor(v):
                        state[k] = v.to(self.device)
        elif learner.lower() == "adamw":
            optimizer = optim.AdamW(
                params, lr=learning_rate, weight_decay=weight_decay
            )
        elif learner.lower() == "rmsprop":
            optimizer = optim.RMSprop(
                params, lr=learning_rate, weight_decay=weight_decay
            )
        else:
            self.logger.warning(
                "Received unrecognized optimizer, set default Adam optimizer"
            )
            optimizer = optim.Adam(params, lr=learning_rate)
        return optimizer

    def _get_scheduler(self):
        if self.lr_scheduler_type == "linear":
            scheduler = get_linear_schedule_with_warmup(self.optimizer, 
                                                        num_warmup_steps=self.warmup_steps, 
                                                        num_training_steps=self.max_steps)
        elif self.lr_scheduler_type == "constant":
            scheduler = get_constant_schedule_with_warmup(self.optimizer, 
                                                          num_warmup_steps=self.warmup_steps)
        else:
            raise ValueError("Unsupported learning rate scheduler type: {}".format(self.lr_scheduler_type))
        return scheduler

    def _check_nan(self, loss):
        if torch.isnan(loss).any():
            raise ValueError("Training loss is nan") 
    
    def _train_epoch(self, train_data, epoch_idx):
        self.model.train()
        
        total_loss = 0.0
        total_recon_losss = 0.0
        iter_data = tqdm(train_data,
                         total=len(train_data),
                         ncols=100,
                         desc = set_color(f"Training Epoch {epoch_idx}", "pink"))
        
        for batch_idx, data in enumerate(iter_data):
            data = data.to(self.device)
            self.optimizer.zero_grad()
            out, rq_loss, indice = self.model(data)
            loss, loss_recon = self.model.compute_loss(out, rq_loss, xs=data)
            self._check_nan(loss)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
            self.optimizer.step()
            self.scheduler.step()
            total_loss += loss.item()
            total_recon_losss += loss_recon.item()
        
        return total_loss, total_recon_losss
    
    @torch.no_grad()
    def _vaild_epoch(self, vaild_data):
        self.model.eval()

        iter_data = tqdm(vaild_data,
                         total=len(vaild_data),
                         ncols=100,
                         desc = set_color("Validating ", "pink"))

        indices_set = set()
        num_sample = 0
        for batch_idx, data in enumerate(iter_data):
            num_sample += len(data)
            data = data.to(self.device)
            indices = self.model.get_indices(data)
            indices = indices.view(-1,indices.shape[-1]).cpu().numpy()
            for index in indices:
                code = "-".join([str(int(_)) for _ in index])
                indices_set.add(code)

        collision_rate = (num_sample - len(list(indices_set)))/num_sample

        return collision_rate
    
    def _save_checkpoint(self, epoch, collision_rate = 1, ckpt_file=None):
        ckpt_path = os.path.join(self.ckpt_dir,ckpt_file) if ckpt_file \
            else os.path.join(self.ckpt_dir, 'epoch_%d_collision_%.4f_model.pth' % (epoch, collision_rate))
        
        state = {
            "args": self.args,
            "epoch": epoch,
            "best_loss": self.best_loss,
            "best_collision_rate": self.best_collision_rate,
            "state_dict": self.model.state_dict(),
            "optimizer": self.optimizer.state_dict(),
        }
        torch.save(state, ckpt_path, pickle_protocol=4)

        self.logger.info(
            set_color("Saving current", "blue") + f": {ckpt_path}"
        )

        return ckpt_path
    
    def _generate_train_loss_output(self, epoch_idx, s_time, e_time, loss, recon_loss):
        train_loss_output = (
            set_color("epoch %d training", "green")
            + " ["
            + set_color("time", "blue")
            + ": %.2fs, "
        ) % (epoch_idx, e_time - s_time)
        train_loss_output += set_color("train loss", "blue") + ": %.4f" % loss
        train_loss_output +=", "
        train_loss_output += set_color("reconstruction loss", "blue") + ": %.4f" % recon_loss
        return train_loss_output + "]"
    
    def fit(self, data):
        cur_eval_step = 0

        for epoch_idx in range(self.epochs):
            # train
            training_start_time = time()
            train_loss, train_recon_loss = self._train_epoch(data, epoch_idx)
            training_end_time = time()
            train_loss_output = self._generate_train_loss_output(
                epoch_idx, training_start_time, training_end_time, train_loss, train_recon_loss
            )
            self.logger.info(train_loss_output)
            self.logger.info(f"Epoch {epoch_idx + 1}/{self.epochs}")
            self.logger.info(f"Train loss: {train_loss}.")
            self.logger.info(f"Reconstruction loss: {train_recon_loss}.")

            # eval
            if (epoch_idx + 1) % self.eval_step == 0:
                valid_start_time = time()
                collision_rate = self._vaild_epoch(data)

                if train_loss < self.best_loss:
                    self.best_loss = train_loss
                    self._save_checkpoint(epoch = epoch_idx, 
                                          ckpt_file = self.best_loss_ckpt)

                if collision_rate < self.best_collision_rate:
                    self.best_collision_rate = collision_rate
                    cur_eval_step = 0
                    self._save_checkpoint(epoch_idx, 
                                          collision_rate = collision_rate, # type:ignore
                                          ckpt_file = self.best_collision_ckpt)
                else:
                    cur_eval_step += 1
                
                valid_end_time = time()
                valid_score_output = (
                    set_color("epoch %d evaluating", "green")
                    + " ["
                    + set_color("time", "blue")
                    + ": %.2fs, "
                    + set_color("collision_rate", "blue")
                    + ": %f]"
                ) % (epoch_idx, valid_end_time - valid_start_time, collision_rate)

                self.logger.info(valid_score_output)

                ckpt_path = self._save_checkpoint(epoch_idx, 
                                                  collision_rate=collision_rate) # type:ignore
                now_save = (-collision_rate, ckpt_path)
                if len(self.newest_save_queue) < self.save_limit:
                    self.newest_save_queue.append(now_save)
                    heapq.heappush(self.best_save_heap, now_save)
                else:
                    old_save = self.newest_save_queue.pop(0)
                    self.newest_save_queue.append(now_save)
                    if collision_rate < -self.best_save_heap[0][0]:
                        bad_save = heapq.heappop(self.best_save_heap)
                        heapq.heappush(self.best_save_heap, now_save)

                        if bad_save not in self.newest_save_queue:
                            delete_file(bad_save[1])

                    if old_save not in self.best_save_heap:
                        delete_file(old_save[1])
                self.logger.info(f"Collision_rate: {collision_rate}.")
            
        return self.best_loss, self.best_collision_rate
