# code taken from https://github.com/zipzou/hf-multitask-trainer

from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
from torch.nn import Module
from torch.optim.lr_scheduler import LambdaLR
from torch.optim.optimizer import Optimizer as Optimizer
from torch.utils.data import Dataset, IterableDataset
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.trainer import Trainer
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalPrediction
from transformers.training_args import TrainingArguments

from flowdock.multitask_trainer.mixins import MultiTaskModuleMixin
from flowdock.multitask_trainer.state import AdditionalState

DataCollator = Callable[[List[Any]], Dict[str, Any]]


class HfMultiTaskTrainer(Trainer):

    def __init__(
        self,
        model: Optional[Union[PreTrainedModel, Module]] = None,
        args: Optional[TrainingArguments] = None,
        data_collator: Optional[DataCollator] = None,
        train_dataset: Optional[Union[Dataset, IterableDataset, Any]] = None,
        eval_dataset: Optional[Union[Dataset, Dict[str, Dataset], Any]] = None,
        tokenizer: Optional[PreTrainedTokenizerBase] = None,
        model_init: Optional[Callable[[], PreTrainedModel]] = None,
        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
        callbacks: Optional[List[TrainerCallback]] = None,
        optimizers: Optional[Tuple[Optimizer, LambdaLR]] = (None, None),
        preprocess_logits_for_metrics: Optional[Callable[
            [torch.Tensor, torch.Tensor], torch.Tensor]] = None
    ):
        self.additional_state = AdditionalState(args)
        if model is not None:
            model.__class__.__bases__ = model.__class__.__bases__ + (
                MultiTaskModuleMixin,
            )
            model.report_metrics = partial(
                model.report_metrics, self.additional_state
            )
        super().__init__(
            model=model, args=args, data_collator=data_collator, train_dataset=train_dataset, 
            eval_dataset=eval_dataset, processing_class=tokenizer, model_init=model_init, 
            compute_metrics=compute_metrics, callbacks=callbacks, optimizers=optimizers,
            preprocess_logits_for_metrics=preprocess_logits_for_metrics
        )

    def log(self, logs: Dict[str, float], start_time = None) -> None:
        if self.state.epoch is not None:
            logs["epoch"] = self.state.epoch
        if self.args.include_num_input_tokens_seen:
            logs["num_input_tokens_seen"] = self.state.num_input_tokens_seen

        split = "test"
        if logs.get("eval_loss", None) is not None or logs.get("eval_pdbbind_loss", None) is not None:
            split = "eval"
        elif logs.get("loss", None) is not None:
            split = "train"


        if hasattr(self, 'additional_state'):
            additional_logs = self.additional_state.pop_metrics(
                gather_func=self._nested_gather
            )
            if split != "train":
                additional_logs = {f"{split}_{param_name}": param for param_name, param in additional_logs.items()}
        else:
            additional_logs = {}

        epoch = logs.pop('epoch', None)
        logs.update(additional_logs)
        logs['epoch'] = epoch

        output = {
            **logs,
            **{
                "step": self.state.global_step
            }
        }
        self.state.log_history.append(output)
        self.control = self.callback_handler.on_log(
            self.args, self.state, self.control, logs
        )
        