import os
import json
import argparse
from typing import Dict, Any

import torch
from transformers import AutoModel, AutoTokenizer
from tqdm import tqdm

from data import get_dataset
from dllm_generation import dllm_original_generation
from dllm_latent_seek import dllm_latent_seek_generation
from rewards.dllm_reward import DllmRewardModel
from extract_judge_answer import extract_answer, extract_true_answer, judge_answer


def parse_args():
    parser = argparse.ArgumentParser(description="DLLM Latent-Seek Optimization")
    # Core
    parser.add_argument("--dataset", type=str, default="openai/gsm8k", help="Dataset to evaluate")
    parser.add_argument("--model_name_or_path", type=str, help="HF model id or local path")
    parser.add_argument("--dllm_type", type=str, choices=["llada", "dream"], help="DLLM backend type")
    parser.add_argument("--output_dir", type=str, default="./output", help="Directory to store outputs")
    parser.add_argument("--start_data_idx", type=int, default=0, help="Start index (inclusive)")
    parser.add_argument("--end_data_idx", type=int, default=1319, help="End index (exclusive)")
    parser.add_argument("--solver_prompt_idx", type=int, default=0, help="Which solver prompt to use")
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--device", type=str, default=None)

    # Common gen params for original generation
    parser.add_argument("--steps", type=int, default=128, help="Sampling steps for DLLM")
    parser.add_argument("--max_new_tokens", type=int, default=128)
    parser.add_argument("--temperature", type=float, default=0.0)

    # LLaDA-specific
    parser.add_argument("--llada_block_length", type=int, default=32)
    parser.add_argument("--llada_cfg_scale", type=float, default=0.0)
    parser.add_argument("--llada_remasking", type=str, default="low_confidence", choices=["low_confidence", "random"])
    parser.add_argument("--llada_mask_id", type=int, default=126336)

    # DREAM-specific
    parser.add_argument("--dream_top_p", type=float, default=0.95)
    parser.add_argument("--dream_alg", type=str, default="entropy")
    parser.add_argument("--dream_alg_temp", type=float, default=0.0)

    # Latent-Seek optimization params
    parser.add_argument("--lr", type=float, default=0.03)
    parser.add_argument("--k", type=float, default=0.1, help="Fraction of answer tokens to optimize (cap 300)")
    parser.add_argument("--max_num_steps", type=int, default=10)
    parser.add_argument("--reward_threshold", type=float, default=-0.2)
    parser.add_argument("--start_index_in_answer", type=int, default=0)

    # Tail generation params (override for the tail only)
    parser.add_argument("--tail_steps", type=int, default=128)
    parser.add_argument("--tail_block_length", type=int, default=32)

    # Reward format (optional)
    parser.add_argument("--rule_format_string", type=str, default=None)

    # Reward type
    parser.add_argument(
        "--reward_type",
        type=str,
        default="dllm_verifier",
        choices=["dllm_verifier", "self_confidence"],
        help="Choose reward backend: DLLM verifiers or self-confidence logits",
    )
    parser.add_argument(
        "--conf_measure",
        type=str,
        default="top1",
        choices=["top1", "gap"],
        help="Confidence measure for self reward",
    )
    parser.add_argument(
        "--conf_aggregator",
        type=str,
        default="mean",
        choices=["mean", "min"],
        help="Aggregate token confidences",
    )

    return parser.parse_args()


def set_seed(seed: int):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ["PYTHONHASHSEED"] = str(seed)


def build_gen_kwargs(args) -> Dict[str, Any]:
    if args.dllm_type == "llada":
        return dict(
            steps=args.steps,
            max_new_tokens=args.max_new_tokens,
            block_length=args.llada_block_length,
            temperature=args.temperature,
            cfg_scale=args.llada_cfg_scale,
            remasking=args.llada_remasking,
            mask_id=args.llada_mask_id,
        )
    else:  # dream
        return dict(
            steps=args.steps,
            max_new_tokens=args.max_new_tokens,
            temperature=args.temperature,
            top_p=args.dream_top_p,
            alg=args.dream_alg,
            alg_temp=args.dream_alg_temp,
        )


def main(args):
    if args.seed:
        set_seed(args.seed)

    device = args.device or ("cuda" if torch.cuda.is_available() else "cpu")

    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True)
    model = AutoModel.from_pretrained(
        args.model_name_or_path,
        trust_remote_code=True,
        torch_dtype=torch.bfloat16,
    ).to(device).eval()

    dataset = get_dataset(args.dataset, tokenizer=tokenizer, prompt_idx=args.solver_prompt_idx)
    print(f"Example: {dataset[0]}")

    model_name = args.model_name_or_path.split("/")[-1]
    data_name = args.dataset.split("/")[-1]
    tag = f"dllm-lseek-{args.dllm_type}"
    out_dir = f"{args.output_dir}/{tag}-{model_name}-{data_name}-SolIdx{args.solver_prompt_idx}"
    os.makedirs(out_dir, exist_ok=True)

    results_list = []
    correct = 0
    total = 0

    start = max(0, args.start_data_idx)
    end = min(args.end_data_idx, len(dataset))

    gen_kwargs = build_gen_kwargs(args)
    print(f"Using base gen params: {gen_kwargs}")

    # Build reward model
    if args.reward_type == "self_confidence":
        from rewards.self_confidence_reward import SelfConfidenceRewardModel

        reward_model = SelfConfidenceRewardModel(
            dllm_type=args.dllm_type,
            model=model,
            tokenizer=tokenizer,
            device=device,
            measure=args.conf_measure,
            aggregator=args.conf_aggregator,
            llada_mask_id=args.llada_mask_id,
        )
    else:
        # Reward model backed by the same DLLM via verifier prompts
        reward_model = DllmRewardModel(
            dllm_type=args.dllm_type,
            model=model,
            tokenizer=tokenizer,
            device=device,
            rule_format_string=args.rule_format_string,
            gen_steps=min(args.steps, 64),
            max_new_tokens=64,
            llada_block_length=args.llada_block_length,
            temperature=args.temperature,
            dream_top_p=args.dream_top_p,
            dream_alg=args.dream_alg,
            dream_alg_temp=args.dream_alg_temp,
        )

    # Iterate dataset
    pbar = tqdm(range(start, end), desc="Latent-Seek", unit="samples")
    for i in pbar:
        example = dataset[i]
        question = example["question"]
        true_answer = extract_true_answer(example["answer"], name=args.dataset)
        if true_answer is None:
            continue

        print(f"[#{i}] Q: {question}")
        print(f"GT: {true_answer}")

        # Original DLLM generation (continuation only)
        base_gen_text, _ = dllm_original_generation(
            dllm_type=args.dllm_type,
            model=model,
            tokenizer=tokenizer,
            device=device,
            input_text=example["formatted"],
            **gen_kwargs,
        )
        print(f"Original: {base_gen_text}")

        # Latent-Seek optimization using DLLM for reward + tail
        opt_text, reward_hist, orig_len, opt_len, upd_len = dllm_latent_seek_generation(
            dllm_type=args.dllm_type,
            reward_model=reward_model,
            model=model,
            tokenizer=tokenizer,
            device=device,
            question=question,
            input_text=example["formatted"],
            original_answer=base_gen_text,
            start_index_in_answer=args.start_index_in_answer,
            max_num_steps=args.max_num_steps,
            lr=args.lr,
            k=args.k,
            reward_threshold=args.reward_threshold,
            base_gen_kwargs=gen_kwargs,
            tail_steps=args.tail_steps,
            tail_block_length=args.tail_block_length,
        )
        print(f"Optimized: {opt_text}")

        # Extract and judge
        extracted_base = extract_answer(base_gen_text, data_name=args.dataset, prompt_idx=args.solver_prompt_idx,
                                        model_name=args.model_name_or_path)
        extracted_opt = extract_answer(opt_text, data_name=args.dataset, prompt_idx=args.solver_prompt_idx,
                                       model_name=args.model_name_or_path)
        is_correct = False
        if extracted_opt is not None:
            is_correct = judge_answer(opt_text, true_answer, data_name=args.dataset, prompt_idx=args.solver_prompt_idx)

        correct += int(is_correct)
        total += 1

        # Save incremental results
        result_entry = {
            "question": question,
            "true_answer": true_answer,
            "original": base_gen_text,
            "optimized": opt_text,
            "extracted_original": extracted_base,
            "extracted_optimized": extracted_opt,
            "reward_history": reward_hist,
            "orig_len": orig_len,
            "opt_len": opt_len,
            "update_len": upd_len,
            "data_idx": i,
            "correct": bool(is_correct),
        }
        results_list.append(result_entry)
        with open(os.path.join(out_dir, "results.json"), "w") as f:
            json.dump(results_list, f, indent=2, ensure_ascii=False)

        pbar.set_postfix({'acc': f"{correct/total:.4f}" if total > 0 else "0.0000"})

    final_acc = correct / total if total > 0 else 0.0
    print(f"Final accuracy (optimized): {final_acc:.4f} ({correct}/{total})")

    summary = {
        "final_accuracy": final_acc,
        "correct": correct,
        "total": total,
        "dllm_type": args.dllm_type,
        "model_name_or_path": args.model_name_or_path,
        "dataset": args.dataset,
        "gen_kwargs": gen_kwargs,
        "lr": args.lr,
        "k": args.k,
        "max_num_steps": args.max_num_steps,
        "reward_threshold": args.reward_threshold,
        "start_index_in_answer": args.start_index_in_answer,
        "tail_steps": args.tail_steps,
        "tail_block_length": args.tail_block_length,
    }
    with open(os.path.join(out_dir, "summary.json"), "w") as f:
        json.dump(summary, f, indent=2)


if __name__ == "__main__":
    args = parse_args()
    for arg in vars(args):
        print(f"-- {arg}: {getattr(args, arg)}")
    main(args)
