import random, os, sys
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
import torch
import argparse
from tqdm import tqdm
current_script_path = os.path.abspath(__file__)
scripts_dir = os.path.dirname(current_script_path)
project_root = os.path.dirname(scripts_dir)
if project_root not in sys.path:
    sys.path.insert(0, project_root)
from utils.eval_utils import query_extract, load_dataset, eval, countdown_check
from src.llama_template import llama_prompt

random.seed(42)

def generate(model, tokenizer, input, task, steps, gen_length, block_length, temperature, mode, lambd, alpha, baseline_name, thread, gamma, num_remask_tokens, position_temperature=0.1, candidate_number=8, heuristic='confidence', return_cumulative_entropy=False, tokens_per_step=None, mask_id=None, is_dream=False, result_path=None, use_shot=True):

    query = query_extract(input, task, use_shot=use_shot)
    m = [{"role": "user", "content": query}]
    user_input = tokenizer.apply_chat_template(m, add_generation_prompt=True, tokenize=False)
    
    # 使用传入的 mask_id，如果没有则自动检测
    if mask_id is None:
        # 优先从 model.config 获取（Dream 模型）
        if hasattr(model, 'config') and hasattr(model.config, 'mask_token_id'):
            mask_id = model.config.mask_token_id
        # 尝试从 tokenizer 获取
        elif getattr(tokenizer, 'mask_token_id', None) is not None:
            mask_id = tokenizer.mask_token_id
        # 尝试从 vocab 中查找
        elif hasattr(tokenizer, 'vocab') and '<|mask|>' in tokenizer.vocab:
            mask_id = tokenizer.vocab['<|mask|>']
        elif hasattr(tokenizer, 'get_vocab'):
            vocab = tokenizer.get_vocab()
            for token in ['<|mask|>', '[MASK]', '<mask>']:
                if token in vocab:
                    mask_id = vocab[token]
                    break
        if mask_id is None:
            mask_id = 126336  # LLaDA 默认值
    
    # 获取 mask token 字符串
    mask_token = tokenizer.decode([mask_id])
    
    # Sudoku 特殊处理：直接在 prompt 后面拼接带 mask 的答案模板
    # 答案格式：4行，每行4个数字，换行分隔
    # 已有数字保留，0 的位置换成 mask token
    prefilled_positions = None
    sudoku_answer_len = 0
    original_prompt_len = 0  # 原始 prompt 长度（不含答案模板）
    
    if task == 'sudoku':
        puzzle_str = input['Puzzle']  # 16 位数字字符串
        
        # 记录原始 prompt 长度
        original_prompt_len = len(tokenizer(user_input)['input_ids'])
        
        # 构建答案模板：把 0 替换为 mask token
        answer_lines = []
        for row in range(4):
            line = ''
            for col in range(4):
                char = puzzle_str[row * 4 + col]
                if char == '0':
                    line += mask_token
                else:
                    line += char
            answer_lines.append(line)
        answer_lines.append("</answer>")
        answer_template = '\n'.join(answer_lines)
        
        # 把答案模板拼接到 prompt 后面
        user_input = user_input + answer_template
        # print(f"[DEBUG] User input: {user_input}")
        
        # tokenize 完整的输入
        full_tokens = tokenizer.encode(answer_template, add_special_tokens=False)
        sudoku_answer_len = len(full_tokens)
        
        # 设置 gen_length = 0，因为答案已经在 prompt 中了
        gen_length = 0
    
    prompt = tokenizer(user_input)['input_ids']
    prompt = torch.tensor(prompt).to(model.device).unsqueeze(0)
    
    # 对于 Sudoku，计算实际需要解码的 mask 数量
    if task == 'sudoku':
        num_masks = (prompt == mask_id).sum().item()
        # 步数 = mask 数量
        steps = num_masks
        block_length = num_masks
    
    if mode == 'original':
        # original 模式：使用默认参数的基础 generate 函数
        from src.generate import generate
        out = generate(model, prompt, steps, gen_length, block_length, lambd=0.0, alpha=100, baseline_name=baseline_name, temperature=temperature, cfg_scale=0., remasking='low_confidence', mask_id=mask_id, candidate_number=1, position_temperature=0.0, heuristic='confidence', prefilled_positions=prefilled_positions, is_dream=is_dream)
    elif mode == 'pc_sampler':
        # pc_sampler 模式：使用 PC 启发函数的基础 generate 函数
        from src.generate import generate
        out = generate(model, prompt, steps, gen_length, block_length, lambd, alpha, baseline_name, temperature, cfg_scale=0., remasking='low_confidence', mask_id=mask_id, candidate_number=1, position_temperature=0.0, heuristic='pc', prefilled_positions=prefilled_positions, is_dream=is_dream)
    elif mode == 'eb_sampler':
        from src.generate import generate_with_eb_sampler
        out = generate_with_eb_sampler(model, prompt, gamma, gen_length, temperature, cfg_scale=0., mask_id=mask_id, is_dream=is_dream)
    elif mode == 'fast_dllm':
        from src.generate import generate_with_fast_dllm
        out = generate_with_fast_dllm(model, prompt, steps, gen_length, block_length, temperature, remasking='low_confidence', mask_id=mask_id, threshold=thread, is_dream=is_dream)[0]
    elif mode == 'entropy':
        # entropy 模式：等同于 candidate_number=1, heuristic='neg_entropy' 的 generate
        from src.generate import generate
        out = generate(model, prompt, steps, gen_length, block_length, lambd=0.0, alpha=100, baseline_name=baseline_name, temperature=temperature, cfg_scale=0., remasking='low_confidence', mask_id=mask_id, candidate_number=1, position_temperature=0.0, heuristic='neg_entropy', prefilled_positions=prefilled_positions, is_dream=is_dream)
    elif mode == 'margin':
        # margin 模式：等同于 candidate_number=1, heuristic='margin' 的 generate
        from src.generate import generate
        out = generate(model, prompt, steps, gen_length, block_length, lambd=0.0, alpha=100, baseline_name=baseline_name, temperature=temperature, cfg_scale=0., remasking='low_confidence', mask_id=mask_id, candidate_number=1, position_temperature=0.0, heuristic='margin', prefilled_positions=prefilled_positions, is_dream=is_dream)
    elif mode == 'info-gain':
        from src.generate import generate_with_info_gain
        result = generate_with_info_gain(model, prompt, steps, gen_length, block_length, temperature, cfg_scale=0., remasking='low_confidence', mask_id=mask_id,
                               position_temperature=position_temperature, candidate_number=candidate_number,
                               lambd=lambd, alpha=alpha, baseline_name=baseline_name,
                               prefilled_positions=prefilled_positions, heuristic=heuristic,
                               return_cumulative_entropy=return_cumulative_entropy,
                               tokens_per_step=tokens_per_step, is_dream=is_dream,
                               save_monotone_residual_path=result_path)
        if return_cumulative_entropy:
            out, cumulative_entropy = result
        else:
            out = result
            cumulative_entropy = None
    else:
        raise NotImplementedError(f"Mode {mode} not implemented.")
    
    # 提取答案
    if task == 'sudoku' and sudoku_answer_len > 0:
        # Sudoku 任务：答案在 prompt 尾部的 sudoku_answer_len 个 token 中
        answer = tokenizer.batch_decode(out[:, original_prompt_len:], skip_special_tokens=True)[0]
    else:
        answer = tokenizer.batch_decode(out[:, prompt.shape[1]:], skip_special_tokens=True)[0]
    
    if return_cumulative_entropy and mode == 'info-gain':
        return answer, cumulative_entropy
    return answer

def llama_generate(model, tokenizer, input, task, gen_length, use_shot=True):

    query = query_extract(input, task, use_shot=use_shot)
    user_input = llama_prompt(query, task)
    input_ids = tokenizer(user_input)['input_ids']
    input_ids = torch.tensor(input_ids).to(model.device).unsqueeze(0)
    prompt = input_ids
    out = model.generate(inputs=prompt, max_length=gen_length+prompt.shape[1], do_sample=False)
    answer = tokenizer.batch_decode(out[:, prompt.shape[1]:], skip_special_tokens=True)[0]
    return answer

def mistral_generate(model, tokenizer, input, task, gen_length, use_shot=True):

    query = query_extract(input, task, use_shot=use_shot)
    conversation = [{"role": "user", "content": query}]
    inputs = tokenizer.apply_chat_template(
                conversation,
                add_generation_prompt=True,
                return_dict=True,
                return_tensors="pt",
            )
    
    
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    prompt = inputs['input_ids']
    out = model.generate(**inputs, max_new_tokens=gen_length, do_sample=False)
    answer = tokenizer.batch_decode(out[:, prompt.shape[1]:], skip_special_tokens=True)[0]
    if task == 'mbpp':
        answer = answer.split('```python')[1].split('```')[0]
        answer = answer.replace('```python', '\n').replace('```', '\n')
        answer = input['prompt'].replace(input['entry_point'], input['entry_point']+'_prompt') + answer
    return answer

def Qwen_25_generate(model, tokenizer, input, task, gen_length, use_shot=True):

    query = query_extract(input, task, use_shot=use_shot)
    messages = [{"role": "user", "content": query}]
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
    )
    model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
    out = model.generate(**model_inputs, max_new_tokens=gen_length, do_sample=False)
    prompt = model_inputs['input_ids']
    answer = tokenizer.batch_decode(out[:, prompt.shape[1]:], skip_special_tokens=True)[0]
    if task =='mbpp':
        answer = answer.replace('```python', '\n').replace('```', '\n')
        answer = input['prompt'].replace(input['entry_point'], input['entry_point']+'_prompt') + answer
    return answer

def main(args):

    task = args.task
    model_name = args.model_name
    device = args.device
    gen_length = args.gen_length
    steps = args.steps
    block_length = args.block_length
    temperature = args.temperature
    mode = args.mode
    lambd = args.lambd
    alpha = args.alpha
    baseline_name = args.baseline_name
    thread = args.thread
    gamma = args.gamma
    num_remask_tokens = args.num_remask_tokens
    data_path = args.data_path
    result_path = args.result_path
    use_shot = not args.no_shot  # 如果 --no_shot 为 True，则 use_shot 为 False

    dataset = load_dataset(data_path, task)

    print('----------------- Load model -------------------')
    
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    # 判断模型类型：支持HuggingFace模型名和本地路径
    # LLaDA和Dream模型使用AutoModel，其他模型使用AutoModelForCausalLM
    model_name_lower = model_name.lower()
    is_llada = 'llada' in model_name_lower or 'llama' in model_name_lower  # 本地路径可能包含llama
    is_dream = 'dream' in model_name_lower
    
    if is_llada or is_dream:
        model = AutoModel.from_pretrained(model_name, trust_remote_code=True, torch_dtype=torch.bfloat16).to(device)
    else:
        model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, torch_dtype=torch.bfloat16).to(device)
    model.eval()
    
    # 获取 mask_id：优先从 model.config 获取（Dream），否则从 tokenizer 或默认值
    mask_id = None
    if hasattr(model, 'config') and hasattr(model.config, 'mask_token_id'):
        mask_id = model.config.mask_token_id
        print(f"[INFO] mask_id from model.config: {mask_id}")
    elif getattr(tokenizer, 'mask_token_id', None) is not None:
        mask_id = tokenizer.mask_token_id
        print(f"[INFO] mask_id from tokenizer: {mask_id}")
    else:
        if hasattr(tokenizer, 'get_vocab'):
            vocab = tokenizer.get_vocab()
            for token in ['<|mask|>', '[MASK]', '<mask>']:
                if token in vocab:
                    mask_id = vocab[token]
                    print(f"[INFO] mask_id from vocab ('{token}'): {mask_id}")
                    break
        if mask_id is None:
            mask_id = 126336
            print(f"[INFO] Using default mask_id (LLaDA): {mask_id}")
    
    print(f"[INFO] Model type: {'Dream' if is_dream else 'LLaDA' if is_llada else 'Other'}")
    print(f"[INFO] is_dream={is_dream}, mask_id={mask_id}")
        
    print('----------------- Start Answering -------------------')
    
    results = []
    cumulative_entropies = []  # 收集每个样本的累积熵
    
    for idx, input in enumerate(tqdm(dataset, desc="Generating")):
        cumulative_entropy = None
        
        if 'llama' in model_name:
            answer = llama_generate(model, tokenizer, input, task, gen_length, use_shot=use_shot)
        elif 'mistral' in model_name:
            answer = mistral_generate(model, tokenizer, input, task, gen_length, use_shot=use_shot)
        elif 'Qwen2.5' in model_name:
            answer = Qwen_25_generate(model, tokenizer, input, task, gen_length, use_shot=use_shot)
        else:
            # 对于 info-gain 模式，启用累积熵统计
            return_ce = (mode == 'info-gain')
            # 为每个样本生成唯一的单调性残差数据保存路径
            import os
            sample_monotone_path = None
            if mode == 'info-gain' and result_path:
                result_dir = os.path.dirname(result_path) if os.path.dirname(result_path) else '.'
                result_basename = os.path.basename(result_path)
                sample_monotone_path = os.path.join(result_dir, f"{result_basename}_sample_{idx}_monotone_residual.json")
            result = generate(model, tokenizer, input, task, steps, gen_length, block_length, temperature, mode, lambd, alpha, baseline_name, thread, gamma, num_remask_tokens, 
                            position_temperature=args.position_temperature if hasattr(args, 'position_temperature') else 0.1,
                            candidate_number=args.candidate_number if hasattr(args, 'candidate_number') else 0,
                            heuristic=args.heuristic if hasattr(args, 'heuristic') else 'confidence',
                            return_cumulative_entropy=return_ce,
                            tokens_per_step=args.tokens_per_step if hasattr(args, 'tokens_per_step') else None,
                            mask_id=mask_id,
                            is_dream=is_dream,
                            result_path=sample_monotone_path,
                            use_shot=use_shot)
            if return_ce and isinstance(result, tuple):
                answer, cumulative_entropy = result
            else:
                answer = result
        
        # 显示生成结果
        print(f"\n{'='*80}")
        print(f"Sample {idx + 1}/{len(dataset)}")
        print(f"{'='*80}")
        if task == 'sudoku':
            print(f"Puzzle: {input.get('Puzzle', 'N/A')}")
            print(f"Generated Answer: {answer}")
            print(f"Expected Solution: {input.get('Solution', 'N/A')}")
        elif task == 'countdown':
            print(f"Input: {input.get('input', 'N/A')}")
            print(f"Generated Answer: {answer}")
            print(f"Expected Output: {input.get('output', 'N/A')}")
            # 打印评测结果
            input_nums = input.get('input', '').split(',')
            target = int(input_nums[-1]) if input_nums else None
            is_correct = countdown_check(answer, input.get('output', ''), target=target)
            print(f">>> Evaluation: {'✓ CORRECT' if is_correct else '✗ WRONG'}")
        else:
            print(f"Generated Answer: {answer[:200]}..." if len(answer) > 200 else f"Generated Answer: {answer}")
        
        # 实时输出累积熵（如果存在）
        if cumulative_entropy is not None:
            print(f">>> Cumulative Entropy: {cumulative_entropy:.4f}")
            # 添加到列表中
            cumulative_entropies.append(cumulative_entropy)
            # 计算并显示到目前为止的平均累积熵
            avg_entropy = sum(cumulative_entropies) / len(cumulative_entropies)
            print(f">>> Average Cumulative Entropy (so far): {avg_entropy:.4f} (from {len(cumulative_entropies)} samples)")
        else:
            print(f">>> Cumulative Entropy: N/A (not available for this mode)")
        print(f"{'='*80}\n")
        
        results.append(answer)

    eval(task, results, dataset, result_path, args)
    
    # 保存累积熵统计到输出文件
    if cumulative_entropies:
        with open(result_path, 'a', encoding='utf-8') as file:
            file.write("\n----------------- Cumulative Entropy Statistics -------------------\n")
            file.write(f"Total Samples: {len(cumulative_entropies)}\n")
            file.write(f"Mean Cumulative Entropy: {sum(cumulative_entropies) / len(cumulative_entropies):.4f}\n")
            file.write(f"Min Cumulative Entropy: {min(cumulative_entropies):.4f}\n")
            file.write(f"Max Cumulative Entropy: {max(cumulative_entropies):.4f}\n")
            file.write(f"All Cumulative Entropies: {cumulative_entropies}\n")
    
    print('----------------- Finish -------------------')

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--task', type=str, default='humaneval')
    parser.add_argument('--model_name', type=str, default='GSAI-ML/LLaDA-8B-Instruct')
    parser.add_argument('--device', type=str, default='cuda:0')
    parser.add_argument('--gen_length', type=int, default=256)
    parser.add_argument('--steps', type=int, default=256)
    parser.add_argument('--block_length', type=int, default=32)
    parser.add_argument('--temperature', type=float, default=0.)
    parser.add_argument('--mode', type=str, default='original')
    parser.add_argument('--lambd', type=float, default=0.25)
    parser.add_argument('--alpha', type=float, default=100)
    parser.add_argument('--baseline_name', type=str, default='../data/baseline/reference_corpus.json')
    parser.add_argument('--thread', type=float, default=0.9)
    parser.add_argument('--gamma', type=float, default=0.01)
    parser.add_argument('--num_remask_tokens', type=int, default=10)
    parser.add_argument('--position_temperature', type=float, default=0.1, help='位置采样温度，<=0 时退化为 PC-Sampler')
    parser.add_argument('--candidate_number', type=int, default=8, help='候选动作数量，<=1 时退化为 PC-Sampler')
    parser.add_argument('--heuristic', type=str, default='confidence', choices=['pc', 'confidence', 'neg_entropy', 'margin', 'uniform'],
                        help='启发函数类型（仅 Info-Gain Sampler 模式）: pc(PC值), confidence(置信度), neg_entropy(负熵), margin(边际), uniform(随机)')
    parser.add_argument('--tokens_per_step', type=int, default=None, help='每步解码的 token 数量 (K)。若设置，则 steps = num_masks // K')
    parser.add_argument('--no_shot', action='store_true', help='不使用 shot（示例），仅使用问题本身')
    parser.add_argument('--data_path', type=str, default='./data/humaneval.jsonl')
    parser.add_argument('--result_path', type=str, default='../results/humaneval_results')
    args = parser.parse_args()
    main(args)