import os
from typing import Any, Dict, Optional, Union

import torch
import torch.nn as nn
from transformers import Trainer, PreTrainedModel
from transformers.trainer import TRAINING_ARGS_NAME
from transformers.utils import logging, is_peft_available, is_safetensors_available, SAFE_WEIGHTS_NAME, WEIGHTS_NAME

from src.utils.utils import extract_and_clean_state_dict

if is_peft_available():
    from peft import PeftModel

if is_safetensors_available:
    import safetensors.torch

logger = logging.get_logger(__name__)


class AlignChatRecognitionTrainer(Trainer):

    def __init__(
            self,
            embed_tokens: nn.Embedding = None,
            proj_out: nn.Module = None,
            **kwargs
        ):
        super().__init__(**kwargs)

        self.embed_tokens = embed_tokens
        self.proj_out = proj_out
    
    def training_step(
        self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch=None
    ) -> torch.Tensor:
        """
        Perform a training step on a batch of inputs.

        Subclass and override to inject custom behavior.

        Args:
            model (`nn.Module`):
                The model to train.
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument `labels`. Check your model's documentation for all accepted arguments.

        Return:
            `torch.Tensor`: The tensor with training loss on this batch.
        """

        inputs['use_cache'] = False
        inputs['loss_ratios'] = self.args.loss_ratios
        inputs['loss_types'] = self.args.loss_types

        inputs['embed_tokens'] = self.embed_tokens
        inputs['proj_out'] = self.proj_out

        return super().training_step(model, inputs, num_items_in_batch)

    def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys = None):
        """
        Perform a training step on a batch of inputs.

        Subclass and override to inject custom behavior.

        Args:
            model (`nn.Module`):
                The model to train.
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument `labels`. Check your model's documentation for all accepted arguments.

        Return:
            `torch.Tensor`: The tensor with training loss on this batch.
        """

        with torch.inference_mode():
            inputs['use_cache'] = False
            inputs['loss_ratios'] = self.args.loss_ratios
            inputs['loss_types'] = self.args.loss_types

            inputs['embed_tokens'] = self.embed_tokens
            inputs['proj_out'] = self.proj_out

        return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys)


class AlignChatResponseTrainer(Trainer):

    def training_step(
        self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch=None
    ) -> torch.Tensor:
        """
        Perform a training step on a batch of inputs.

        Subclass and override to inject custom behavior.

        Args:
            model (`nn.Module`):
                The model to train.
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument `labels`. Check your model's documentation for all accepted arguments.

        Return:
            `torch.Tensor`: The tensor with training loss on this batch.
        """

        inputs['use_cache'] = False
        inputs['loss_ratios'] = self.args.loss_ratios
        inputs['loss_types'] = self.args.loss_types

        return super().training_step(model, inputs, num_items_in_batch)

    def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys = None):
        """
        Perform a training step on a batch of inputs.

        Subclass and override to inject custom behavior.

        Args:
            model (`nn.Module`):
                The model to train.
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument `labels`. Check your model's documentation for all accepted arguments.

        Return:
            `torch.Tensor`: The tensor with training loss on this batch.
        """

        with torch.inference_mode():
            inputs['use_cache'] = False
            inputs['loss_ratios'] = self.args.loss_ratios
            inputs['loss_types'] = self.args.loss_types

        return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys)
    
    def get_batch_samples(self, epoch_iterator, num_batches, device):
        batch_samples = []
        num_items_in_batch_decoder_labels = None
        num_items_in_batch_labels = None
        num_items_in_batch = None

        for _ in range(num_batches):
            try:
                batch_samples.append(next(epoch_iterator))
            except StopIteration:
                break

        count_num_items_in_batch = (
            len(batch_samples) > 0
            and "decoder_labels" in batch_samples[0]
            and "labels" in batch_samples[0]
            and (
                # num_items_in_batch is passed to model forward
                # https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/trainer.py#L3757
                self.model_accepts_loss_kwargs
                # num_items_in_batch is passed to compute_loss_func
                # https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/trainer.py#L3773
                or self.compute_loss_func is not None
                # num_items_in_batch is also verified if (self.model_accepts_loss_kwargs or self.compute_loss_func)
                # https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/trainer.py#L3790
            )
        )

        if count_num_items_in_batch:
            # For now we don't support object detection
            try:
                num_items_in_batch_decoder_labels = sum([(batch["decoder_labels"].ne(-100)).sum() for batch in batch_samples])
            except (TypeError, AttributeError):
                pass
            
            try:
                num_items_in_batch_labels = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples])
            except (TypeError, AttributeError):
                pass

        if num_items_in_batch_decoder_labels is not None and num_items_in_batch_labels is not None:
            if self.args.average_tokens_across_devices:
                num_items_in_batch_decoder_labels = self.accelerator.gather(num_items_in_batch_decoder_labels).sum()
                num_items_in_batch_labels = self.accelerator.gather(num_items_in_batch_labels).sum()

            if torch.is_tensor(num_items_in_batch_decoder_labels):
                num_items_in_batch_decoder_labels = num_items_in_batch_decoder_labels.to(device)

                if self.args.n_gpu > 1 and num_items_in_batch_decoder_labels.dim() == 0:
                    # In the DataParallel case, convert the scalar tensor into a 1-dim tensor
                    num_items_in_batch_decoder_labels = num_items_in_batch_decoder_labels.unsqueeze(0)

            if torch.is_tensor(num_items_in_batch_labels):
                num_items_in_batch_labels = num_items_in_batch_labels.to(device)

                if self.args.n_gpu > 1 and num_items_in_batch_labels.dim() == 0:
                    # In the DataParallel case, convert the scalar tensor into a 1-dim tensor
                    num_items_in_batch_labels = num_items_in_batch_labels.unsqueeze(0)

            num_items_in_batch = (
                num_items_in_batch_decoder_labels,
                num_items_in_batch_labels,
            )

        return batch_samples, num_items_in_batch

    def _save(self, output_dir: Optional[str] = None, state_dict=None):
        # If we are executing this function, we are the process zero, so we don't check for that.
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
        logger.info(f"Saving model checkpoint to {output_dir}")

        if self.is_deepspeed_enabled:
            # DeepSpeed is used but we only need to save model.audio_model.
            # replace the prefix "audio_model." with "model."
            state_dict = extract_and_clean_state_dict(state_dict, old_prefix="audio_model.", new_prefix="")

        supported_classes = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel)
        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        if not isinstance(self.model, supported_classes):
            if state_dict is None:
                state_dict = self.model.audio_model.state_dict()

            if isinstance(self.accelerator.unwrap_model(self.model).audio_model, supported_classes):
                self.accelerator.unwrap_model(self.model).audio_model.save_pretrained(
                    output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
                )
            else:
                logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
                if self.args.save_safetensors:
                    safetensors.torch.save_file(
                        state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME), metadata={"format": "pt"}
                    )
                else:
                    torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
        else:
            self.model.audio_model.save_pretrained(
                output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
            )

        if self.processing_class is not None:
            self.processing_class.save_pretrained(output_dir)
        elif (
            self.data_collator is not None
            and hasattr(self.data_collator, "tokenizer")
            and self.data_collator.tokenizer is not None
        ):
            logger.info("Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`")
            self.data_collator.tokenizer.save_pretrained(output_dir)

        # Good practice: save your training arguments together with the trained model
        torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
