import os
import json
import torch
from transformers import Trainer, TrainingArguments
from transformers.utils import logging
from typing import Dict, Union, Any, Optional, List
from dataclasses import dataclass, field
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

logger = logging.get_logger(__name__)


@dataclass
class PermutationLogArgs:
    """
    Arguments for permutation loss logging. These are separate to allow HfArgumentParser
    to parse them correctly alongside TrainingArguments and other custom arg sets.
    NOTE: Fields previously here (permutation_log_dir, dataset_name_for_log) are removed
    as logging will be handled by wandb.
    """

    pass  # No specific args needed here now as wandb handles logging dirs/names


class PermutationLossLoggingTrainingArguments(TrainingArguments):
    dataloader_pin_memory: Optional[bool] = field(
        default=False,
        metadata={"help": "Whether to pin memory in DataLoader."},
    )
    # permutation_log_dir: Optional[str] = field(
    #     default=None,
    #     metadata={
    #         "help": "Directory to save permutation-specific loss logs. Defaults to {output_dir}/permutation_loss_logs."
    #     },
    # )
    # dataset_name_for_log: Optional[str] = field(
    #     default="unknown_dataset",
    #     metadata={"help": "Name of the dataset, used for naming the loss log files."},
    # )
    pass


class PermutationLossLoggingTrainer(Trainer):
    """
    Custom Trainer that logs loss for each permutation index separately using Weights & Biases.
    Includes a method to evaluate loss on all permutations at the end of training.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # Ensure wandb is initialized if it's in report_to
        if "wandb" in self.args.report_to and self.is_world_process_zero():
            import wandb

            if wandb.run is None:
                # This case should ideally be handled by the main script's wandb.init()
                # or Trainer's own wandb setup.
                # If Trainer didn't initialize it (e.g. report_to not set early enough for its __init__),
                # we might need to ensure it's configured.
                # However, relying on HF Trainer's native wandb integration is preferred.
                print(
                    "Warning: wandb.run is None in PermutationLossLoggingTrainer init. "
                    "Ensure wandb is initialized correctly in the main script "
                    "and 'wandb' is in TrainingArguments.report_to."
                )

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        """
        How the loss is computed by Trainer. By default, all models return the loss in the first element.

        Subclass and override for custom behavior.
        """
        # Extract permutation_idx from inputs. It was added by the collator.
        # .pop() is used to remove it so that the model doesn't receive an unexpected argument.
        # permutation_idx = inputs.pop("permutation_idx", None)
        permutation_idx = inputs.get("permutation_idx", None)
        # permutation_idx = torch.stack([i for i in torch.arange(0, len(self.data_collator.permutations_list))])

        if self.label_smoother is not None and "labels" in inputs:
            labels = inputs.pop("labels")
        else:
            labels = None

        outputs = model(**inputs)
        # outputs = super().compute_loss(model, inputs, return_outputs=True)
        # Save past state if it exists
        # TODO: this needs to be fixed and made cleaner later.
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]

        if labels is not None:
            unwrapped_model = self.unwrap_model(model)
            if hasattr(unwrapped_model, "compute_loss"):  # Optional custom loss computation on model
                loss = unwrapped_model.compute_loss(outputs, labels)
            else:
                loss = self.label_smoother(outputs, labels)
        else:
            if isinstance(outputs, dict) and "loss" not in outputs:
                raise ValueError(
                    "The model did not return a loss from the inputs, only the following keys: "
                    f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
                )
            # We don't use .loss here since the model may return tuples instead of ModelOutput.
            loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
        # breakpoint()
        # Log permutation-specific loss if permutation_idx is available and we are in training
        if (
            permutation_idx is not None
            and self.is_in_train
            and "wandb" in self.args.report_to
            and self.is_world_process_zero()
        ):
            # Handle both single permutation and per-sample permutation cases
            # breakpoint()
            if torch.is_tensor(permutation_idx):
                # If all samples in batch have the same permutation (old behavior)
                if torch.all(permutation_idx == permutation_idx[0]):
                    log_data = {f"loss_perm_{permutation_idx[0].item()}": loss.item()}
                    self.log(log_data)
                else:
                    # If samples have different permutations (new per-sample behavior)
                    # Log average loss for each permutation that appears in this batch
                    unique_perms = torch.unique(permutation_idx)
                    for perm_idx in unique_perms:
                        # This is the batch-level loss, not per-sample loss
                        # For per-sample loss logging, we'd need to modify this further
                        log_data = {f"loss_perm_{perm_idx.item()}_in_batch": loss.item()}
                        self.log(log_data)
            else:
                # Fallback for non-tensor permutation_idx
                log_data = {f"loss_perm_{permutation_idx}": loss.item()}
                self.log(log_data)
            # Also log the general loss for overall tracking, if not already handled by Trainer's default logging
            # However, Trainer's default log method will pick up the 'loss' from outputs.
            # We are logging custom metrics here.
            # self.log({f"loss_perm_{permutation_idx.item()}": loss.item(), "step": current_global_step}) # Use self.log for HF Trainer integration

            # It's possible one batch contains multiple permutation_idx if collator is modified that way.
            # Assuming permutation_idx is a tensor of indices for the batch.
            # For simplicity, if it's a single scalar tensor, use its value.
            # If it's a batch of them, this logic might need adjustment based on how they are batched.
            # Current collator adds a single scalar tensor for the whole batch.
            # if torch.is_tensor(permutation_idx) and current_global_step % self.args.logging_steps == 0:
            #     perm_idx_val = permutation_idx[0].item()
            #     self.log(log_data)  # self.log handles step implicitly
            # else:
            # This case needs to be defined if collator sends a batch of perm_idx
            # print(f"Warning: permutation_idx in compute_loss is not a single element tensor: {permutation_idx}")

        return (loss, outputs) if return_outputs else loss

    # training_step can remain as default, as compute_loss is the core part for loss handling.
    # If more complex per-permutation logic is needed at each step (e.g. optimizer changes),
    # then training_step might need overriding.

    def evaluate_all_permutations(
        self,
        eval_dataset: Optional[Dataset] = None,
        metric_key_prefix: str = "eval",
    ) -> Dict[int, float]:
        """
        Evaluates the model on the given dataset for each permutation defined in the data_collator.
        Returns a dictionary mapping permutation index to its loss.
        """
        if (
            self.data_collator is None
            or not hasattr(self.data_collator, "permutations_list")
            or not hasattr(self.data_collator, "fixed_permutation_index")
        ):
            logger.error("Data collator is not configured for evaluating all permutations. Skipping.")
            return {}

        original_fixed_perm_index = getattr(self.data_collator, "fixed_permutation_index", None)
        original_per_sample_permutation = getattr(self.data_collator, "per_sample_permutation", False)
        num_permutations = len(self.data_collator.permutations_list)
        all_permutation_losses = {}

        if num_permutations == 0:
            logger.warning("No permutations found in the data collator. Cannot evaluate all permutations.")
            return {}

        # Disable per_sample_permutation during evaluation to ensure consistent permutation per batch
        self.data_collator.per_sample_permutation = False
        logger.info(f"Starting evaluation for all {num_permutations} permutations.")

        for perm_idx in tqdm(range(num_permutations), desc="Evaluating Permutations"):
            self.data_collator.fixed_permutation_index = perm_idx
            logger.info(f"Evaluating with permutation index: {perm_idx}")
            # breakpoint()

            # Use the standard evaluation loop but with the fixed permutation
            # The evaluate method returns a dict like {'eval_loss': ..., 'eval_runtime': ...}
            # metrics = self.evaluate(eval_dataset=eval_dataset)
            metrics = self.evaluate(
                eval_dataset=eval_dataset,
                metric_key_prefix=metric_key_prefix,
            )
            # loss_key = f"{metric_key_prefix}_perm_{perm_idx}_loss"
            loss_key = f"{metric_key_prefix}_loss"  # Assuming the loss is logged under this key
            # breakpoint()
            if loss_key in metrics:
                all_permutation_losses[perm_idx] = metrics[loss_key]
                logger.info(f"Loss for permutation {perm_idx}: {metrics[loss_key]}")
            else:
                logger.warning(
                    f"Could not find loss key '{loss_key}' in evaluation metrics for permutation {perm_idx}. Metrics: {metrics}"
                )
                all_permutation_losses[perm_idx] = float("nan")  # Or handle error as appropriate

        # Restore original settings in collator
        self.data_collator.fixed_permutation_index = original_fixed_perm_index
        self.data_collator.per_sample_permutation = original_per_sample_permutation
        logger.info("Finished evaluating all permutations.")
        return all_permutation_losses
