import os
import re
import math
import glob
import yaml
import torch
import shutil
import logging

import numpy as np

from pathlib import Path
from contextlib import nullcontext
from torch.nn.utils import clip_grad_norm_

from typing import Dict, List, Optional, Tuple, Union

from tensorboardX import SummaryWriter
from torch.cuda.amp import GradScaler
from torch.utils.data import DataLoader

from nn_ss.sound_synthesis2.utils.misc import instantiate_from_config, format_seconds
from nn_ss.trainer.scheduler import LRScheduler
import time

from utils.checkpoint import (
    save_checkpoint,
    remove_checkpoints,
    average_checkpoints,
    save_optimizer
)


class Executor(object):
    def __init__(self,
                model: Union[torch.nn.Module, torch.nn.parallel.DistributedDataParallel],
                model_type: str,
                checkpoint_dir: Union[Path, str],
                scheduler: LRScheduler,
                num_checkpoints_to_average: int=5,
                writer: Optional[SummaryWriter]=None,
                scaler: Optional[GradScaler]=None) -> None:
        self.model = model
        self.model_type = model_type
        self.checkpoint_dir = checkpoint_dir
        self.scheduler = scheduler
        self.num_checkpoints_to_average = num_checkpoints_to_average
        self.writer = writer
        self.scaler = scaler
        
        self.saved_model = self.model
        # A context manager to be used in conjunction with an instance of
        # torch.nn.parallel.DistributedDataDataParallel to be able to train
        # with uneven inputs across participating processes.
        if isinstance(self.model, torch.nn.parallel.DistributedDataParallel):
            self.model_context = self.model.join
        else:
            self.model_context = nullcontext
    
    def epoch(self,
              epoch: int,
              train_dataloader: DataLoader,
              dev_dataloader: DataLoader,
              parameters: List,
              configs: Dict,
              device: torch.device):
        rank = configs.get("rank", 0)
        keep_topk_dir = configs.get("keep_topk_dir", False)

        # breakpoint()

        self.train(epoch=epoch,
                   train_dataloader=train_dataloader,
                   parameters=parameters,
                   configs=configs,
                   device=device)

        dev_metric_dict = self.evaluate(epoch=epoch,
                                        dev_dataloader=dev_dataloader,
                                        configs=configs,
                                        device=device)

        dev_loss, dev_ppl = dev_metric_dict.get("dev_loss"), dev_metric_dict.get("dev_ppl")
        logging.info("| Epoch {} | Dev loss {} | Dev ppl {}".format(
            epoch, dev_loss, dev_ppl
        ))
        self.scheduler.epoch(dev_ppl)
        if rank == 0:
            self.writer.add_scalar("epoch/dev_loss", dev_loss, epoch)
            self.writer.add_scalar("epoch/dev_ppl", dev_ppl, epoch)

            self.save_epoch_checkpoints(epoch=epoch,
                                        dev_metric_dict=dev_metric_dict,
                                        keep_topk_dir=keep_topk_dir)

    def train(self,
              epoch: int,
              train_dataloader: DataLoader,
              parameters: List,
              configs: Dict,
              device: torch.device):
        self.model.train()

        optimizer = self.scheduler.optimizer

        rank = configs.get("rank", 0)
        clip = configs.get("clip", 5.0)
        distributed = configs.get("distributed", False)
        log_interval = configs.get("log_interval", 200)
        accum_grad = configs.get("accum_grad", 1)
        use_amp = configs.get("use_amp", False)
        save_steps = configs.get("save_steps", 500)
        keep_last_k_ckpt = configs.get("keep_last_k_ckpt", 10)
        if save_steps * accum_grad < log_interval:
            save_steps = log_interval + 1

        mid_checkpoint_dir = None
        if rank == 0:
            mid_checkpoint_dir = os.path.join(
                self.checkpoint_dir, "{}_{}".format(self.model_type, epoch)
            )
            logging.info(f" | Make {mid_checkpoint_dir} for saving mid checkpoints in one epoch")
            os.makedirs(mid_checkpoint_dir, exist_ok=True)

        logging.info(' | using accumulate grad, new batch size is {}'
                    ' times larger than before'.format(accum_grad))
        
        if use_amp:
            assert self.scaler is not None

        with self.model_context():
            
            total_losses = 0.0
            log_codec_losses = 0.0
            log_dur_losses = 0.0
            log_mdn_losses = 0.0

            total_batch = 0
            epoch_start = time.time()
            itr_start = time.time()

            for batch_idx, batch_data in enumerate(train_dataloader):


                # print("jsp_tmp",batch_idx,batch_data)
                # assert 0
                
                if batch_idx == 0:
                    print("time2 is " + str(time.time()),epoch)
                    # print("batch_data:",epoch,batch_data)
                data_time = time.time() - itr_start
                step_start = time.time()

                # print("batch_data11:",batch_idx,batch_data.keys(),device)
                # print("batch_data:",batch_data)
                batch_data, batch_size = self.get_batch_data(batch_data, device)

                #到这一步的时候已经构造好每个batch了
                # print("jsp_tmp",batch_data)
                # assert 0

                if batch_size <= 0:
                    # print("batch_data22:", batch_size,batch_data)
                    continue

                context = None
                # Disable gradient synchronizations across DDP processes.
                # within this context, gradients will be accumulated on module
                # variables, which will later be synchronized.
                if distributed and batch_idx % accum_grad != 0:
                    context = self.model.no_sync
                # Used for single gpu training and DDP gradient synchronization
                # processes.
                else:
                    context = nullcontext


                with context():
                    # autocast context
                    # The more details about amp can be found in
                    # https://pytorch.org/docs/stable/notes/amp_examples.html
                    with torch.cuda.amp.autocast(self.scaler is not None):
                        # forward calculation of different models

                        outputs = self.model(batch=batch_data)

                        # print("outputs:",outputs.keys())
                        assert "loss" in outputs
                        total_loss = outputs["loss"] / accum_grad
                        # log_codec_loss = outputs["loss_codec"] / accum_grad
                        # log_dur_loss = outputs["loss_dur"] / accum_grad
                        # log_mdn_loss = outputs["loss_mdn"] / accum_grad

                        assert isinstance(total_loss, torch.Tensor)

                    if use_amp:
                        self.scaler.scale(total_loss).backward()
                    else:
                        total_loss.backward()



                if batch_idx % accum_grad == 0:
                    if rank == 0 and self.writer is not None:
                        self.writer.add_scalar("train_lr", self.scheduler.lr, self.scheduler.n_steps)
                        self.writer.add_scalar("train_loss", outputs["loss"], self.scheduler.n_steps)
                        self.writer.add_scalar("train_dur_loss", outputs["loss_dur"], self.scheduler.n_steps)
                        self.writer.add_scalar("train_codec_loss", outputs["loss_codec"], self.scheduler.n_steps)
                        self.writer.add_scalar("train_mdn_loss", outputs["loss_mdn"], self.scheduler.n_steps)
                        self.writer.add_scalar("train_acc", outputs["para_acc"], self.scheduler.n_steps)
                        self.writer.add_scalar("dur_acc", outputs["dur_acc"], self.scheduler.n_steps)
                

                    # Use mixed precision training
                    if use_amp:
                        self.scaler.unscale_(optimizer)
                        grad_norm = clip_grad_norm_(parameters, clip)
                        # Must invoke scaler.update() if unscale_() is used in
                        # the interation to avoid the following error:
                        #   RuntimeError: unscale_() has already been called
                        #   on this optimizer since the last update().
                        # We don't check grad here since that if the gradient
                        # has inf/nan values, scaler.step will skip optimizer.skep().
                        self.scaler.step(optimizer)
                        self.scaler.update()
                    else:
                        grad_norm = clip_grad_norm_(parameters, clip)
                        if torch.isfinite(grad_norm):
                            self.scheduler.step()
                    self.scheduler.zero_grad()

                total_losses += total_loss.item() * accum_grad
                # log_codec_losses += log_codec_loss.item() * accum_grad
                # log_dur_losses += log_dur_loss.item() * accum_grad
                # log_mdn_losses += log_mdn_loss.item() * accum_grad

                total_batch += batch_size
                # print("total_loss:",total_loss)
                # add train logger

                if batch_idx % log_interval == 0 and batch_idx > 0:
                    cur_batch_size = int(total_batch / log_interval) * accum_grad
                    cur_total_loss = total_losses / log_interval
                    # cur_codec_loss = log_codec_losses / log_interval
                    # cur_dur_loss = log_dur_losses / log_interval
                    # cur_mdn_loss = log_mdn_losses / log_interval

                    log_str = " | Train Batch {}/{} | Batch size {} | Total loss {:.4f} Codec loss {:.4f} Duration loss {:.4f} MDN loss {:.4f} | acc {} dur_acc {} n_q {}".format(
                        epoch, batch_idx, cur_batch_size, cur_total_loss, outputs["loss_codec"], outputs["loss_dur"], outputs["loss_mdn"], outputs['para_acc'],outputs['dur_acc'],outputs['para_q']
                    )

                    cur_ppl = math.exp(cur_total_loss)
                    log_str += " | PPL {:.3f}".format(cur_ppl)

                    log_str += ' | lr {:.6f} | rank {} | step {}'.format(self.scheduler.lr, rank, self.scheduler.n_steps)

                    log_str += ' || data_time: {dt}s | fbward_time: {fbt}s |  epoch_time: {et} '.format(
                        dt=round(data_time, 1),
                        fbt=round(time.time() - step_start, 1),
                        et=format_seconds(time.time() - epoch_start),
                    )

                    logging.debug(log_str)

                    total_losses = 0.0
                    total_batch = 0

                if mid_checkpoint_dir and rank == 0:
                    self.save_mid_checkpoints(epoch=epoch,
                                              save_steps=save_steps,
                                              keep_last_k_ckpt=keep_last_k_ckpt,
                                              mid_checkpoint_dir=mid_checkpoint_dir,
                                              loss=total_loss)
        if rank == 0:
            self.average_checkpoints(mid_checkpoint_dir)
    
    def evaluate(self,
                 epoch: int,
                 dev_dataloader: DataLoader,
                 configs: Dict,
                 device: torch.device):

        evaluated_model = self.saved_model
        # if isinstance(evaluated_model, torch.nn.parallel.DistributedDataParallel):
        #     evaluated_model = evaluated_model.module
        evaluated_model.eval()
        
        rank = configs.get("rank", 0)
        distributed = configs.get("distributed", False)
        log_interval = configs.get("log_interval", 200)

        total_losses = 0.0
        # in order to avoid division by zero
        num_seen_seq = 1

        with torch.no_grad():
            for batch_idx, batch_data in enumerate(dev_dataloader):
                batch_data, batch_size = self.get_batch_data(batch_data, device)
                
                if batch_size <= 0:
                    continue
                
                outputs = evaluated_model(batch=batch_data)

                total_loss = outputs["loss"]
                assert isinstance(total_loss, torch.Tensor)

                if torch.isfinite(total_loss):
                    num_seen_seq += batch_size
                    total_losses += total_loss.item() * batch_size
                
                # add eval logger
                if batch_idx % log_interval == 0 and batch_idx > 0:
                    history_loss = total_losses / num_seen_seq

                    log_str = " | DEV Batch {}/{} | Total loss {:.4f}".format(
                        epoch, batch_idx, total_loss.item())
                    log_str += " | History Loss {:.4f} | History PPL {:.3f}".format(
                        history_loss, math.exp(history_loss))
                    log_str += ' | lr {:.6f} | rank {} | step {}'.format(
                        self.scheduler.lr, rank, self.scheduler.n_steps)
                    logging.debug(log_str)
        average_nwp_loss = total_losses / num_seen_seq
        average_nwp_ppl = math.exp(average_nwp_loss)
        return {
            "dev_loss": average_nwp_loss,
            "dev_ppl": average_nwp_ppl
        }
    
    def get_batch_data(self,
                       batch_data: Dict,
                       device: Union[str, torch.device],
                       ) -> Dict:
        # print("111111111111111111")
        return_batch_data = {}
        batch_size = 0
        for k, v in batch_data.items():
            if isinstance(v, torch.Tensor):
                return_batch_data[k] = v.to(device)
            if k == "target_semantics":
                batch_size = v.size(0)
        # print("2222222222222222222222",batch_size)
        return return_batch_data, batch_size

    def average_checkpoints(self,
                            mid_checkpoint_dir: Union[Path, str]):
        train_losses = []
        mid_checkpoint_dir = os.path.abspath(mid_checkpoint_dir)
        yamls = glob.glob('{}/checkpoint-*.yaml'.format(mid_checkpoint_dir))
        for y in yamls:
            check_ckpt = str(y)[:-4] + "pt"
            if not os.path.exists(check_ckpt):
                continue

            with open(y, 'r') as f:
                dic_yaml = yaml.load(f, Loader=yaml.FullLoader)
                train_loss = dic_yaml["loss"]
                steps = dic_yaml["steps"]

                train_losses += [[steps, train_loss]]

        if len(train_losses) <= 0:
            logging.warning(" | No checkpoints to average")
            return

        train_losses = np.array(train_losses)
        sort_idx = np.argsort(train_losses[:, -1])
        sorted_train_losses = train_losses[sort_idx][::1]

        logging.info(" | Best train losses: {}".format(
            str(sorted_train_losses[:self.num_checkpoints_to_average, 1])
        ))
        logging.info(" | Selected average steps: {}".format(
            str(sorted_train_losses[:self.num_checkpoints_to_average, 0].astype(np.int64))
        ))

        checkpoints_list = [
            mid_checkpoint_dir + "/checkpoint-{}.pt".format(int(_step))
            for _step in sorted_train_losses[:self.num_checkpoints_to_average, 0]
        ]
        
        logging.info(" | checkpoints to average: \n{}".format(
            "\n".join(checkpoints_list)))

        average_num = len(checkpoints_list)
        avg = average_checkpoints(checkpoints_list)
        
        saved_checkpoints = "{}_avg_{}.pt".format(self.model_type, str(average_num))
        saved_checkpoints_path = os.path.join(mid_checkpoint_dir, saved_checkpoints)
        logging.info(" | Checkpoint: save to checkpoint {}".format(saved_checkpoints_path))
        torch.save(avg, saved_checkpoints_path)

    def save_mid_checkpoints(self,
                             epoch: int,
                             save_steps: int,
                             keep_last_k_ckpt: int,
                             mid_checkpoint_dir: Union[Path, str],
                             loss: torch.Tensor) -> None:
        save_path = os.path.join(mid_checkpoint_dir, f"checkpoint-{self.scheduler.n_steps}.pt")
        if not os.path.exists(save_path) and self.scheduler.n_steps % save_steps == 0 and self.scheduler.n_steps > 0:
            save_checkpoint(self.saved_model, save_path, {
                "epoch": epoch,
                "steps": self.scheduler.n_steps,
                "lr": self.scheduler.lr,
                "loss": loss.item()
            })
            remove_checkpoints(
                out_dir=mid_checkpoint_dir,
                topk=keep_last_k_ckpt
            )
    
    def save_epoch_checkpoints(self,
                               epoch: int,
                               dev_metric_dict: Dict,
                               keep_topk_dir: bool=False) -> None:
        if self.scheduler.is_topk:
            save_checkpoint_path = os.path.join(self.checkpoint_dir,
                                                "{}_{}.pt".format(self.model_type, epoch))
            save_optimizer_path = os.path.join(self.checkpoint_dir,
                                               "optimizer_{}.pt".format(epoch))
            
            save_checkpoint(self.saved_model, save_checkpoint_path, {
                "epoch": epoch,
                "step": self.scheduler.n_steps,
                "lr": self.scheduler.lr,
                "dev_metric": dev_metric_dict
            })
            save_optimizer(self.scheduler.optimizer, save_optimizer_path)
        
        top_k_epochs = [tp[0] for tp in self.scheduler.topk_list]

        if keep_topk_dir:
            top_k_checkpoint_dir_names = ["{}_{}".format(self.model_type, ep) for ep in top_k_epochs]
        else:
            top_k_checkpoint_dir_names = ["{}_{}".format(self.model_type, top_k_epochs[0])]
        top_k_checkpoints_names = ["{}_{}.pt".format(self.model_type, ep) for ep in top_k_epochs] + \
            ["optimizer_{}.pt".format(ep) for ep in top_k_epochs]
            
        remove_checkpoint_pattern = re.compile(f"{self.model_type}_([0-9]+).pt")
        remove_optimizer_pattern = re.compile(f"optimizer_([0-9]+).pt")
        
        for saved_dir_or_file in os.listdir(self.checkpoint_dir):
            del_dir_or_file = os.path.join(self.checkpoint_dir, saved_dir_or_file)
            if os.path.isdir(del_dir_or_file) and saved_dir_or_file not in top_k_checkpoint_dir_names:
                logging.debug(f" | Remove {del_dir_or_file}")
                shutil.rmtree(del_dir_or_file)
            elif (re.search(remove_checkpoint_pattern, del_dir_or_file) or \
                re.search(remove_optimizer_pattern, del_dir_or_file)) and \
                saved_dir_or_file not in top_k_checkpoints_names:
                logging.debug(f" | Remove {del_dir_or_file}")
                os.remove(del_dir_or_file)
