import torch
import numpy as np
import torch.nn.functional as F
import time
from transformers import AutoTokenizer, AutoModel
import matplotlib.pyplot as plt
import seaborn as sns
import re
import math
from prompt import gsm8k_prompt

def parse_flops_str(s):
    """将如 '1.23 GFLOPs' 解析为 1.23e9"""
    number = float(re.findall(r"[\d\.]+", s)[0])
    if 'T' in s:
        return number * 1e12
    elif 'G' in s:
        return number * 1e9
    elif 'M' in s:
        return number * 1e6
    elif 'K' in s:
        return number * 1e3
    else:
        return number


def  add_gumbel_noise(logits, temperature):
    '''
    The Gumbel max is a method for sampling categorical distributions.
    According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality.
    Thus, we use float64.
    '''
    if temperature == 0:
        return logits
    logits = logits.to(torch.float64)
    noise = torch.rand_like(logits, dtype=torch.float64)
    gumbel_noise = (- torch.log(noise)) ** temperature
    return logits.exp() / gumbel_noise


def get_num_transfer_tokens(mask_index, steps):
    mask_num = mask_index.sum().item()
    base = mask_num // steps
    remainder = mask_num % steps

    num_transfer_tokens = torch.full((steps,), base, dtype=torch.int64, device=mask_index.device)
    num_transfer_tokens[:remainder] += 1
    return num_transfer_tokens

def get_ordered_positions(length):
    seen = set()
    result = []
    max_depth = int(math.ceil(math.log2(length))) + 1
    for depth in range(1, max_depth):
        step = 2 ** depth
        for i in range(1, step, 2):
            pos = (i * length) // step
            if pos < length and pos not in seen:
                result.append(pos)
                seen.add(pos)
    for i in range(length):
        if i not in seen:
            result.append(i)
    return result

def get_preorder_bst_positions(length):
    result = []

    def build(start, end):
        if start >= end:
            return
        mid = (start + end) // 2
        result.append(mid)
        build(start, mid)   # 左子树
        build(mid + 1, end) # 右子树

    build(0, length)
    return result

@torch.no_grad()
def generate(model,
            prompt,
            tokenizer,
            steps: int = 1024,
            gen_length: int = 1024,
            block_length: int = 8,
            temperature: float = 0.,
            cfg_scale: float = 0.,
            remasking: str = 'low_confidence',                
            mask_id=126336,
            eos_id=126081):
    device = model.device
    prompt_len = prompt.shape[1]
    x = torch.full((1, prompt_len + gen_length), mask_id, dtype=torch.long, device=device)
    x[:, :prompt_len] = prompt.clone()
    prompt_index = (prompt != mask_id)

    assert gen_length % block_length == 0
    num_blocks = gen_length // block_length

    assert steps % num_blocks == 0
    steps_per_block = steps // num_blocks

    total_step = 0

    for block_idx in range(num_blocks):
        block_start = prompt_len + block_idx * block_length
        block_end = block_start + block_length

        block_mask_index = (x[:, block_start: block_end] == mask_id)
        num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps)

        for step_idx in range(steps_per_block):
            total_step += 1
            mask_index = (x == mask_id)
            
            start = time.perf_counter()
            if cfg_scale > 0.:
                un_x = x.clone()
                un_x[prompt_index] = mask_id
                x_ = torch.cat([x, un_x], dim=0)
                logits = model(x_).logits
                logits, un_logits = torch.chunk(logits, 2, dim=0)
                logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
            else:
                logits = model(x).logits

            noisy = add_gumbel_noise(logits, temperature)
            x0 = torch.argmax(noisy, dim=-1)

            if remasking == 'low_confidence':
                p = F.softmax(logits, dim=-1)
                x0_p = torch.squeeze(
                    torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
            elif remasking == 'random':
                x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
            elif remasking == 'auto_regression':
                x0_p = torch.full((x0.shape[0], x0.shape[1]), -float('inf'), device=x0.device)
                for j in range(x0.shape[0]):
                    block_mask_indices = torch.nonzero(mask_index[:, block_start: block_end], as_tuple=False).squeeze(-1)
                    
                    k = num_transfer_tokens[step_idx].item()
                    selected = block_mask_indices[:k]
                    for idx in selected:
                        x0_p[j, block_start + idx] = 1.0
            else:
                raise NotImplementedError(remasking)

            x0_p[:, block_end:] = -np.inf
            x0 = torch.where(mask_index, x0, x)
            confidence = torch.where(mask_index, x0_p, -np.inf)

            transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
            for j in range(confidence.shape[0]):
                _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[step_idx])
                transfer_index[j, select_index] = True
            x[transfer_index] = x0[transfer_index]

        if eos_id in x[:, prompt_len:block_end]:
            x[:, block_end:] = eos_id
            break
    return x[:, prompt_len:], total_step

@torch.no_grad()
def generate_wavefront(model,
                       prompt,
                       tokenizer,
                       steps: int = 1024,
                       gen_length: int = 1024,
                       r: int = 4,                 # 邻域半径
                       K: int = 1,                 # 每步更新的 token 数
                       F_max: int = 8,            # 波前最大规模
                       threshold: float = 0.7,     # 收敛阈值
                       lambda1: float = 0.1,       # anchor proximity 权重
                       temperature: float = 0.,
                       cfg_scale: float = 0.,
                       mask_id=126336,
                       eos_id=126081):
    device = model.device
    prompt_len = prompt.shape[1]
    seq_len = prompt_len + gen_length

    x = torch.full((1, seq_len), mask_id, dtype=torch.long, device=device)
    x[:, :prompt_len] = prompt.clone()
    prompt_index = (prompt != mask_id)

    anchors = torch.arange(prompt_len, device=device)

    mask_positions = torch.arange(prompt_len, seq_len, device=device)

    F_mask = torch.zeros(seq_len, dtype=torch.bool, device=device)
    init_idx = torch.arange(prompt_len, min(prompt_len + r + 1, seq_len), device=device)
    F_mask[init_idx] = True

    step = 0
    while step < steps and (x == mask_id).any():
        step += 1
        mask_index = (x == mask_id)

        # ----------- CFG trick -----------
        if cfg_scale > 0.:
            un_x = x.clone()
            un_x[prompt_index] = mask_id
            x_ = torch.cat([x, un_x], dim=0)
            logits = model(x_).logits
            logits, un_logits = torch.chunk(logits, 2, dim=0)
            logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
        else:
            logits = model(x).logits  # [1, L, V]

        probs = F.softmax(logits, dim=-1)
        top1 = torch.argmax(probs, dim=-1)           # [1, L]
        top1_conf = torch.gather(probs, -1, top1.unsqueeze(-1)).squeeze(-1)  # [1, L]

        if F_mask.sum().item() == 0:
            # breakpoint()
            mask_conf = top1_conf[mask_index]   # [M]
            mask_pos  = torch.nonzero(mask_index, as_tuple=True)[1]
        
            sorted_idx = torch.argsort(mask_conf, descending=True)
        
            topk = min(F_max, mask_pos.size(0))
            chosen = sorted_idx[:topk]
            F_mask = torch.zeros(seq_len, dtype=torch.bool, device=device)
            F_mask[mask_pos[chosen]] = True

        # ----------- 计算 scores -----------
        scores = torch.full_like(top1_conf[0], -1e9)
        scores[F_mask] = top1_conf[0, F_mask]


        # 选出前 K
        sel_scores, sel_idx = torch.topk(scores, k=min(K, F_mask.sum().item()))
        if sel_idx.numel() == 0:
            break

        # 更新选中位置
        x[0, sel_idx] = top1[0, sel_idx]

        # ----------- 邻域扩展 -----------
        new_F_mask = F_mask.clone()
        for offset in range(-r, r + 1):
            nb = sel_idx + offset
            valid_nb = (nb >= prompt_len) & (nb < seq_len)
            new_F_mask[nb[valid_nb]] = True
        F_mask = new_F_mask

        # ----------- 收敛过滤 -----------
        F_mask &= (x[0] == mask_id)

        # ----------- 限制 F_max -----------
        if F_mask.sum() > F_max:
            pruned_scores = torch.full_like(top1_conf[0], -1e9)
            pruned_scores[F_mask] = top1_conf[0, F_mask]
            keep_scores, keep_idx = torch.topk(pruned_scores, k=F_max)
            new_mask = torch.zeros_like(F_mask)
            new_mask[keep_idx] = True
            F_mask = new_mask

        # ----------- EOS 检查 -----------
        if eos_id in x[0, prompt_len:]:
            eos_pos = (x[0, prompt_len:] == eos_id).nonzero(as_tuple=True)[0][0].item()
            x[0, prompt_len + eos_pos + 1:] = eos_id
            
    # while (x == mask_id).any():
    #     mask_index = (x == mask_id)[0]
    #     if not mask_index.any():
    #         break

    #     logits = model(x).logits
    #     probs = F.softmax(logits, dim=-1)
    #     top1 = torch.argmax(probs, dim=-1)        
    #     top1_conf = torch.gather(probs, -1, top1.unsqueeze(-1)).squeeze(-1)  # [1, L]

    #     scores = torch.full_like(top1_conf[0], -1e9)
    #     scores[mask_index] = top1_conf[0, mask_index]
    #     sel_score, sel_idx = torch.topk(scores, k=1)

    #     x[0, sel_idx] = top1[0, sel_idx]

    return x, step

@torch.no_grad()
def generate_in_sentence(model,
            prompt,
            tokenizer,
            steps: int = 1024,
            gen_length: int = 1024,
            temperature: float = 0.,
            cfg_scale: float = 0.,
            mask_id=126336,
            eos_id=126081):

    device = model.device
    prompt_len = prompt.shape[1]

    x = torch.full((1, prompt_len + gen_length), mask_id, dtype=torch.long, device=device)
    x[:, :prompt_len] = prompt.clone()
    prompt_index = (prompt != mask_id)

    # 特殊 token id
    comma_id = tokenizer.encode(",")[0]
    period_id = tokenizer.encode(".")[0]
    newline_id = tokenizer.encode("\n")[0]
    boundary_tokens = {comma_id, period_id, newline_id, eos_id}

    total_step = 0
    for step in range(steps):
        total_step += 1
        mask_index = (x == mask_id)

        if not mask_index.any():
            break  # 所有 token 已去噪

        # forward
        if cfg_scale > 0.:
            un_x = x.clone()
            un_x[prompt_index] = mask_id
            x_ = torch.cat([x, un_x], dim=0)
            logits = model(x_).logits
            logits, un_logits = torch.chunk(logits, 2, dim=0)
            logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
        else:
            logits = model(x).logits

        p = F.softmax(logits, dim=-1)
        pred_token = torch.argmax(p, dim=-1)
        conf = torch.gather(p, -1, pred_token.unsqueeze(-1)).squeeze(-1)

        # 找第一个未去噪的 mask token
        mask_positions = [i for i, t in enumerate(x[0]) if t == mask_id]
        start_pos = mask_positions[0]

            # 找第一个特殊 token作为边界
        seq = pred_token[0].tolist()
        boundary_pos = None
        for i in range(start_pos, len(seq)):
            if seq[i] in boundary_tokens:
                boundary_pos = i + 1  # 包括边界符
                break

        if boundary_pos is None:
            boundary_pos = start_pos + 8  # 默认长度 8

        if total_step > 172:
            breakpoint()

        denoise_range = range(start_pos, min(boundary_pos, x.shape[1]))

        # 选择置信度最高的 mask token 去噪
        candidate_positions = [pos for pos in denoise_range if mask_index[0, pos]]
        if not candidate_positions:
            continue

        best_pos = max(candidate_positions, key=lambda pos: conf[0, pos].item())
        x[0, best_pos] = pred_token[0, best_pos]

        # 检查 eos
        if eos_id in x[0]:
            eos_pos = (x[0] == eos_id).nonzero()[0].item()
            # breakpoint()
            x[:, eos_pos + 1:] = eos_id

    return x[:, prompt_len:], total_step

@torch.no_grad()
def generate_before(model,
    remasking: str = 'low_confidence',                
    mask_id=126336,
    eos_id=126081):
    device = model.device
    prompt_len = prompt.shape[1]
    x = torch.full((1, prompt_len + gen_length), mask_id, dtype=torch.long, device=device)
    x[:, :prompt_len] = prompt.clone()
    prompt_logits = model(x).logits
    fixed_prompt = torch.argmax(prompt_logits, dim=-1)
    x[:, :prompt_len] = fixed_prompt[:, :prompt_len]

    # breakpoint()
    
    assert steps % num_blocks == 0
    steps_per_block = steps // num_blocks

    position_order = get_ordered_positions(block_length) if remasking == "bst" else get_preorder_bst_positions(block_length)
    total_step = 0

    for block_idx in range(num_blocks):
        block_start = prompt_len + block_idx * block_length
        block_end = block_start + block_length

        block_mask_index = (x[:, block_start: block_end] == mask_id)
        num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps)

        for step_idx in range(steps_per_block):
            total_step += 1
            mask_index = (x == mask_id)
            
            start = time.perf_counter()
            if cfg_scale > 0.:
                un_x = x.clone()
                un_x[prompt_index] = mask_id
                x_ = torch.cat([x, un_x], dim=0)
                logits = model(x_).logits
                logits, un_logits = torch.chunk(logits, 2, dim=0)
                logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
            else:
                logits = model(x).logits

            noisy = add_gumbel_noise(logits, temperature)
            x0 = torch.argmax(noisy, dim=-1)

            if remasking == 'binary_search' or remasking == 'bst':
                x0_p = torch.full((x0.shape[0], x0.shape[1]), -float('inf'), device=x0.device)
                k = num_transfer_tokens[:step_idx + 1].sum()
                ordered_mask_idxs = position_order[:k]
                for pos in ordered_mask_idxs:
                    x0_p[:, block_start + pos] = 1.0
            elif remasking == 'low_confidence':
                p = F.softmax(logits, dim=-1)
                x0_p = torch.squeeze(
                    torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
            elif remasking == 'random':
                x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
            elif remasking == 'auto_regression':
                x0_p = torch.full((x0.shape[0], x0.shape[1]), -float('inf'), device=x0.device)
                for j in range(x0.shape[0]):
                    block_mask_indices = torch.nonzero(mask_index[:, block_start: block_end], as_tuple=False).squeeze(-1)
                    
                    k = num_transfer_tokens[step_idx].item()
                    selected = block_mask_indices[:k]
                    for idx in selected:
                        x0_p[j, block_start + idx] = 1.0
            else:
                raise NotImplementedError(remasking)

            x0_p[:, block_end:] = -np.inf
            x0 = torch.where(mask_index, x0, x)
            confidence = torch.where(mask_index, x0_p, -np.inf)

            transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
            for j in range(confidence.shape[0]):
                _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[step_idx])
                transfer_index[j, select_index] = True
            x[transfer_index] = x0[transfer_index]
        start = time.perf_counter()
        for step_idx in range(steps_per_block):
            total_step += 1
            mask_index = (x == mask_id)

            # 当前block内被mask的位置
            block_mask_index = mask_index[:, block_start:block_end].squeeze(0)
            if not block_mask_index.any():
                break
            
            masked_positions = torch.nonzero(block_mask_index, as_tuple=False).squeeze(1)
            candidate_logits = []
            candidate_positions = []

            for pos in masked_positions:
                x_candidate = x.clone()
                x_candidate[:, block_start:block_end] = mask_id
                x_candidate[:, block_end:] = mask_id

                
                if cfg_scale > 0.:
                    un_x = x.clone()
                    un_x[:, :prompt_len] = mask_id
                    x_ = torch.cat([x, un_x], dim=0)
                    logits_full = model(x_).logits  # (2, seq_len, vocab)
                    logits_cond, logits_uncond = torch.chunk(logits_full, 2, dim=0)
                    logits_full = logits_uncond + (cfg_scale + 1) * (logits_cond - logits_uncond)
                else:
                    logits_full = model(x).logits  # (1, seq_len, vocab)

                noisy_full = add_gumbel_noise(logits_full, temperature)
                x0_full = torch.argmax(noisy_full, dim=-1)

                x_candidate[:, block_start + pos] = x0_full[:, block_start + pos]

                logits_candidate = model(x_candidate).logits  # (1, seq_len, vocab)

                token_id = x_candidate[:, block_start + pos].unsqueeze(-1)  # (1,1)
                token_logits = torch.gather(logits_candidate[:, block_start + pos, :], -1, token_id).squeeze(-1)  # (1,)

                candidate_logits.append(token_logits)
                candidate_positions.append(block_start + pos)

            # 选出logits最高的token及其对应位置，更新x
            candidate_logits_tensor = torch.cat(candidate_logits)  # (num_masked_positions,)
            max_idx = torch.argmax(candidate_logits_tensor)
            best_pos = candidate_positions[max_idx]
            best_token = x0_full[:, best_pos]

            x[:, best_pos] = best_token

        if eos_id in x[:, prompt_len:block_end]:
            x[:, block_end:] = eos_id
            break
    return x[:, prompt_len:], total_step

def main():
    device = 'cuda'

    model = AutoModel.from_pretrained(
        'GSAI-ML/LLaDA-8B-Instruct', 
        trust_remote_code=True, torch_dtype=torch.bfloat16).to(device).eval()
    tokenizer = AutoTokenizer.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True)
    tokenizer.pad_token_id = 126336

    prompt = '''\n\nfrom typing import List\n\n\ndef factorize(n: int) -> List[int]:\n    """ Return list of prime factors of given integer in the order from smallest to largest.\n    Each of the factors should be listed number of times corresponding to how many times it appeares in factorization.\n    Input number should be equal to the product of all factors\n    >>> factorize(8)\n    [2, 2, 2]\n    >>> factorize(25)\n    [5, 5]\n    >>> factorize(70)\n    [2, 5, 7]\n    """'''

    # Add special tokens for the Instruct model. The Base model does not require the following two lines.
    m = [{"role": "user", "content": prompt}, ]
    prompt = tokenizer.apply_chat_template(m, add_generation_prompt=True, tokenize=False)

    input_ids = tokenizer(prompt)['input_ids']
    input_ids = torch.tensor(input_ids).to(device).unsqueeze(0)

    # update_prompt = '''To implement the `factorize` function, we need to find the prime factors of a given integer `n` and return them in a list. The factors should be listed in ascending order, and each factor should appear as many times as it does in the factorization.\n\nHere\'s a step-by-step approach to achieve this:\n\n1. **Initialize an empty list to store the factors.**\n2. **Check if `n` is divisible by 2.** If it is, append 2 to the list and divide `n` by 2.\n3. **Check if `n` is divisible by 3.** If it is, append 3 to the list and divide `n` by 3.\n4. **Continue this process until `n` is greater than 1.**\n5. **If `n` is still greater than 1 after the loop, it means `n` itself is a prime number and should be added to the list.**\n\nHere\'s the complete implementation of the `factorize` function:\n\n```python\nfrom typing import List\n\ndef factorize(n: int) -> List[int]:\n    """ Return list of prime factors of given integer in the order from smallest to largest.\n    Each of the factors should be listed number of times corresponding to how many times it appears in factorization.\n    Input number should be equal to the product of all factors\n    >>> factorize(8)\n    [2, 2, 2]\n    >>> factorize(25)\n    [5, 5]\n    >>> factorize(70)\n    [2, 5, 7]\n    """\n    factors = []\n    if n <= 1:\n        return factors\n\n    while n % 2 == 0:\n        factors.append(2)\n        n //= 2\n\n    if n > 2:\n        while n % 3 == 0:\n            factors.append(3)\n            n //= 3\n\n    i = 5\n    while i * i <= n:\n        while n % i == 0'''
    # update_prompt_id = tokenizer(update_prompt, return_tensors="pt").input_ids.to(device)

    # input_ids =  torch.cat([input_ids, update_prompt_id], dim=1)

    start_time = time.perf_counter()
    # start_time = time.perf_counter()
    out, total_steps = generate_before(
        model,
        prompt,
        tokenizer,
        steps=512,
        gen_length=512,
        block_length=8,
        temperature=0.0,
        cfg_scale=0.,
        # remasking='binary_search',
    )

    # out, redun_steps, total_steps = generate_with_position_order_sampling(model, input_ids, tokenizer)
    print(f"Cost {time.perf_counter() - start_time}")
    print(tokenizer.decode(out[0]))
    # print(f"Prompt Len is {len(input_ids[0])} and Gen_len is {len(out[0])}\n")
    # print(f"Redundent Steps in Generation is {redun_steps} / Total Steps is {total_steps}.")


if __name__ == '__main__':
    main()
