import torch
from typing import List, Dict, Any, Iterable
import torch.nn.functional as F
from transformers.generation.utils import GenerationConfig
import copy
from transformers import LogitsProcessorList



def _norm_tok_str(s: str) -> str:

    return s.replace("▁", " ").replace("Ġ", " ")

def build_semantic_sets(tokenizer) -> Dict[str, set]:

    digits, comma, lbr, rbr, space = set(), set(), set(), set(), set()
    zero, leading_zero = set(), set()

    V = tokenizer.vocab_size
    for tid in range(V):
        ts = tokenizer.convert_ids_to_tokens(tid) or ""
        tsn = _norm_tok_str(ts)
        st = tsn.strip()
        if st and all(c.isdigit() for c in st): 
            digits.add(tid)
            if st == "0":
                zero.add(tid)
            elif st.startswith("0"):
                leading_zero.add(tid)
        elif st == ",":
            comma.add(tid)
        elif st == "[":
            lbr.add(tid)
        elif st == "]":
            rbr.add(tid)
        if tsn.startswith(" ") or st == "":
            space.add(tid)
    if not digits:
        raise RuntimeError("未能在词表中识别到任何数字token。请检查tokenizer，或自定义数字集合判定逻辑。")
    return {"digits": digits, "comma": comma, "lbr": lbr, "rbr": rbr, "space": space, "zero": zero, "leading_zero": leading_zero}


def digits_only_logp(logits: torch.Tensor, digit_ids: set) -> torch.Tensor:

    # 用float32做softmax更稳
    l = logits.to(torch.float32)
    mask = torch.full_like(l, float('-inf'))
    idx = torch.tensor(list(digit_ids), device=l.device, dtype=torch.long)
    mask.scatter_(-1, idx.unsqueeze(0), 0.0)
    lp = (l + mask).log_softmax(dim=-1)
    return lp


class CoordFSM:
    OPEN, X_DIGITS, AFTER_COMMA, Y_DIGITS, CLOSED = range(5)
    def __init__(self, sem_sets: Dict[str, set]):
        self.sem = sem_sets
        self.state = self.OPEN
    def _sem_of(self, tid: int) -> str:
        s = self.sem
        if tid in s["digits"]: return "DIGIT"
        if tid in s["comma"]:  return "COMMA"
        if tid in s["lbr"]:    return "LBR"
        if tid in s["rbr"]:    return "RBR"
        if tid in s["space"]:  return "SPACE"
        return "OTHER"
    def in_digit_phase(self) -> bool:
        return self.state in (self.X_DIGITS, self.Y_DIGITS)
    def step(self, tid: int):
        t = self._sem_of(tid)
        if   self.state == self.OPEN and t == "LBR": self.state = self.X_DIGITS
        elif self.state == self.X_DIGITS and t == "DIGIT": pass
        elif self.state == self.X_DIGITS and t == "COMMA": self.state = self.AFTER_COMMA
        elif self.state == self.AFTER_COMMA and t == "SPACE": pass
        elif self.state == self.AFTER_COMMA and t == "DIGIT": self.state = self.Y_DIGITS
        elif self.state == self.Y_DIGITS and t == "DIGIT": pass
        elif self.state == self.Y_DIGITS and t == "RBR": self.state = self.CLOSED


def aggregate(tensors: Iterable[torch.Tensor], method="median", trim_ratio=0.2) -> torch.Tensor:

    xs = list(tensors)
    if len(xs) == 1:
        return xs[0]
    X = torch.stack(xs, dim=0)  # [K, 1, *]
    if method == "median":
        return X.median(dim=0).values
    if method == "trimmed_mean":
        K = X.size(0); k = int(K * trim_ratio)
        vals, _ = torch.sort(X, dim=0)
        vals = vals[k:K-k if K-2*k > 0 else K]
        return vals.mean(dim=0)
    return X.mean(dim=0)  # mean


def project_to_logits(model, h_last_step: torch.Tensor) -> torch.Tensor:

    h = h_last_step
    if hasattr(model, "model") and hasattr(model.model, "norm"):
        h = model.model.norm(h)
    return model.lm_head(h)


def _check_flash_attention2_ok(model, out):

    attn_impl = getattr(getattr(model, "config", None), "attn_implementation", None)
    if attn_impl == "flash_attention_2":
        if not hasattr(out, "hidden_states") or out.hidden_states is None:
            raise RuntimeError(

            )


import inspect

def _supports_cache_pos(model):
    return "cache_position" in inspect.signature(model.forward).parameters

def _cache_len(past_kv):

    k = past_kv[0][0]
    return k.shape[-2]


@torch.no_grad()
def scd_contrastive_with_shuffle_layerwise(
    model,
    processor,
    *,
    text: str,
    image_inputs: Any,
    video_inputs: Any = None,
    seeds: List[int] = (15, 23, 42),
    tap_layers: Iterable[int] = (-1,),
    layer_taus: float | Iterable[float] = 1.5,
    alpha_digit: float = 0.35,
    alpha_other: float = 0.0,
    center_B: bool = True,
    agg_method_paths: str = "median",
    agg_method_layers: str = "median",
    max_new_tokens: int = 32,
    decay: float = 0.5
):
    device = model.device
    tokenizer = getattr(processor, "tokenizer", None) or getattr(model, "tokenizer", None)
    if tokenizer is None:
        raise ValueError("无法获取 tokenizer（通常在 processor.tokenizer）。")


    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    inputs = {k: (v.to(device) if torch.is_tensor(v) else v) for k, v in inputs.items()}


    sem_sets = build_semantic_sets(tokenizer)
    fsm = CoordFSM(sem_sets)

    logits_dict_list = []
    topk_k = 10




    tap_layers = tuple(tap_layers)
    if not isinstance(layer_taus, (list, tuple)):
        layer_taus = [layer_taus] * len(tap_layers)
    assert len(layer_taus) == len(tap_layers),


    model.visual.vision_pe_shuffle = False
    out_A = model(**inputs, use_cache=True, output_hidden_states=True)


    _check_flash_attention2_ok(model, out_A)
    past_A = out_A.past_key_values
    logits_A = out_A.logits[:, -1, :]  # [1, V]
    hs_A = out_A.hidden_states         # tuple(len=L+1)

    # B
    past_B: Dict[int, Any] = {}
    logits_B_last: Dict[int, torch.Tensor] = {}
    hs_B_seed: Dict[int, tuple] = {}

    b_step_tokens = {s: [] for s in seeds}

    for s in seeds:
        model.visual.vision_pe_shuffle = True
        model.visual.vision_pe_shuffle_scope = 'per_image'
        model.visual.vision_pe_shuffle_seed = int(s)
        out_B = model(**inputs, use_cache=True, output_hidden_states=True)
        _check_flash_attention2_ok(model, out_B)
        past_B[s] = out_B.past_key_values
        logits_B_last[s] = out_B.logits[:, -1, :]  # 可能用不上，但保留
        hs_B_seed[s] = out_B.hidden_states

    # 起始生成序列（A路）
    generated = inputs["input_ids"][0:1, :]
    prompt_len = generated.size(1)

    x_exp = None
    y_exp = None

    device = generated.device if torch.is_tensor(generated) else getattr(model, "device", "cpu")

    gen_cfg = GenerationConfig.from_dict(model.generation_config.to_dict())
    gen_cfg.do_sample = False
    gen_cfg.num_beams = 1
    gen_cfg.temperature = 1.0  # 贪心时无效，但保留一致性


    if not hasattr(model, "_prepare_special_tokens"):
        raise RuntimeError("当前模型未暴露 _prepare_special_tokens，请改用手动注入（方法B）。")
    model._prepare_special_tokens(gen_cfg, device=device)

    logits_processor = model._get_logits_processor(
        generation_config=gen_cfg,
        input_ids_seq_length=generated.shape[-1],
        encoder_input_ids=None,
        prefix_allowed_tokens_fn=None,
        logits_processor=LogitsProcessorList(),  # 先给一个空列表；需要时可在这里预塞自定义 processor
        device=device,
        model_kwargs={},
    )




    for _ in range(max_new_tokens):

        prev_state = fsm.state


        proc_logits_A = logits_processor(generated, logits_A.clone())

        logp_A = proc_logits_A.to(torch.float32).log_softmax(dim=-1)

        for s in seeds:
            proc_logits_B = logits_processor(generated, logits_B_last[s].clone())
            b_step_tokens[s].append(int(torch.argmax(proc_logits_B, dim=-1).item()))


        if (fsm.in_digit_phase() or fsm.state == fsm.AFTER_COMMA) and len(seeds) > 0:

            lpB_layers = []
            for li, Lidx in enumerate(tap_layers):
                tau = float(max(layer_taus[li], 1e-6))

                per_path = []
                for s in seeds:

                    h = hs_B_seed[s][Lidx][0, -1, :].unsqueeze(0)  # [1,H]
                    logits_B_L = project_to_logits(model, h) / tau # [1,V]
                    logits_B_L = logits_processor(generated, logits_B_L)
                    lp_digits = digits_only_logp(logits_B_L, sem_sets["digits"])  # [1,V]
                    per_path.append(lp_digits)

                lp_layer = aggregate(per_path, method=agg_method_paths)  # [1,V]
                if center_B:

                    idx = torch.tensor(list(sem_sets["digits"]), device=lp_layer.device, dtype=torch.long)
                    mean_val = lp_layer.gather(-1, idx.unsqueeze(0)).mean(dim=-1, keepdim=True)
                    scatter = torch.full_like(lp_layer, 0.0).scatter(-1, idx.unsqueeze(0), 1.0)
                    lp_layer = lp_layer - mean_val * scatter
                lpB_layers.append(lp_layer)

            lpB_digits_all = aggregate(lpB_layers, method=agg_method_layers)  # [1,V]

            if fsm.state == fsm.X_DIGITS:
                exp = 0 if x_exp is None else x_exp
                current_alpha = alpha_digit * (decay ** exp)
            elif fsm.state == fsm.AFTER_COMMA:
                current_alpha = alpha_digit * (decay ** 0)
            elif fsm.state == fsm.Y_DIGITS:
                exp = 0 if y_exp is None else y_exp
                current_alpha = alpha_digit * (decay ** exp)
            else:
                current_alpha = alpha_digit


            score = logp_A.clone()
            idx = torch.tensor(list(sem_sets["digits"]), device=score.device, dtype=torch.long)
            score.scatter_(
                -1, idx.unsqueeze(0),
                (logp_A.gather(-1, idx.unsqueeze(0)) - current_alpha * lpB_digits_all.gather(-1, idx.unsqueeze(0)))
            )
            if alpha_other > 0.0:

                ref_s = seeds[0]
                lpB_other = logits_B_last[ref_s].to(torch.float32).log_softmax(dim=-1)
                score = score - alpha_other * lpB_other
        else:

            score = logp_A


        is_first_x_digit = (prev_state == fsm.X_DIGITS) and (x_exp is None or x_exp == 0)

        is_first_y_digit = (prev_state == fsm.AFTER_COMMA)

        if is_first_x_digit or is_first_y_digit:
            ban_ids = list(sem_sets.get("zero", [])) + list(sem_sets.get("leading_zero", []))
            if ban_ids:
                ban = torch.tensor(ban_ids, device=score.device, dtype=torch.long).unsqueeze(0)
                score.scatter_(-1, ban, float("-inf"))


        probs_post = F.softmax(score[0], dim=-1)
        topk = torch.topk(probs_post, k=topk_k)
        tokens = tokenizer.convert_ids_to_tokens(topk.indices.tolist())
        step_dict = {t: p.item() for t, p in zip(tokens, topk.values)}
        logits_dict_list.append(step_dict)


        next_token = torch.argmax(score, dim=-1)  # [1]


        next_id = next_token.item()
        is_digit = next_id in sem_sets["digits"]
        is_comma = next_id in sem_sets["comma"]


        if prev_state == fsm.X_DIGITS and is_digit:
            x_exp = (0 if x_exp is None else x_exp) + 1


        elif prev_state == fsm.AFTER_COMMA and is_digit:
            y_exp = 1


        elif prev_state == fsm.Y_DIGITS and is_digit:
            y_exp = (0 if y_exp is None else y_exp) + 1


        if is_comma:
            y_exp = None

        fsm.step(next_token.item())


        generated = torch.cat([generated, next_token.view(1,1)], dim=-1)


        eos_id = getattr(tokenizer, "eos_token_id", None)
        if eos_id is not None and next_token.item() == eos_id:
            break


        model.visual.vision_pe_shuffle = False
        supports_cache_pos = _supports_cache_pos(model)


        kwargsA = {}
        if supports_cache_pos:
            posA = torch.tensor([_cache_len(past_A)], device=model.device, dtype=torch.long)
            kwargsA["cache_position"] = posA
        out_A = model(input_ids=next_token.view(1, 1), use_cache=True,
                      past_key_values=past_A, output_hidden_states=True, **kwargsA)


        past_A = out_A.past_key_values
        logits_A = out_A.logits[:, -1, :]
        hs_A = out_A.hidden_states


        for s in seeds:
            model.visual.vision_pe_shuffle = True
            model.visual.vision_pe_shuffle_scope = 'per_image'
            model.visual.vision_pe_shuffle_seed = int(s)
            kwargsB = {}
            if supports_cache_pos:
                posB = torch.tensor([_cache_len(past_B[s])], device=model.device, dtype=torch.long)
                kwargsB["cache_position"] = posB
            out_B = model(input_ids=next_token.view(1,1), use_cache=True,
                          past_key_values=past_B[s], output_hidden_states=True, **kwargsB)
            past_B[s] = out_B.past_key_values
            logits_B_last[s] = out_B.logits[:, -1, :]
            hs_B_seed[s] = out_B.hidden_states



    gen_only = generated[:, prompt_len:]
    if hasattr(processor, "batch_decode"):
        decoded = processor.batch_decode(gen_only, skip_special_tokens=True)[0]
    else:
        decoded = tokenizer.decode(gen_only[0], skip_special_tokens=True)

    return decoded, logits_dict_list
