from typing import Optional, Dict, Union, Any, List, Tuple
from transformers import Trainer
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset


class CustomTrainer(Trainer):
    def __init__(self, **kwargs):
        self.num_chunks = kwargs.pop("num_chunks")
        self.max_length = kwargs.pop("max_length")
        self.my_tokenizer = kwargs.pop("my_tokenizer")
        super(CustomTrainer, self).__init__(**kwargs)
        print(f"Using device: {self.args.device}")

    def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
        return super(CustomTrainer, self).training_step(model, inputs)

    def prediction_step(
            self,
            model: nn.Module,
            inputs: Dict[str, Union[torch.Tensor, Any]],
            prediction_loss_only: bool,
            ignore_keys: Optional[List[str]] = None,
    ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
        return super(CustomTrainer, self).prediction_step(model, inputs, prediction_loss_only, ignore_keys)

    def get_train_dataloader(self):
        train_dataloader = DataLoader(self.train_dataset,
                                      batch_size=self.args.per_device_train_batch_size,
                                      # num_workers=0,
                                      pin_memory=True)
        return train_dataloader

    def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
        if not eval_dataset:
            eval_dataset = self.eval_dataset
        validation_dataloader = DataLoader(eval_dataset,
                                           batch_size=self.args.per_device_eval_batch_size,
                                           # num_workers=0,
                                           pin_memory=True,
                                           # shuffle=False
                                           )
        return validation_dataloader

    def compute_loss(self, model, inputs, return_outputs=False):
        """
        How the loss is computed by Trainer. By default, all models return the loss in the first element.

        Subclass and override for custom behavior.
        """
        outputs = model(**inputs)
        if self.state.global_step % 501 == 0:
            print({'loss': torch.mean(outputs['loss']).item(),
                   # 'loss_3d': torch.mean(outputs['loss_3d']).item(),
                   # 'lm_loss': torch.mean(outputs['lm_loss']).item(),
                   # 'batch_size': outputs['batch_size'],
                   'steps': self.state.global_step})
        loss = outputs['loss']
        return (loss, outputs) if return_outputs else loss
