from dataclasses import dataclass
from typing import Any, Dict, List

import torch
from transformers import PreTrainedTokenizerBase


def per_token_loss(model_outputs, batch_input):
    pad_mask = batch_input["attention_mask"]
    logits = model_outputs.logits
    shift_logits = logits[:, :-1, :].contiguous()
    shift_labels = batch_input["labels"][:, 1:].contiguous()
    label_pad_mask = pad_mask[:, 1:] != 0

    loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
    bsz = shift_logits.size(0)
    loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
    loss = loss.view(bsz, -1)
    size_per_sample = label_pad_mask.sum(axis=1)

    return (loss, label_pad_mask, size_per_sample)


def per_example_loss(model_outputs, batch_input):
    token_loss, label_pad_mask, size_per_sample = per_token_loss(
        model_outputs, batch_input
    )
    loss = token_loss.sum(axis=1) / size_per_sample
    return loss, label_pad_mask


@dataclass
class LeftPaddingCompatibleDataCollatorForLM:
    tokenizer: PreTrainedTokenizerBase

    def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, Any]:
        # Determine the maximum sequence length in this batch
        max_length = max(len(seq["input_ids"]) for seq in examples)

        # Prepare padded structures
        padded_input_ids = []
        labels = []
        attention_masks = []

        for seq in examples:
            input_ids = seq["input_ids"]
            padding_length = max_length - len(input_ids)

            # Pad input_ids to the left
            padded_seq = [self.tokenizer.pad_token_id] * padding_length + input_ids
            padded_input_ids.append(padded_seq)

            # Use existing labels if present, else generate them from input_ids
            if "labels" in seq and seq["labels"] is not None:
                original_labels = seq["labels"]
                padded_label = [-100] * padding_length + original_labels
            else:
                padded_label = [-100] * padding_length + input_ids
            labels.append(padded_label)

            # Create attention mask
            attention_mask = [0] * padding_length + [1] * len(input_ids)
            attention_masks.append(attention_mask)

        # Convert to tensors
        batch = {
            "input_ids": torch.tensor(padded_input_ids, dtype=torch.long),
            "labels": torch.tensor(labels, dtype=torch.long),
            "attention_mask": torch.tensor(attention_masks, dtype=torch.long),
        }

        return batch


@dataclass
class DataCollatorForLM:
    tokenizer: PreTrainedTokenizerBase

    def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, Any]:
        if self.tokenizer.pad_token_id is None:
            # Fallback to eos_token_id if available, else raise
            if getattr(self.tokenizer, "eos_token_id", None) is not None:
                pad_id = self.tokenizer.eos_token_id
            else:
                raise ValueError("Tokenizer has no pad_token_id and no eos_token_id.")
        else:
            pad_id = self.tokenizer.pad_token_id

        side = getattr(self.tokenizer, "padding_side", "right")
        pad_left = side == "left"

        # Ensure lists
        def to_list(x):
            if isinstance(x, torch.Tensor):
                return x.tolist()
            return list(x)

        # Compute max length in this batch
        lengths = [len(to_list(ex["input_ids"])) for ex in examples]
        max_len = max(lengths)

        padded_input_ids: List[List[int]] = []
        padded_labels: List[List[int]] = []
        padded_attn: List[List[int]] = []

        for ex in examples:
            input_ids = to_list(ex["input_ids"])
            L = len(input_ids)
            pad_len = max_len - L

            # Input ids
            if pad_left:
                input_ids_padded = [pad_id] * pad_len + input_ids
            else:
                input_ids_padded = input_ids + [pad_id] * pad_len
            padded_input_ids.append(input_ids_padded)

            # Labels
            if "labels" in ex and ex["labels"] is not None:
                labels = to_list(ex["labels"])
            else:
                labels = input_ids  # teacher forcing

            if pad_left:
                labels_padded = [-100] * pad_len + labels
            else:
                labels_padded = labels + [-100] * pad_len
            padded_labels.append(labels_padded)

            # Attention mask
            if pad_left:
                attn = [0] * pad_len + [1] * L
            else:
                attn = [1] * L + [0] * pad_len
            padded_attn.append(attn)
        batch = {
            "input_ids": torch.tensor(padded_input_ids, dtype=torch.long),
            "labels": torch.tensor(padded_labels, dtype=torch.long),
            "attention_mask": torch.tensor(padded_attn, dtype=torch.long),
            "weight": torch.tensor(
                [ex.get("weight", 1.0) for ex in examples], dtype=torch.float
            ),
        }
        return batch
