from typing import Dict, Union, Any, Optional, List, Tuple

from torch.utils.data import DataLoader
from transformers.trainer import Trainer, _is_peft_model, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
import torch
import torch.nn as nn
from transformers.trainer_utils import EvalLoopOutput


class LLAMA_Trainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        # outputs = model(**inputs, labels=None)
        # print(f'inputs: {inputs}')
        if self.label_smoother is not None and "labels" in inputs:
            labels = inputs.pop("labels")
        else:
            labels = inputs.pop("labels")
        # labels = inputs['data']['labels']
        with torch.no_grad():
            outputs_dict = model(**inputs, labels=None, cls=False)
        outputs = model(past_key_values=outputs_dict['past_key_values'], labels=labels, cls=True)
        loss = outputs['loss']
        return (loss, outputs) if return_outputs else loss

    def compute_metrics(self, eval_pred):
        print(f'eval_pred: {eval_pred}')
        predictions, labels = eval_pred.predictions, eval_pred.label_ids
        preds = predictions.argmax(-1)  # 对分类问题使用argmax
        accuracy = (preds == labels).mean()
        return {"accuracy": accuracy}


class LLAMA_Trainer_NO_GIST(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        # outputs = model(**inputs, labels=None)
        # print(f'inputs: {inputs}')
        if self.label_smoother is not None and "labels" in inputs:
            labels = inputs.pop("labels")
        else:
            labels = inputs.pop("labels")
        with torch.no_grad():
            outputs = model(**inputs, labels=labels, cls=True)
        # outputs = model(past_key_values=outputs_dict['past_key_values'], labels=labels, cls=True)
        loss = outputs['loss']
        return (loss, outputs) if return_outputs else loss

    def compute_metrics(self, eval_pred):
        print(f'eval_pred: {eval_pred}')
        predictions, labels = eval_pred.predictions, eval_pred.label_ids
        preds = predictions.argmax(-1)  # 对分类问题使用argmax
        accuracy = (preds == labels).mean()
        return {"accuracy": accuracy}


class LLAVA_Trainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        # outputs = model(**inputs, labels=None)
        # print(f'inputs: {inputs}')
        if self.label_smoother is not None and "labels" in inputs:
            labels = inputs.pop("labels")
        else:
            labels = inputs.pop("labels")
        # labels = inputs['data']['labels']
        with torch.no_grad():
            outputs_dict = model(**inputs,
                                 return_dict=True,
                                 labels=None,
                                 # output_hidden_states=True,
                                 cls=False)
        outputs = model(past_key_values=outputs_dict['past_key_values'], labels=labels, cls=True)
        loss = outputs['loss']
        return (loss, outputs) if return_outputs else loss

    def compute_metrics(self, eval_pred):
        print(f'eval_pred: {eval_pred}')
        predictions, labels = eval_pred.predictions, eval_pred.label_ids
        preds = predictions.argmax(-1)
        accuracy = (preds == labels).mean()
        return {"accuracy": accuracy}


class LLAVA_Trainer_NO_GIST(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        if self.label_smoother is not None and "labels" in inputs:
            labels = inputs.pop("labels")
        else:
            labels = inputs.pop("labels")
        with torch.no_grad():
            outputs = model(**inputs, labels=labels, cls=True)
        loss = outputs['loss']
        outputs = {'loss': loss, 'logits': outputs['cls_logits']}
        return (loss, outputs) if return_outputs else loss

    def compute_metrics(self, eval_pred):
        print(f'eval_pred: {eval_pred}')
        predictions, labels = eval_pred.predictions, eval_pred.label_ids
        preds = predictions.argmax(-1)
        accuracy = (preds == labels).mean()
        return {"accuracy": accuracy}


class QWen_Trainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        # outputs = model(**inputs, labels=None)
        # print(f'inputs: {inputs}')
        if self.label_smoother is not None and "labels" in inputs:
            labels = inputs.pop("labels")
        else:
            labels = inputs.pop("labels")
        # labels = inputs['data']['labels']
        with torch.no_grad():
            outputs = model.generate(**inputs,
                                     output_hidden_states=True, return_dict_in_generate=True,
                                     max_new_tokens=1)
            past_key_values = outputs['past_key_values']
            # remove the last key values
            past_key_values = list(past_key_values)
            for i in range(len(past_key_values)):
                past_key_values[i] = list(past_key_values[i])
            for i in range(len(past_key_values)):
                for j in range(len(past_key_values[i])):
                    past_key_values[i][j] = past_key_values[i][j][:, :, :-1, :]

        outputs = model(past_key_values=past_key_values, labels=labels, cls=True)
        loss = outputs['loss']
        return (loss, outputs) if return_outputs else loss

    def compute_metrics(self, eval_pred):
        print(f'eval_pred: {eval_pred}')
        predictions, labels = eval_pred.predictions, eval_pred.label_ids
        preds = predictions.argmax(-1)
        accuracy = (preds == labels).mean()
        return {"accuracy": accuracy}
