import torch
from transformers import DefaultDataCollator
from dataclasses import dataclass
from typing import List, Dict, Any, Optional
import torch
from utils import seeded_rand, seeded_randint, get_batch_seed

@dataclass
class CollatorConfig:
    """
    Configuration for the data collator.
    """
    mask_token_id: Optional[int] = None
    softmasking_prob: float = 0.5
    min_prob: float = 0.2
    max_prob: float = 0.8

class dLLMDataCollator(DefaultDataCollator):
    """
    Batch-time padding and noising/masking
    """
    def __init__(self, tokenizer, cfg: CollatorConfig):
        super().__init__()
        self.tokenizer = tokenizer
        self.mask_token_id = tokenizer.mask_token_id if cfg.mask_token_id is None else cfg.mask_token_id
        if self.mask_token_id is None:
            raise ValueError("mask_token_id must be set (diff_tok.mask_token_id is None).")
        self.softmasking_prob = cfg.softmasking_prob
        self.min_prob = cfg.min_prob
        self.max_prob = cfg.max_prob
        
    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
        """
        Collate and process a batch of features.
        """
        batch_seed = get_batch_seed(features)
        
        pad_length = self.get_pad_length_for_batch(features, batch_seed)
        for f in features:
            f["input_ids"] = self.prepad_input_ids(f["input_ids"], pad_length, self.tokenizer.pad_token_id)

        batch = super().__call__(features)

        noisy_batch, labels, block_ranges, block_mask_indices, t = self.forward_process(batch, batch_seed)
        labels, noisy_batch, block_mask_indices = self.add_at_least_one_mask(
            labels, noisy_batch, block_ranges, block_mask_indices, batch_seed
        )
        
        batch.update({
            "input_ids": noisy_batch.long(),
            "labels": labels.long(),
            "block_ranges": block_ranges,
            "block_mask_indices": block_mask_indices,      # optional, handy for debugging
            "t": t,                          
        })

        return batch

    # -- HELPERS ------------------------------------------------------------------------------

    # -- PREPARE INPUT IDS --------------------------------------------------------------------
    def get_pad_length_for_batch(self, features, batch_seed=None):
        """
        Determine the padding length for the batch.
        Adds a random extra length (0-50) to the max input length in the batch
        """
        max_input_length_in_batch = max(len(f["input_ids"]) for f in features)

        extra_len = seeded_randint(50, size=(1,), device=features[0]["input_ids"].device, seed=batch_seed)
        max_len = max_input_length_in_batch + int(extra_len.item())

        # Make it a multiple of 8 for efficiency
        if max_len % 8 != 0:
            max_len += 8 - (max_len % 8)

        return max_len
    
    def prepad_input_ids(self, input_ids, pad_length, pad_token_id):
        """
        Pad to pad_length
        """
        # --- pad full input_ids ---
        padded = torch.full((pad_length,), pad_token_id, dtype=input_ids.dtype, device=input_ids.device)
        padded[:input_ids.size(0)] = input_ids

        return padded

    # -- NOISE DATA --------------------------------------------------------------------

    def forward_process(self, batch, batch_seed=None, eps=1e-3):
        """
        Apply the forward noising/masking process to input_ids at the BLOCK level.
        Returns:
            noisy_batch, labels, blocks, block_mask_indices, t
        """
        input_ids = batch["input_ids"]
        block_ranges = batch.get("block_ranges", [])
        device = input_ids.device
        B, N = input_ids.shape

        p_mask, t = self.get_p_mask(batch, device, eps, batch_seed=batch_seed)

        # token-level mask
        mask_indices = torch.zeros((B, N), dtype=torch.bool, device=device)
        # block-level mask (list of tensors, since blocks vary per example)
        block_mask_indices = []

        for b in range(B):
            current_blocks = block_ranges[b]
            if not current_blocks.numel():
                block_mask_indices.append(torch.empty(0, dtype=torch.bool, device=device))
                continue

            # randomly decide which blocks to mask
            block_flags = seeded_rand((len(current_blocks),), device, seed=batch_seed) < p_mask
            block_mask_indices.append(block_flags)

            # mask all tokens in chosen blocks
            for flag, (a, b_) in zip(block_flags, current_blocks):
                if flag:
                    mask_indices[b, a:b_] = True

        # create noisy + label tensors
        noisy_batch = input_ids.masked_fill(mask_indices, int(self.mask_token_id))
        labels = input_ids.clone().masked_fill(~mask_indices, -100)

        return noisy_batch, labels, block_ranges, block_mask_indices, t

    def add_at_least_one_mask(self, labels, input_ids, block_ranges, block_mask_indices, batch_seed=None):
        """
        Ensure that each sample has at least one *block* masked.
        If no block was selected, randomly choose one block and mask it fully.
        """
        B, _ = labels.shape
        device = labels.device

        for b in range(B):
            sample_blocks = block_ranges[b]
            block_flags = block_mask_indices[b]

            # skip if this sample already has masked blocks
            if len(block_flags) > 0 and block_flags.any():
                continue
            if not sample_blocks.numel():
                continue

            # randomly select one block
            chosen = int(seeded_randint(len(sample_blocks), size=(1,), device=device, seed=batch_seed).item())
            a, b_ = sample_blocks[chosen]

            # mask that block
            labels[b, a:b_] = input_ids[b, a:b_]
            input_ids[b, a:b_] = self.mask_token_id
            block_flags[chosen] = True

        return labels, input_ids, block_mask_indices

    
    def get_p_mask(self, batch, device, eps=1e-3, batch_seed=None):
        """
        Get the masking probability p_mask and the corresponding noise scale dsigma = 1/t
        """
        if "t" in batch:
            t = batch["t"].float()
            if t.ndim > 0: t = t.mean()
            t = t.clamp_min(eps)
        else:
            t = ((self.max_prob - self.min_prob) * seeded_rand((), device, seed=batch_seed) + self.min_prob).clamp_min(eps)   # Uniform[0.2,0.8]

        p_mask = (1 - eps) * t + eps
        return p_mask, t