from transformers.trainer_callback import TrainerCallback
import math
import importlib.util
import torch
import torch.distributed as dist
from trainer import MetricLogger
from typing import Any, Dict, List, Optional, Tuple, Union
from hexa.utils.dist_utils import gather
import os
import json
import pprint


def rewrite_logs(d):
    new_d = {}
    eval_prefix = "eval_"
    eval_prefix_len = len(eval_prefix)
    test_prefix = "test_"
    test_prefix_len = len(test_prefix)
    for k, v in d.items():
        if k.startswith(eval_prefix):
            new_d["eval/" + k[eval_prefix_len:]] = v
        elif k.startswith(test_prefix):
            new_d["test/" + k[test_prefix_len:]] = v
        else:
            new_d["train/" + k] = v
    return new_d

def is_tensorboard_available():
    return importlib.util.find_spec("tensorboard") is not None or importlib.util.find_spec("tensorboardX") is not None


class TokenizerSaveCallback(TrainerCallback):
    def on_save(self, args, state, control, **kwargs):
        if state.is_world_process_zero:
            kwargs["tokenizer"].save_pretrained(args.output_dir)


class EpochLossCallback(TrainerCallback):
    # Adds epoch level average loss on logging
    def __init__(self):
        super().__init__()
        self.epoch_loss_sum = 0.0
        self.epoch_step_count = 0
        self.prev_epoch = 0
        self.prev_global_step_count = 0

    def on_epoch_begin(self, args, state, control, **kwargs):
        self.epoch_loss_sum = 0.0
        self.epoch_step_count = 0

    def on_log(self, args, state, control, logs=None, **kwargs):
        # state has following arguments: epoch, global_step, max_steps, total_flos, log_history, best_metric
        # best_model_checkpoint, is_local_process_zero, is_world_process_zero, is_hyperparam_search
        # logs: contains 'loss' and 'learning_rate'
        # logs['loss']: per step average of loss over all process inbetween logging steps
        if logs is None or 'loss' not in logs:
            return
        step_diff = state.global_step - self.prev_global_step_count
        self.epoch_loss_sum += logs['loss'] * step_diff
        self.epoch_step_count += step_diff

        logs['mean_loss'] = self.epoch_loss_sum / (self.epoch_step_count + int(self.epoch_step_count == 0))

        self.prev_global_step_count = state.global_step


class TensorBoardCallback(TrainerCallback):
    """
    A [`TrainerCallback`] that sends the logs to [TensorBoard](https://www.tensorflow.org/tensorboard).
    Args:
        tb_writer (`SummaryWriter`, *optional*):
            The writer to use. Will instantiate one if not set.
    """

    def __init__(self, tb_writer=None):
        has_tensorboard = is_tensorboard_available()
        if not has_tensorboard:
            raise RuntimeError(
                "TensorBoardCallback requires tensorboard to be installed. Either update your PyTorch version or"
                " install tensorboardX."
            )
        if has_tensorboard:
            try:
                from torch.utils.tensorboard import SummaryWriter  # noqa: F401

                self._SummaryWriter = SummaryWriter
            except ImportError:
                try:
                    from tensorboardX import SummaryWriter

                    self._SummaryWriter = SummaryWriter
                except ImportError:
                    self._SummaryWriter = None
        else:
            self._SummaryWriter = None
        self.tb_writer = tb_writer

    def _init_summary_writer(self, args, log_dir=None):
        log_dir = log_dir or args.logging_dir
        if self._SummaryWriter is not None:
            self.tb_writer = self._SummaryWriter(log_dir=log_dir)

    def on_train_begin(self, args, state, control, **kwargs):
        if not state.is_world_process_zero:
            return

        log_dir = None

        if state.is_hyper_param_search:
            trial_name = state.trial_name
            if trial_name is not None:
                log_dir = os.path.join(args.logging_dir, trial_name)

        if self.tb_writer is None:
            self._init_summary_writer(args, log_dir)

        if self.tb_writer is not None:
            self.tb_writer.add_text("args", args.to_json_string())
            if "model" in kwargs:
                model = kwargs["model"]
                if hasattr(model, "config") and model.config is not None:
                    model_config_json = model.config.to_json_string()
                    self.tb_writer.add_text("model_config", model_config_json)
            # Version of TensorBoard coming from tensorboardX does not have this method.
            if hasattr(self.tb_writer, "add_hparams"):
                self.tb_writer.add_hparams(args.to_sanitized_dict(), metric_dict={})

    def on_log(self, args, state, control, logs=None, **kwargs):
        if not state.is_world_process_zero:
            return

        if self.tb_writer is None:
            self._init_summary_writer(args)

        if self.tb_writer is not None:
            # logs = rewrite_logs(logs)
            for k, v in logs.items():
                if isinstance(v, (int, float)):
                    self.tb_writer.add_scalar(k, v, state.global_step)
                else:
                    logger.warning(
                        "Trainer is attempting to log a value of "
                        f'"{v}" of type {type(v)} for key "{k}" as a scalar. '
                        "This invocation of Tensorboard's writer.add_scalar() "
                        "is incorrect so we dropped this attribute."
                    )
            self.tb_writer.flush()

    def on_train_end(self, args, state, control, **kwargs):
        if self.tb_writer:
            self.tb_writer.close()
            self.tb_writer = None


class LoggingCallback(TrainerCallback):
    def __init__(self, logger: MetricLogger):
        self.logger = logger

    def _gpu_usage(self, opt):
        """
        Compute GPU memory usage.
        Includes both allocated and cached memory; this should be close to the
        output of nvidia-smi, but not reflect of how much is currently demanded
        by the program. It may be viewed as a rough approximation of
        worst-case-until-now.
        :return: Percent of allocated GPU memory as a fraction of available.
        """
        if opt.gpu == -1:
            # use all gpus available locally
            devices = range(torch.cuda.device_count())
        else:
            devices = [opt.gpu]
        memory_avail = 0
        memory_used = 0
        for dev in devices:
            props = torch.cuda.get_device_properties(dev)
            memory_avail += props.total_memory
            memory_used += torch.cuda.max_memory_allocated(dev)
            torch.cuda.reset_peak_memory_stats(dev)
        return memory_used / memory_avail

    def on_step_begin(self, args, state, control, **kwargs):
        pass

    def on_step_end(self, args, state, control, **kwargs):
        self.logger.log('ups', 1)
        self.logger.log('total_train_updates', state.global_step)
        try:
            opt = kwargs['model'].opt
        except Exception:
            opt = kwargs['model'].config
        self.logger.log('gpu_mem', self._gpu_usage(opt))


class EvalLoggingCallback(TrainerCallback):
    def __init__(self, logger: MetricLogger, logfile_name='log'):
        self.logger = logger
        self.logfile_name = logfile_name

    def gather_metrics(self, metrics: Dict[str, Union[int, float]]):
        gathered_metrics = {key: {} for key in metrics.keys()}
        keys = sorted(metrics.keys())

        for k in keys:
            v = metrics[k]
            torch.distributed.barrier()
            gathered_metrics[k] = torch.cat(gather(torch.FloatTensor([v]).cuda()))

        return gathered_metrics

    def on_evaluate(self, args, state, control, **kwargs):
        print('Evaluation end.')
        metrics = kwargs.get('metrics')
        assert metrics is not None

        local_rank = args.local_rank
        logfile_name = f'{self.logfile_name}_{local_rank}.json'
        logfile_path = os.path.join(args.output_dir, logfile_name)
        with open(logfile_path, 'w') as fout:
            json.dump(metrics, fout)

        torch.distributed.barrier()
        if local_rank == 0:
            world_size = dist.get_world_size()
            gathered_data = {}
            item_length = {}
            files = []
            for i in range(world_size):
                local_logfile_name = f'{self.logfile_name}_{i}.json'
                local_logfile_path = os.path.join(args.output_dir, local_logfile_name)
                with open(local_logfile_path, 'r') as fin:
                    local_data = json.load(fin)
                    for key, value in local_data.items():
                        if not key in gathered_data:
                            gathered_data[key] = 0
                        if not key in item_length:
                            item_length[key] = 0
                        gathered_data[key] += value
                        item_length[key] += 1
                files.append(local_logfile_path)

            for key, value in gathered_data.items():
                if not 'exs' in key:
                    gathered_data[key] /= item_length[key]

            pprint.pprint(gathered_data, indent=4)

            new_file_path = os.path.join(args.output_dir, f'{self.logfile_name}.json')
            with open(new_file_path, 'w') as fout:
                json.dump(gathered_data, fout, indent=4)

            for fpath in files:
                os.remove(fpath)


        # if args.local_rank > -1:
            # gathered_metrics = self.gather_metrics(metrics)
            # if dist.get_rank() == 0:
            #     with open(os.path.join(args.output_dir, self.logfile_name), 'w') as fw:
            #         for k, v in gathered_metrics.items():
            #             if 'exs' in k:
            #                 print(f"{k}: {v.sum().item()}")
            #                 fw.write(f"{k}: {v.sum().item()}" + '\n')
            #             else:
            #                 print(f"{k}: {v.mean().item()}")
            #                 fw.write(f"{k}: {v.mean().item()}" + '\n')
        # else:
            # with open(os.path.join(args.output_dir, self.logfile_name), 'w') as fw:
            #     for k, v in metrics.items():
            #         print(f"{k}: {v}")
            #         fw.write(f"{k}: {v}" + '\n')