import torch
import torch.nn.functional as F
from dataclasses import dataclass, field
from typing import Any, Dict, Optional
import re

def centered_mean_unfold(x: torch.Tensor, k: int, pad_mode='reflect'):
    x_pad = F.pad(x[None, None, :], (k, k), mode=pad_mode).squeeze(0).squeeze(0)  
    windows = x_pad.unfold(0, size=2*k+1, step=1)
    return windows.mean(dim=-1)

def invert_mask(mask, dtype): 
    # Inversion using bitwise NOT and multiplication
    return (~mask).to(dtype) * torch.finfo(dtype).min

@dataclass
class DraftParams:
    temperature: float = 1
    max_depth: int = 8
    topk_len: int = 1
    max_verify_tokens: int = None
    generator_kwargs: Dict[str, Any] = field(default_factory=dict)
    
    def __post_init__(self):
        self.max_sample_tokens = self.max_depth * self.topk_len + 1
        self.max_verify_tokens = min(self.max_sample_tokens, self.max_verify_tokens) if self.max_verify_tokens is not None else self.max_sample_tokens


def build_layer_device_map(model):
    """
    Returns: {layer_idx: device_string}, e.g. {0: "cuda:0", 1: "cuda:0", ...}
    """
    if not hasattr(model, "hf_device_map"):
        raise RuntimeError("Model has no hf_device_map. Did you load with device_map='auto'?")

    layer_device_map = {}
    # Match entries like "model.layers.12" or "model.model.layers.12"
    pat = re.compile(r"(?:^|\.)(?:layers)\.(\d+)(?:\.|$)")

    for module_name, dev in model.hf_device_map.items():
        m = pat.search(module_name)
        if m:
            layer_idx = int(m.group(1))
            layer_device_map[layer_idx] = dev

    if not layer_device_map:
        # Fallback: inspect your model.hf_device_map keys and adjust regex.
        raise RuntimeError(f"Could not infer layer_device_map. hf_device_map keys sample: {list(model.hf_device_map)[:20]}")

    return layer_device_map
