import torch
import numpy as np
from sklearn.metrics import accuracy_score
from transformers import TrainerCallback, TrainerControl, TrainerState
from transformers import TrainingArguments
from .my_trainer import MyTrainer, MyTrainingArguments
from transformers.utils import logging
from typing import Dict
from transformers import get_scheduler
import math

logger = logging.get_logger(__name__)


def preprocess_logits_for_metrics(outputs, labels):
    """
    Original Trainer may have a memory leak.
    This is a workaround to avoid storing too many tensors that are not needed.
    outputs: (loss_clf, loss_rg, logits, logits_for_regression, encoder_outputs)
    labels: (labels_for_classification, labels_for_regression)
    """

    logits, logits_for_regression = outputs[2], outputs[3]
    predicted_ids = logits.argmax(dim=-1)

    return (predicted_ids, logits_for_regression)


def compute_metrics(eval_preds, ignore_index=-100):
    """
    every thing is in numpy array format
    """
    (predicted_ids, logits_for_regression), labels = eval_preds
    if isinstance(labels, tuple):
        labels, labels_for_regression = labels[0], labels[-1]

    # breakpoint()
    predicted_ids = predicted_ids[labels != ignore_index]
    labels = labels[labels != ignore_index]

    error = 1 - accuracy_score(labels, predicted_ids)
    metrics = {"error rate": error}

    if "labels_for_regression" in locals():
        mse_fn = lambda x, y: np.nanmean((x - y) ** 2, axis=(0, 1))
        mae_fn = lambda x, y: np.nanmean(np.abs(x - y), axis=(0, 1))
        is_ignore_index = labels_for_regression == ignore_index
        labels_for_regression[is_ignore_index] = np.nan

        mses = mse_fn(logits_for_regression, labels_for_regression)
        maes = mae_fn(logits_for_regression, labels_for_regression)
        for i, (mse, mae) in enumerate(zip(mses, maes), start=1):
            metrics[f"MSE{i}"] = mse
            metrics[f"MAE{i}"] = mae

    return metrics


# def nested_detach(tensors):
#     if tensors is None: return None # EDIT: added this line

#     "Detach `tensors` (even if it's a nested list/tuple/dict of tensors)."
#     if isinstance(tensors, (list, tuple)):
#         return type(tensors)(nested_detach(t) for t in tensors)
#     elif isinstance(tensors, Mapping):
#         return type(tensors)({k: nested_detach(t) for k, t in tensors.items()})
#     return tensors.detach()


class LimitStepsCallback(TrainerCallback):
    def __init__(self, max_steps_per_epoch: int):
        self.max_steps_per_epoch = max_steps_per_epoch if max_steps_per_epoch > 0 else float("inf")
        self.current_epoch_steps = 0

    def on_epoch_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        self.current_epoch_steps = 0

    def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        self.current_epoch_steps += 1
        if self.current_epoch_steps >= self.max_steps_per_epoch:
            control.should_epoch_stop = True
            control.should_save = True
            control.should_evaluate = True
