import os
import torch
from transformers import Trainer
from criterion import CausalCriterion, make_compute_metrics
from typing import Tuple, Dict, Optional, List
from models.utils import CausalLMOutputWithPastExtended
import numpy as np
from models.processor_wrapper import Phi4CausalQAProcessor
import json
from utils.tasks import CausalQATask

class SampleTrainer(Trainer):
    # Have your trainer ready

    def __init__(self, *args, 
                 criterion: CausalCriterion,
                 _last_notears_h: Optional[float] = None,
                 **kwargs):
        super().__init__(*args, **kwargs)
        pass

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch: Optional[int] = None):
        # We used to compute loss here
        X = inputs.copy()

        labels = inputs["labels"]
        X = self.keep_only_input_keys(X)

        outputs = model(**X, labels=labels, return_dict=True, use_cache=False)

        # Pass global_step to criterion (for NOTEARS warmup)
        gs = int(getattr(self.state, "global_step", 0))
        loss, loss_dict = self.criterion(outputs, inputs, global_step=gs)

        # Store last h(P) for NOTEARS updates (one scalar per forward)
        h_dag = loss_dict.get("h_dag", None)
        if isinstance(h_dag, torch.Tensor):
            try:
                self._last_notears_h = float(h_dag.detach().float().mean().cpu().item())
            except Exception:
                self._last_notears_h = None
        else:
            self._last_notears_h = None

        log_dict = {k: (v.item() if isinstance(v, torch.Tensor) else v) for k, v in loss_dict.items()}
        self.log(log_dict)

        if return_outputs:
            return loss, outputs
        return loss
    
    def keep_only_input_keys(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        '''
        Keep only the keys in inputs that are actually used by the model's forward method.
        This avoids issues with unused keys during training/evaluation.
        '''
        if not hasattr(self.model, 'signature'):
            model_signature = set(self.model.forward.__code__.co_varnames)
        else:
            model_signature = self.model.signature

        filtered_inputs = {k: v for k, v in inputs.items() if k in model_signature}
        return filtered_inputs
    