import torch
import numpy as np
import torch.nn.functional as F
import time
from transformers import AutoTokenizer, AutoModel
import matplotlib.pyplot as plt
import seaborn as sns
import re
import math
from prompt import gsm8k_prompt
import math

def parse_flops_str(s):
    """将如 '1.23 GFLOPs' 解析为 1.23e9"""
    number = float(re.findall(r"[\d\.]+", s)[0])
    if 'T' in s:
        return number * 1e12
    elif 'G' in s:
        return number * 1e9
    elif 'M' in s:
        return number * 1e6
    elif 'K' in s:
        return number * 1e3
    else:
        return number


def  add_gumbel_noise(logits, temperature):
    '''
    The Gumbel max is a method for sampling categorical distributions.
    According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality.
    Thus, we use float64.
    '''
    if temperature == 0:
        return logits
    logits = logits.to(torch.float64)
    noise = torch.rand_like(logits, dtype=torch.float64)
    gumbel_noise = (- torch.log(noise)) ** temperature
    return logits.exp() / gumbel_noise


def get_num_transfer_tokens(mask_index, steps):
    mask_num = mask_index.sum().item()
    base = mask_num // steps
    remainder = mask_num % steps

    num_transfer_tokens = torch.full((steps,), base, dtype=torch.int64, device=mask_index.device)
    num_transfer_tokens[:remainder] += 1
    return num_transfer_tokens

def get_ordered_positions(length):
    seen = set()
    result = []
    max_depth = int(math.ceil(math.log2(length))) + 1
    for depth in range(1, max_depth):
        step = 2 ** depth
        for i in range(1, step, 2):
            pos = (i * length) // step
            if pos < length and pos not in seen:
                result.append(pos)
                seen.add(pos)
    for i in range(length):
        if i not in seen:
            result.append(i)
    return result

def get_preorder_bst_positions(length):
    result = []

    def build(start, end):
        if start >= end:
            return
        mid = (start + end) // 2
        result.append(mid)
        build(start, mid)   # 左子树
        build(mid + 1, end) # 右子树

    build(0, length)
    return result

@torch.no_grad()
def generate(model,
            prompt,
            tokenizer,
            steps: int = 1024,
            gen_length: int = 1024,
            block_length: int = 8,
            temperature: float = 0.,
            cfg_scale: float = 0.,
            remasking: str = 'low_confidence',                
            mask_id=126336,
            eos_id=126081):
    device = model.device
    prompt_len = prompt.shape[1]
    x = torch.full((1, prompt_len + gen_length), mask_id, dtype=torch.long, device=device)
    x[:, :prompt_len] = prompt.clone()
    prompt_index = (prompt != mask_id)

    assert gen_length % block_length == 0
    num_blocks = gen_length // block_length

    assert steps % num_blocks == 0
    steps_per_block = steps // num_blocks

    total_step = 0

    for block_idx in range(num_blocks):
        block_start = prompt_len + block_idx * block_length
        block_end = block_start + block_length

        block_mask_index = (x[:, block_start: block_end] == mask_id)
        num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps)

        for step_idx in range(steps_per_block):
            total_step += 1
            mask_index = (x == mask_id)
            
            start = time.perf_counter()
            if cfg_scale > 0.:
                un_x = x.clone()
                un_x[prompt_index] = mask_id
                x_ = torch.cat([x, un_x], dim=0)
                logits = model(x_).logits
                logits, un_logits = torch.chunk(logits, 2, dim=0)
                logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
            else:
                logits = model(x).logits

            noisy = add_gumbel_noise(logits, temperature)
            x0 = torch.argmax(noisy, dim=-1)

            if remasking == 'low_confidence':
                p = F.softmax(logits, dim=-1)
                x0_p = torch.squeeze(
                    torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
            elif remasking == 'random':
                x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
            elif remasking == 'auto_regression':
                x0_p = torch.full((x0.shape[0], x0.shape[1]), -float('inf'), device=x0.device)
                for j in range(x0.shape[0]):
                    block_mask_indices = torch.nonzero(mask_index[:, block_start: block_end], as_tuple=False).squeeze(-1)
                    
                    k = num_transfer_tokens[step_idx].item()
                    selected = block_mask_indices[:k]
                    for idx in selected:
                        x0_p[j, block_start + idx] = 1.0
            else:
                raise NotImplementedError(remasking)

            x0_p[:, block_end:] = -np.inf
            x0 = torch.where(mask_index, x0, x)
            confidence = torch.where(mask_index, x0_p, -np.inf)

            transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
            for j in range(confidence.shape[0]):
                _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[step_idx])
                transfer_index[j, select_index] = True
            x[transfer_index] = x0[transfer_index]

        if eos_id in x[:, prompt_len:block_end]:
            x[:, block_end:] = eos_id
            break
    return x[:, prompt_len:], total_step

@torch.no_grad()
def generate_wavefront(model,
                       prompt,
                       tokenizer,
                       steps: int = 1024,
                       gen_length: int = 1024,
                       r: int = 4,                 # 邻域半径
                       K: int = 1,                 # 每步更新的 token 数
                       F_max: int = 8,            # 波前最大规模
                       threshold: float = 0.7,     # 收敛阈值
                       lambda1: float = 0.1,       # anchor proximity 权重
                       temperature: float = 0.,
                       cfg_scale: float = 0.,
                       remasking: str = "low_confidence",
                       mask_id=126336,
                       eos_id=126081):
    device = model.device
    prompt_len = prompt.shape[1]
    seq_len = prompt_len + gen_length

    x = torch.full((1, seq_len), mask_id, dtype=torch.long, device=device)
    x[:, :prompt_len] = prompt.clone()
    prompt_index = (prompt != mask_id)

    anchors = torch.arange(prompt_len, device=device)

    mask_positions = torch.arange(prompt_len, seq_len, device=device)

    F_mask = torch.zeros(seq_len, dtype=torch.bool, device=device)
    init_idx = torch.arange(prompt_len, min(prompt_len + r + 1, seq_len), device=device)
    F_mask[init_idx] = True

    step = 0
    while step < steps and (x == mask_id).any():
        step += 1
        mask_index = (x == mask_id)

        # ----------- CFG trick -----------
        if cfg_scale > 0.:
            un_x = x.clone()
            un_x[prompt_index] = mask_id
            x_ = torch.cat([x, un_x], dim=0)
            logits = model(x_).logits
            logits, un_logits = torch.chunk(logits, 2, dim=0)
            logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
        else:
            logits = model(x).logits  # [1, L, V]

        probs = F.softmax(logits, dim=-1)
        top1 = torch.argmax(probs, dim=-1)           # [1, L]
        top1_conf = torch.gather(probs, -1, top1.unsqueeze(-1)).squeeze(-1)  # [1, L]

        if F_mask.sum().item() == 0:
            # breakpoint()
            mask_conf = top1_conf[mask_index]   # [M]
            mask_pos  = torch.nonzero(mask_index, as_tuple=True)[1]
        
            sorted_idx = torch.argsort(mask_conf, descending=True)
        
            topk = min(F_max, mask_pos.size(0))
            chosen = sorted_idx[:topk]
            F_mask = torch.zeros(seq_len, dtype=torch.bool, device=device)
            F_mask[mask_pos[chosen]] = True

        # ----------- 计算 scores -----------
        scores = torch.full_like(top1_conf[0], -1e9)
        scores[F_mask] = top1_conf[0, F_mask]


        # 选出前 K
        sel_scores, sel_idx = torch.topk(scores, k=min(K, F_mask.sum().item()))
        if sel_idx.numel() == 0:
            break

        # 更新选中位置
        x[0, sel_idx] = top1[0, sel_idx]

        # ----------- 邻域扩展 -----------
        new_F_mask = F_mask.clone()
        for offset in range(-r, r + 1):
            nb = sel_idx + offset
            valid_nb = (nb >= prompt_len) & (nb < seq_len)
            new_F_mask[nb[valid_nb]] = True
        F_mask = new_F_mask

        # ----------- 收敛过滤 -----------
        F_mask &= (x[0] == mask_id)

        # ----------- 限制 F_max -----------
        if F_mask.sum() > F_max:
            pruned_scores = torch.full_like(top1_conf[0], -1e9)
            pruned_scores[F_mask] = top1_conf[0, F_mask]
            keep_scores, keep_idx = torch.topk(pruned_scores, k=F_max)
            new_mask = torch.zeros_like(F_mask)
            new_mask[keep_idx] = True
            F_mask = new_mask

        # ----------- EOS 检查 -----------
        if eos_id in x[0, prompt_len:]:
            eos_pos = (x[0, prompt_len:] == eos_id).nonzero(as_tuple=True)[0][0].item()
            x[0, prompt_len + eos_pos + 1:] = eos_id
            
    # while (x == mask_id).any():
    #     mask_index = (x == mask_id)[0]
    #     if not mask_index.any():
    #         break

    #     logits = model(x).logits
    #     probs = F.softmax(logits, dim=-1)
    #     top1 = torch.argmax(probs, dim=-1)        
    #     top1_conf = torch.gather(probs, -1, top1.unsqueeze(-1)).squeeze(-1)  # [1, L]

    #     scores = torch.full_like(top1_conf[0], -1e9)
    #     scores[mask_index] = top1_conf[0, mask_index]
    #     sel_score, sel_idx = torch.topk(scores, k=1)

    #     x[0, sel_idx] = top1[0, sel_idx]

    return x, step


@torch.no_grad()
def blockdiffusion_ppl(model,
                        prompt: torch.LongTensor,
                        continuation: torch.LongTensor,
                        steps: int = 1024,
                        block_length: int = 8,
                        temperature: float = 0.,
                        cfg_scale: float = 0.,
                        remasking: str = 'low_confidence',
                        mask_id=126336,
                        eos_id=126081,
                        tokenizer=None):
    """
    评估 BlockDiffusion 生成方法的 Masked Reconstruction Loss / PPL。
    输入:
        model: Diffusion LM
        prompt: [1, Lp] 条件 tokens
        continuation: [1, Lc] ground truth continuation tokens
    输出:
        avg_nll, ppl
    """
    device = prompt.device
    Lp, Lc = prompt.shape[1], continuation.shape[1]
    seq_len = Lp + Lc
    ground_truth = torch.cat([prompt, continuation], dim=1)

    # 初始化 mask
    x = torch.full((1, seq_len), mask_id, dtype=torch.long, device=device)
    x[:, :Lp] = prompt.clone()
    prompt_index = (x != mask_id)

    num_blocks = Lc // block_length
    steps_per_block = steps // num_blocks

    total_nll = 0.0
    total_tokens = 0

    # 记录每个 token 是否已经累积过 NLL
    nll_recorded = torch.zeros_like(x, dtype=torch.bool, device=device)

    for block_idx in range(num_blocks):
        block_start = Lp + block_idx * block_length
        block_end = block_start + block_length
        block_mask_index = (x[:, block_start:block_end] == mask_id)

        for step_idx in range(steps_per_block):
            mask_index = (x == mask_id)

            # ---- CFG trick ----
            if cfg_scale > 0.:
                un_x = x.clone()
                un_x[:, :Lp] = mask_id
                x_ = torch.cat([x, un_x], dim=0)
                logits = model(x_).logits
                logits, un_logits = torch.chunk(logits, 2, dim=0)
                logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
            else:
                logits = model(x).logits

            # ---- 采样候选 token ----
            noisy = logits / (temperature + 1e-12)  # 可选 Gumbel/softmax
            x0 = torch.argmax(noisy, dim=-1)

            # ---- NLL 只累积第一次填充的 token ----
            for j in range(x.shape[0]):
                masked_positions = torch.nonzero(mask_index[j], as_tuple=False).squeeze(-1)
                for pos in masked_positions.tolist():
                    if pos < Lp:  # 跳过 prompt
                        continue
                    if nll_recorded[j, pos]:
                        continue
                    true_id = ground_truth[j, pos].item()
                    p_val = F.softmax(logits[j, pos], dim=-1)[true_id].item()
                    total_nll += -math.log(max(p_val, 1e-12))
                    total_tokens += 1
                    nll_recorded[j, pos] = True

            # ---- 更新 token ----
            if remasking == 'low_confidence':
                p = F.softmax(logits, dim=-1)
                x0_p = torch.gather(p, -1, x0.unsqueeze(-1)).squeeze(-1)
            else:
                x0_p = torch.ones_like(x0, dtype=torch.float, device=device)

            x[mask_index] = x0[mask_index]

        # EOS 检查
        if eos_id in x[:, Lp:block_end]:
            eos_pos = (x[:, Lp:block_end] == eos_id).nonzero(as_tuple=True)[1][0].item()
            x[:, block_end:] = eos_id
            break

    # ---- 计算 PPL ----
    if total_tokens > 0:
        avg_nll = total_nll / total_tokens
        ppl = math.exp(avg_nll)
    else:
        avg_nll, ppl = float("inf"), float("inf")

    return avg_nll, ppl



@torch.no_grad()
def wavefront_ppl(model,
                  prompt: torch.LongTensor,
                  continuation: torch.LongTensor,
                  steps: int = 1024,
                  r: int = 4,
                  K: int = 1,
                  F_max: int = 8,
                  temperature: float = 0.,
                  cfg_scale: float = 0.,
                  mask_id=126336,
                  eos_id=126081,
                  tokenizer=None):
    """
    评估 Wavefront 生成方法的 Masked Reconstruction Loss / PPL。
    输入:
        model: Diffusion LM
        prompt: [1, Lp] 条件 tokens
        continuation: [1, Lc] ground truth continuation tokens
    输出:
        avg_nll, ppl
    """
    device = prompt.device
    Lp, Lc = prompt.shape[1], continuation.shape[1]
    seq_len = Lp + Lc

    # 拼接 ground truth
    ground_truth = torch.cat([prompt, continuation], dim=1)  # [1, seq_len]

    # 初始：prompt 保留，continuation 全 mask
    x = torch.full((1, seq_len), mask_id, dtype=torch.long, device=device)
    x[:, :Lp] = prompt.clone()

    # 初始 wavefront 区域
    F_mask = torch.zeros(seq_len, dtype=torch.bool, device=device)
    init_idx = torch.arange(Lp, min(Lp + r + 1, seq_len), device=device)
    F_mask[init_idx] = True

    step, total_nll, total_tokens = 0, 0.0, 0

    while step < steps and (x[:, Lp:] == mask_id).any():
        step += 1
        mask_index = (x == mask_id)

        # ---- CFG trick ----
        if cfg_scale > 0.:
            un_x = x.clone()
            un_x[:, :Lp] = mask_id
            x_ = torch.cat([x, un_x], dim=0)
            logits = model(x_).logits
            logits, un_logits = torch.chunk(logits, 2, dim=0)
            logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
        else:
            logits = model(x).logits  # [1, L, V]

        probs = F.softmax(logits, dim=-1)  # [1, L, V]

        # ---- NLL accumulation (only continuation) ----
        for j in range(x.shape[0]):
            masked_positions = torch.nonzero(mask_index[j], as_tuple=False).squeeze(-1)
            for pos in masked_positions.tolist():
                if pos < Lp:  # skip prompt
                    continue
                true_id = ground_truth[j, pos].item()
                p_val = probs[j, pos, true_id].item()
                total_nll += -math.log(max(p_val, 1e-12))
                total_tokens += 1

        # ---- Wavefront update ----
        top1 = torch.argmax(probs, dim=-1)
        top1_conf = torch.gather(probs, -1, top1.unsqueeze(-1)).squeeze(-1)

        if F_mask.sum().item() == 0:
            mask_conf = top1_conf[mask_index]
            mask_pos  = torch.nonzero(mask_index, as_tuple=True)[1]
            sorted_idx = torch.argsort(mask_conf, descending=True)
            topk = min(F_max, mask_pos.size(0))
            chosen = sorted_idx[:topk]
            F_mask = torch.zeros(seq_len, dtype=torch.bool, device=device)
            F_mask[mask_pos[chosen]] = True

        scores = torch.full_like(top1_conf[0], -1e9)
        scores[F_mask] = top1_conf[0, F_mask]
        sel_scores, sel_idx = torch.topk(scores, k=min(K, F_mask.sum().item()))
        if sel_idx.numel() == 0:
            break

        x[0, sel_idx] = top1[0, sel_idx]

        # 邻域扩展
        new_F_mask = F_mask.clone()
        for offset in range(-r, r + 1):
            nb = sel_idx + offset
            valid_nb = (nb >= Lp) & (nb < seq_len)
            new_F_mask[nb[valid_nb]] = True
        F_mask = new_F_mask

        # 只对 continuation mask
        F_mask &= (x[0] == mask_id)

        # 限制 F_max
        if F_mask.sum() > F_max:
            pruned_scores = torch.full_like(top1_conf[0], -1e9)
            pruned_scores[F_mask] = top1_conf[0, F_mask]
            keep_scores, keep_idx = torch.topk(pruned_scores, k=F_max)
            new_mask = torch.zeros_like(F_mask)
            new_mask[keep_idx] = True
            F_mask = new_mask

        # EOS 检查（只在 continuation 内）
        if eos_id in x[0, Lp:]:
            eos_pos = (x[0, Lp:] == eos_id).nonzero(as_tuple=True)[0][0].item()
            x[0, Lp + eos_pos + 1:] = eos_id

    # ---- Final ppl ----
    if total_tokens > 0:
        avg_nll = total_nll / total_tokens
        ppl = math.exp(avg_nll)
    else:
        avg_nll, ppl = float("inf"), float("inf")

    return avg_nll, ppl



def main():
    device = 'cuda'

    model = AutoModel.from_pretrained(
        'GSAI-ML/LLaDA-8B-Instruct', 
        trust_remote_code=True, torch_dtype=torch.bfloat16).to(device).eval()
    tokenizer = AutoTokenizer.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True)
    tokenizer.pad_token_id = 126336

    prompt = '''\n\nfrom typing import List\n\n\ndef factorize(n: int) -> List[int]:\n    """ Return list of prime factors of given integer in the order from smallest to largest.\n    Each of the factors should be listed number of times corresponding to how many times it appeares in factorization.\n    Input number should be equal to the product of all factors\n    >>> factorize(8)\n    [2, 2, 2]\n    >>> factorize(25)\n    [5, 5]\n    >>> factorize(70)\n    [2, 5, 7]\n    """'''

    # Add special tokens for the Instruct model. The Base model does not require the following two lines.
    m = [{"role": "user", "content": prompt}, ]
    prompt = tokenizer.apply_chat_template(m, add_generation_prompt=True, tokenize=False)

    input_ids = tokenizer(prompt)['input_ids']
    input_ids = torch.tensor(input_ids).to(device).unsqueeze(0)

    # update_prompt = '''To implement the `factorize` function, we need to find the prime factors of a given integer `n` and return them in a list. The factors should be listed in ascending order, and each factor should appear as many times as it does in the factorization.\n\nHere\'s a step-by-step approach to achieve this:\n\n1. **Initialize an empty list to store the factors.**\n2. **Check if `n` is divisible by 2.** If it is, append 2 to the list and divide `n` by 2.\n3. **Check if `n` is divisible by 3.** If it is, append 3 to the list and divide `n` by 3.\n4. **Continue this process until `n` is greater than 1.**\n5. **If `n` is still greater than 1 after the loop, it means `n` itself is a prime number and should be added to the list.**\n\nHere\'s the complete implementation of the `factorize` function:\n\n```python\nfrom typing import List\n\ndef factorize(n: int) -> List[int]:\n    """ Return list of prime factors of given integer in the order from smallest to largest.\n    Each of the factors should be listed number of times corresponding to how many times it appears in factorization.\n    Input number should be equal to the product of all factors\n    >>> factorize(8)\n    [2, 2, 2]\n    >>> factorize(25)\n    [5, 5]\n    >>> factorize(70)\n    [2, 5, 7]\n    """\n    factors = []\n    if n <= 1:\n        return factors\n\n    while n % 2 == 0:\n        factors.append(2)\n        n //= 2\n\n    if n > 2:\n        while n % 3 == 0:\n            factors.append(3)\n            n //= 3\n\n    i = 5\n    while i * i <= n:\n        while n % i == 0'''
    # update_prompt_id = tokenizer(update_prompt, return_tensors="pt").input_ids.to(device)

    # input_ids =  torch.cat([input_ids, update_prompt_id], dim=1)

    start_time = time.perf_counter()
    # start_time = time.perf_counter()
    out, total_steps = generate_before(
        model,
        prompt,
        tokenizer,
        steps=512,
        gen_length=512,
        block_length=8,
        temperature=0.0,
        cfg_scale=0.,
        # remasking='binary_search',
    )

    # out, redun_steps, total_steps = generate_with_position_order_sampling(model, input_ids, tokenizer)
    print(f"Cost {time.perf_counter() - start_time}")
    print(tokenizer.decode(out[0]))
    # print(f"Prompt Len is {len(input_ids[0])} and Gen_len is {len(out[0])}\n")
    # print(f"Redundent Steps in Generation is {redun_steps} / Total Steps is {total_steps}.")


if __name__ == '__main__':
    main()
