import os
import json
import argparse
from typing import Dict, Any

import torch
from transformers import AutoModel, AutoTokenizer
from tqdm import tqdm

from data import get_dataset
from dllm_generation import dllm_original_generation
from extract_judge_answer import extract_answer, extract_true_answer, judge_answer


def parse_args():
    parser = argparse.ArgumentParser(description="Evaluate diffusion language models (original generation)")
    # Core
    parser.add_argument("--dataset", type=str, default="openai/gsm8k", help="Dataset to evaluate")
    parser.add_argument("--model_name_or_path", type=str, help="HF model id or local path")
    parser.add_argument("--dllm_type", type=str, choices=["llada", "dream"], help="DLLM backend type")
    parser.add_argument("--output_dir", type=str, default="./output", help="Directory to store outputs")
    parser.add_argument("--start_data_idx", type=int, default=0, help="Start index (inclusive)")
    parser.add_argument("--end_data_idx", type=int, default=1319, help="End index (exclusive)")
    parser.add_argument("--solver_prompt_idx", type=int, default=0, help="Which solver prompt to use")
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--device", type=str, default=None)

    # Common gen params
    parser.add_argument("--steps", type=int, default=128, help="Sampling steps for DLLM")
    parser.add_argument("--max_new_tokens", type=int, default=128)
    parser.add_argument("--temperature", type=float, default=0.0)

    # LLaDA-specific
    parser.add_argument("--llada_block_length", type=int, default=32)
    parser.add_argument("--llada_cfg_scale", type=float, default=0.0)
    parser.add_argument("--llada_remasking", type=str, default="low_confidence", choices=["low_confidence", "random"])
    parser.add_argument("--llada_mask_id", type=int, default=126336)

    # DREAM-specific
    parser.add_argument("--dream_top_p", type=float, default=0.95)
    parser.add_argument("--dream_alg", type=str, default="entropy")
    parser.add_argument("--dream_alg_temp", type=float, default=0.0)

    return parser.parse_args()


def set_seed(seed: int):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ["PYTHONHASHSEED"] = str(seed)


def build_gen_kwargs(args) -> Dict[str, Any]:
    if args.dllm_type == "llada":
        return dict(
            steps=args.steps,
            max_new_tokens=args.max_new_tokens,
            block_length=args.llada_block_length,
            temperature=args.temperature,
            cfg_scale=args.llada_cfg_scale,
            remasking=args.llada_remasking,
            mask_id=args.llada_mask_id,
        )
    else:  # dream
        return dict(
            steps=args.steps,
            max_new_tokens=args.max_new_tokens,
            temperature=args.temperature,
            top_p=args.dream_top_p,
            alg=args.dream_alg,
            alg_temp=args.dream_alg_temp,
        )


def main(args):
    if args.seed:
        set_seed(args.seed)

    device = args.device or ("cuda" if torch.cuda.is_available() else "cpu")

    # Load tokenizer first so data formatting uses correct chat template
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True)
    model = AutoModel.from_pretrained(
        args.model_name_or_path,
        trust_remote_code=True,
        torch_dtype=torch.bfloat16,
    ).to(device).eval()

    # Load dataset with the same tokenizer to format messages properly
    dataset = get_dataset(args.dataset, tokenizer=tokenizer, prompt_idx=args.solver_prompt_idx)
    print(f"Example: {dataset[0]}")

    model_name = args.model_name_or_path.split("/")[-1]
    data_name = args.dataset.split("/")[-1]
    tag = f"dllm-{args.dllm_type}"
    out_dir = f"{args.output_dir}/{tag}-{model_name}-{data_name}-SolIdx{args.solver_prompt_idx}"
    os.makedirs(out_dir, exist_ok=True)

    results_list = []
    correct = 0
    total = 0

    start = max(0, args.start_data_idx)
    end = min(args.end_data_idx, len(dataset))

    gen_kwargs = build_gen_kwargs(args)
    print(f"Using gen params: {gen_kwargs}")

    # Initialize progress bar
    pbar = tqdm(range(start, end), desc="Evaluating", unit="samples")
    for i in pbar:
        example = dataset[i]
        question = example["question"]
        true_answer = extract_true_answer(example["answer"], name=args.dataset)
        if true_answer is None:
            continue

        print(f"[#{i}] Q: {question}")
        print(f"GT: {true_answer}")

        # Original DLLM generation (continuation only)
        gen_text, input_ids = dllm_original_generation(
            dllm_type=args.dllm_type,
            model=model,
            tokenizer=tokenizer,
            device=device,
            input_text=example["formatted"],
            **gen_kwargs,
        )
        
        print(f"Generated: {gen_text}")

        # Extract and judge
        extracted = extract_answer(gen_text, data_name=args.dataset, prompt_idx=args.solver_prompt_idx,
                                   model_name=args.model_name_or_path)
        is_correct = False
        if extracted is not None:
            is_correct = judge_answer(gen_text, true_answer, data_name=args.dataset, prompt_idx=args.solver_prompt_idx)

        correct += int(is_correct)
        total += 1

        # Save incremental results
        result_entry = {
            "question": question,
            "true_answer": true_answer,
            "generated": gen_text,
            "extracted_answer": extracted,
            "data_idx": i,
            "correct": bool(is_correct),
        }
        results_list.append(result_entry)
        with open(os.path.join(out_dir, "results.json"), "w") as f:
            json.dump(results_list, f, indent=2, ensure_ascii=False)

        # Update tqdm postfix with current accuracy
        pbar.set_postfix({'acc': f"{correct/total:.4f}" if total > 0 else "0.0000"})

    final_acc = correct / total if total > 0 else 0.0
    print(f"Final accuracy: {final_acc:.4f} ({correct}/{total})")

    # Save final summary
    summary = {
        "final_accuracy": final_acc,
        "correct": correct,
        "total": total,
        "dllm_type": args.dllm_type,
        "model_name_or_path": args.model_name_or_path,
        "dataset": args.dataset,
        "gen_kwargs": gen_kwargs,
    }
    with open(os.path.join(out_dir, "summary.json"), "w") as f:
        json.dump(summary, f, indent=2)


if __name__ == "__main__":
    args = parse_args()
    for arg in vars(args):
        print(f"-- {arg}: {getattr(args, arg)}")
    main(args)

