# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Callbacks

"""
import os
import json
import logging
import torch
from transformers.integrations import TensorBoardCallback
from transformers.trainer_callback import (
    TrainingArguments,
    TrainerState,
    TrainerControl,
    PrinterCallback, 
    TrainerCallback,
)
from transformers.trainer_utils import is_main_process

logger = logging.getLogger(__name__)


class EditableLogPrinterCallback(PrinterCallback):
    """ TrainerCallback with editable log info.
    """
    def __init__(self):
        super().__init__()
        self.log = dict()

    def add_to_log(self, log):
        if isinstance(log, dict):
            log = {k:v.item() if isinstance(v, torch.Tensor) else v for k, v in log.items()}
            self.log.update(log)
        else:
            logger.warning('Expect log to be dict, got {}'.format(type(log)), RuntimeWarning)

    def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, logs=None, **kwargs):
        logs = {**logs, **self.log}
        _ = logs.pop("total_flos", None)
        epoch = logs.pop('epoch', 'unknown')  # in case training is not done, logs does not have 'epoch'
        if state.is_local_process_zero:
            logger.info('[Epoch {} - Step {}]: {}.'.format(epoch, state.global_step, logs))


class EditableLogTensorBoardCallback(TensorBoardCallback):
    """ TensorBoardCallback with editable log info.
    """
    def __init__(self):
        super().__init__()
        self.log = dict()

    def add_to_log(self, log):
        if isinstance(log, dict):
            self.log.update(log)
        else:
            logger.warning('Expect log to be dict, got {}'.format(type(log)), RuntimeWarning)

    def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, logs=None, **kwargs):
        logs = {**logs, **self.log}
        super().on_log(args, state, control, logs, **kwargs)


class HaltTrainingCallback(TrainerCallback):
    """ Halt training if passing certain number of steps or epochs

    Args:
        halt_step: 
        halt_epoch: 
    """
    def __init__(self, halt_step=-1, halt_epoch=-1):
        self.halt_step = halt_step
        self.halt_epoch = halt_epoch

    def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        """
        Event called at the end of an epoch.
        """
        if self.halt_epoch >= 0 and state.epoch >= self.halt_epoch:
            control.should_training_stop = True
            if state.is_local_process_zero:
                logger.info('[Epoch {}]: halt_epoch reached. Stop training now.'.format(state.epoch))

    def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        """
        Event called at the end of training.
        """
        if self.halt_step >= 0 and state.global_step >= self.halt_step:
            control.should_training_stop = True
            if state.is_local_process_zero:
                logger.info('[Epoch {}]: halt_step reached. Stop training now.'.format(state.global_step))


class PeftSaveCallback(TrainerCallback):
    """ After the model saving during training, rename the saved checkpoint and save config files.
    
    """
    def __init__(self, adapter_config):
        self.config = adapter_config.to_dict()

    def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        # get the latest checkpoint
        if is_main_process(args.local_rank):
            ckpt_steps = [int(some_path.split("checkpoint-")[1]) for some_path in os.listdir(
                args.output_dir) if some_path.startswith("checkpoint-")]
            if ckpt_steps:
                latest_step = sorted(ckpt_steps)[-1]
                ckpt_folder = os.path.join(args.output_dir, "checkpoint-{}".format(latest_step))
                adapter_config_path = os.path.join(ckpt_folder, 'adapter_config.json')
                if not os.path.isfile(adapter_config_path):
                    with open(adapter_config_path, 'w') as handler:
                        handler.write(json.dumps(self.config, indent=2))
                ckpt_model_path = os.path.join(ckpt_folder, "pytorch_model.bin")
                if os.path.isfile(ckpt_model_path):
                    target_file = os.path.join(ckpt_folder, "adapter_model.bin")
                    os.rename(ckpt_model_path, target_file)