import argparse
import os
import time
from typing import List, Tuple, Any, Dict

import torch
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams

from calib.utils import math_equal


def extract_first_boxed(text: str) -> str:
    """Extract the first \boxed{...} content from text. Returns None if not found."""
    if text is None:
        return None
    marker = "\\boxed"
    idx = text.find(marker)
    if idx < 0:
        marker = "\\fbox"
        idx = text.find(marker)
        if idx < 0:
            return None
    i = idx
    right_brace_idx = None
    num_left_braces_open = 0
    while i < len(text):
        if text[i] == "{":
            num_left_braces_open += 1
        if text[i] == "}":
            num_left_braces_open -= 1
            if num_left_braces_open == 0:
                right_brace_idx = i
                break
        i += 1
    if right_brace_idx is None:
        return None
    candidate = text[idx : right_brace_idx + 1]
    left = "\\boxed{"
    if candidate.startswith(left) and candidate.endswith("}"):
        return candidate[len(left) : -1]
    return None


def _sanitize_filename_segment(segment: str) -> str:
    """Sanitize a string to be safely embedded in a filename."""
    safe: List[str] = []
    for ch in segment:
        if ch.isalnum() or ch in ("-", "_"):
            safe.append(ch)
        else:
            safe.append("_")
    return "".join(safe)[:48]


def _get_from_args(obj: Any, name: str, default: Any = None) -> Any:
    if obj is None:
        return default
    if isinstance(obj, dict):
        return obj.get(name, default)
    return getattr(obj, name, default)


def _merge_args(orig_args: Any, updates: Dict[str, Any]):
    base: Dict[str, Any] = {}
    if orig_args is not None:
        try:
            base = dict(vars(orig_args))
        except Exception:
            if isinstance(orig_args, dict):
                base = dict(orig_args)
    base.update({k: v for k, v in updates.items() if v is not None})
    try:
        return argparse.Namespace(**base)
    except Exception:
        return base


def split_tokens_by_delimiter(token_ids: List[int], delimiter_token_id: int) -> List[List[int]]:
    """Split a list of token IDs by a delimiter token ID into segments."""
    if not token_ids:
        return [[]]
    
    segments = []
    current_segment = []
    
    for token_id in token_ids:
        if token_id == delimiter_token_id:
            segments.append(current_segment)
            current_segment = []
        else:
            current_segment.append(token_id)
    
    # Add the final segment
    segments.append(current_segment)
    
    return segments


def main():
    parser = argparse.ArgumentParser(
        description=(
            "Load a completions file produced by completions.py, perform paragraph-level "
            "rollouts, and save results in the same format as completions_paragraph.py."
        )
    )
    parser.add_argument("--input_path", type=str, required=True, help="Path to the saved completions .pt file.")
    parser.add_argument("--model_name", type=str, default=None, help="Model name or path to use for rollouts. Overrides saved args if provided.")
    parser.add_argument("--dtype", type=str, default=None, help="Data type for vLLM (e.g., bfloat16). Overrides saved args if provided.")
    parser.add_argument("--max_model_len", type=int, default=None, help="Maximum model length. Overrides saved args if provided.")
    parser.add_argument("--top_k", type=int, default=None, help="Top-k sampling for rollouts. If not set, uses saved args value or -1.")
    parser.add_argument("--top_p", type=float, default=None, help="Top-p sampling for rollouts. If not set, uses saved args value or 1.0.")

    parser.add_argument("--rollout_append_text", type=str, default="\n\n**Final Answer**\n\\boxed", help="Text appended before rolling out from each paragraph.")
    parser.add_argument("--rollout_max_new_tokens", type=int, default=100, help="Maximum tokens to generate per paragraph-level rollout.")
    parser.add_argument("--rollout_temperature", type=float, default=0.0, help="Sampling temperature for paragraph-level rollouts.")
    parser.add_argument(
        "--paragraph_delimiter_token_id",
        type=int,
        default=None,
        help="Token ID used to split paragraphs in completions. If not provided, will attempt to get from saved args.",
    )
    parser.add_argument("--output_dir", type=str, default="outputs/completions_rollouts", help="Directory to save rollout results.")
    parser.add_argument("--max_prompts_per_generation", type=int, default=500, help="Maximum number of rollout prompts to process per vLLM generation batch.")
    parser.add_argument("--debug", action="store_true", help="Enable debug mode to print rollout prompts and responses.")
    parser.add_argument("--tensor_parallel_size", type=int, default=1, help="Number of tensor parallel size.")
    parser.add_argument("--data_parallel_size", type=int, default=1, help="Number of data parallel size.")

    args = parser.parse_args()

    # Load existing completions bundle
    bundle: Dict[str, Any] = torch.load(args.input_path, map_location="cpu", weights_only=False)
    completions: List[List[str]] = bundle["completions"]
    completion_ids: List[List[List[int]]] = bundle.get("completion_ids")
    prompts_text: List[str] = bundle["prompts_text"]
    prompts_ids: List[List[int]] = bundle.get("prompt_ids")
    answers: List[str] = bundle.get("answers")
    saved_args = bundle.get("args")

    if answers is None:
        raise ValueError("Input bundle must contain 'answers'.")
    
    if completion_ids is None:
        raise ValueError("Input bundle must contain 'completion_ids' for token-based splitting.")
    
    if prompts_ids is None:
        raise ValueError("Input bundle must contain 'prompt_ids' for token-based processing.")

    N = len(completions)
    if N == 0:
        raise ValueError("No completions found in input bundle.")
    M = len(completions[0])

    # Resolve effective parameters, preferring explicit CLI overrides, falling back to saved args, then defaults
    model_name = args.model_name or _get_from_args(saved_args, "model_name")
    if model_name is None:
        raise ValueError("Model name must be provided via --model_name or present in saved args.")
    dtype = args.dtype or _get_from_args(saved_args, "dtype", "bfloat16")
    max_model_len = args.max_model_len or _get_from_args(saved_args, "max_model_len", 4096)
    top_k = args.top_k if args.top_k is not None else _get_from_args(saved_args, "top_k", -1)
    top_p = args.top_p if args.top_p is not None else _get_from_args(saved_args, "top_p", 1.0)
    
    # Get paragraph delimiter token ID
    paragraph_delimiter_token_id = args.paragraph_delimiter_token_id
    if paragraph_delimiter_token_id is None:
        paragraph_delimiter_token_id = _get_from_args(saved_args, "paragraph_delimiter_token_id")
        if paragraph_delimiter_token_id is None:
            raise ValueError("Paragraph delimiter token ID must be provided via --paragraph_delimiter_token_id or present in saved args.")
    
    # Initialize tokenizer for decoding
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    # Filter out invalid token IDs from completion_ids
    vocab_size = len(tokenizer)
    print(f"Tokenizer vocabulary size: {vocab_size}")
    
    invalid_tokens_count = 0
    for i in range(N):
        for j in range(len(completion_ids[i])):
            original_length = len(completion_ids[i][j])
            completion_ids[i][j] = [token_id for token_id in completion_ids[i][j] if token_id < vocab_size]
            filtered_length = len(completion_ids[i][j])
            if filtered_length < original_length:
                invalid_tokens_count += original_length - filtered_length
    
    if invalid_tokens_count > 0:
        print(f"Filtered out {invalid_tokens_count} invalid token IDs (>= {vocab_size})")
    
    if args.debug:
        print(f"\n=== DEBUG: Token Splitting Info ===")
        print(f"Using token-based input to vLLM (no decode/encode round trip!)")
        print(f"Paragraph delimiter token ID: {paragraph_delimiter_token_id}")
        # Show what the delimiter token decodes to
        delimiter_text = tokenizer.decode([paragraph_delimiter_token_id], skip_special_tokens=False)
        print(f"Delimiter token decodes to: {repr(delimiter_text)}")
        
        # Show splitting example for first completion
        if N > 0 and len(completion_ids[0]) > 0:
            first_completion_tokens = completion_ids[0][0]
            segments = split_tokens_by_delimiter(first_completion_tokens, paragraph_delimiter_token_id)
            print(f"First completion has {len(segments)} paragraph segments:")
            for seg_idx, segment in enumerate(segments[:3]):  # Show first 3 segments
                segment_text = tokenizer.decode(segment, skip_special_tokens=False)
                print(f"  Segment {seg_idx}: {len(segment)} tokens -> {repr(segment_text[:100])}")
        print("="*50)

    dataset_name = _get_from_args(saved_args, "dataset_name", "unknown_dataset")
    num_prompts = _get_from_args(saved_args, "num_prompts", N)
    num_completions_per_prompt = _get_from_args(saved_args, "num_completions_per_prompt", M)
    generation_temperature = _get_from_args(saved_args, "temperature", 1.0)
    max_completion_length = _get_from_args(saved_args, "max_completion_length", 4096)
    max_answer_chars = _get_from_args(saved_args, "max_answer_chars", None)
    difficulty = _get_from_args(saved_args, "difficulty", None)

    append_seg = _sanitize_filename_segment(args.rollout_append_text)
    delim_seg = f"token_{paragraph_delimiter_token_id}"
    out_filename = os.path.join(
        args.output_dir,
        f"completions_rollouts_{model_name.replace('/', '_')}_{dataset_name}_{num_prompts}_{num_completions_per_prompt}_{generation_temperature}_{max_completion_length}_{max_answer_chars}_{difficulty}_append_{append_seg}_delim_{delim_seg}_roll_{args.rollout_max_new_tokens}.pt",
    )

    print(f"Will save to {out_filename}")

    # Initialize model for rollouts
    llm = LLM(
        model=model_name,
        dtype=dtype,
        gpu_memory_utilization=0.9,
        max_model_len=max_model_len,
        enable_prefix_caching=True,
        tensor_parallel_size=args.tensor_parallel_size,
        data_parallel_size=args.data_parallel_size,
        trust_remote_code=True,
    )

    rollout_sampling_params = SamplingParams(
        n=1,
        temperature=args.rollout_temperature,
        repetition_penalty=1.0,
        top_k=top_k,
        top_p=top_p,
        min_p=0.0,
        max_tokens=args.rollout_max_new_tokens,
    )

    start_time = time.time()

    # Initialize rollout_texts and rollout_token_ids structures
    rollout_texts: List[List[List[str]]] = []  # [N][M][P]
    rollout_token_ids: List[List[List[List[int]]]] = []  # [N][M][P]
    for i in range(N):
        rollout_texts.append([])
        rollout_token_ids.append([])
        for j in range(num_completions_per_prompt):
            completion_token_ids = completion_ids[i][j]
            paragraph_token_segments = split_tokens_by_delimiter(completion_token_ids, paragraph_delimiter_token_id)
            num_paragraphs = len(paragraph_token_segments)
            rollout_texts[i].append([""] * num_paragraphs)
            rollout_token_ids[i].append([[] for _ in range(num_paragraphs)])

    # Process in batches of original prompts
    batch_size = args.max_prompts_per_generation
    num_batches = (N + batch_size - 1) // batch_size
    
    for batch_idx in range(num_batches):
        start_prompt_idx = batch_idx * batch_size
        end_prompt_idx = min(start_prompt_idx + batch_size, N)
        
        # Prepare rollout prompts for this batch of original prompts
        batch_rollout_prompt_token_ids: List[List[int]] = []
        batch_index_map: List[Tuple[int, int, int]] = []  # (prompt_idx, completion_idx, paragraph_idx)
        
        # Tokenize the rollout append text once
        rollout_append_token_ids = tokenizer.encode(args.rollout_append_text, add_special_tokens=False)
        
        for i in range(start_prompt_idx, end_prompt_idx):
            for j in range(num_completions_per_prompt):
                completion_token_ids = completion_ids[i][j]
                paragraph_token_segments = split_tokens_by_delimiter(completion_token_ids, paragraph_delimiter_token_id)
                for k in range(len(paragraph_token_segments)):
                    # Reconstruct tokens up to paragraph k (inclusive)
                    prefix_token_ids = prompts_ids[i].copy()
                    for seg_idx in range(k + 1):
                        prefix_token_ids.extend(paragraph_token_segments[seg_idx])
                        if seg_idx < k:  # Add delimiter between segments (but not after the last one)
                            prefix_token_ids.append(paragraph_delimiter_token_id)
                    
                    # Add rollout append tokens
                    full_prompt_token_ids = prefix_token_ids + rollout_append_token_ids
                    batch_rollout_prompt_token_ids.append(full_prompt_token_ids)
                    batch_index_map.append((i, j, k))
        
        if len(batch_rollout_prompt_token_ids) > 0:
            print(f"Processing batch {batch_idx + 1}/{num_batches}: prompts {start_prompt_idx}-{end_prompt_idx-1} ({len(batch_rollout_prompt_token_ids)} rollout prompts)...")
            
            if args.debug:
                print(f"\n=== DEBUG: Batch {batch_idx + 1} Rollout Prompts ===")
                for debug_idx, prompt_tokens in enumerate(batch_rollout_prompt_token_ids[:3]):  # Show first 3 prompts
                    i, j, k = batch_index_map[debug_idx]
                    print(f"[{debug_idx}] Prompt {i}, Completion {j}, Paragraph {k}:")
                    print(f"Token count: {len(prompt_tokens)}")
                    # Decode last 200 chars for debugging
                    prompt_text = tokenizer.decode(prompt_tokens, skip_special_tokens=False)
                    print(f"Prompt text: {repr(prompt_text[-200:])}")
                    print(f"Last 10 tokens: {prompt_tokens[-10:]}")
                if len(batch_rollout_prompt_token_ids) > 3:
                    print(f"... and {len(batch_rollout_prompt_token_ids) - 3} more prompts")
                print("="*50)
            
            batch_outputs = llm.generate(prompt_token_ids=batch_rollout_prompt_token_ids, sampling_params=rollout_sampling_params)
            for local_idx, output in enumerate(batch_outputs):
                i, j, k = batch_index_map[local_idx]
                generated_text = output.outputs[0].text
                generated_token_ids = output.outputs[0].token_ids
                rollout_texts[i][j][k] = generated_text
                rollout_token_ids[i][j][k] = generated_token_ids
                
                if args.debug and local_idx < 3:  # Show first 3 responses
                    print(f"\n=== DEBUG: Response {local_idx} ===")
                    print(f"Prompt {i}, Completion {j}, Paragraph {k}")
                    print(f"Generated text: {repr(generated_text)}")
                    print(f"Generated token IDs: {generated_token_ids}")
                    print(f"Number of generated tokens: {len(generated_token_ids)}")
                    print("="*30)

    # Parse answers and compute rewards
    rollout_responses: List[List[List[str]]] = []  # [N][M][P]
    rollout_rewards: List[List[List[float]]] = []  # [N][M][P]
    for i in range(N):
        rollout_responses.append([])
        rollout_rewards.append([])
        for j in range(num_completions_per_prompt):
            paragraph_texts = rollout_texts[i][j]
            parsed_answers_per_paragraph: List[str] = []
            rewards_per_paragraph: List[float] = []
            for k in range(len(paragraph_texts)):
                full_for_parsing = (args.rollout_append_text or "") + (paragraph_texts[k] or "")
                parsed_answer = extract_first_boxed(full_for_parsing)
                reward_val = math_equal(parsed_answer, answers[i])
                parsed_answers_per_paragraph.append(parsed_answer)
                rewards_per_paragraph.append(reward_val)
            rollout_responses[i].append(parsed_answers_per_paragraph)
            rollout_rewards[i].append(rewards_per_paragraph)

    total_paragraphs = 0
    total_correct = 0.0
    for i in range(N):
        for j in range(num_completions_per_prompt):
            for val in rollout_rewards[i][j]:
                total_paragraphs += 1
                total_correct += float(val)
    avg_reward = (total_correct / total_paragraphs) if total_paragraphs > 0 else float("nan")
    print(f"Average paragraph-level reward: {avg_reward:.4f}  (correct={int(total_correct)}/{total_paragraphs})")
    
    # Show statistics about rollout generation
    total_rollout_tokens = sum(
        sum(sum(len(rollout_token_ids[i][j][k]) for k in range(len(rollout_token_ids[i][j]))) 
            for j in range(len(rollout_token_ids[i])))
        for i in range(len(rollout_token_ids))
    )
    print(f"Total rollout tokens generated: {total_rollout_tokens}")
    print(f"Average rollout tokens per paragraph: {total_rollout_tokens / total_paragraphs if total_paragraphs > 0 else 0:.1f}")
    
    if args.debug:
        print(f"\n=== DEBUG: Final Results Sample ===")
        for i in range(min(2, N)):  # Show first 2 prompts
            print(f"Prompt {i}:")
            for j in range(min(2, num_completions_per_prompt)):  # Show first 2 completions
                print(f"  Completion {j}:")
                for k in range(min(3, len(rollout_responses[i][j]))):  # Show first 3 paragraphs
                    parsed_answer = rollout_responses[i][j][k]
                    reward = rollout_rewards[i][j][k]
                    print(f"    Paragraph {k}: answer='{parsed_answer}' reward={reward:.1f}")
                    print(f"    Generated text: {repr(rollout_texts[i][j][k][:100])}")
                    print(f"    Generated token count: {len(rollout_token_ids[i][j][k])}")
                    print(f"    Generated token IDs: {rollout_token_ids[i][j][k][:10]}..." if len(rollout_token_ids[i][j][k]) > 10 else f"    Generated token IDs: {rollout_token_ids[i][j][k]}")
        print("="*50)

    end_time = time.time()
    print(f"Time taken: {end_time - start_time:.2f} seconds")

    # Merge args and save
    merged_args = _merge_args(
        saved_args,
        {
            "model_name": model_name,
            "dtype": dtype,
            "max_model_len": max_model_len,
            "top_k": top_k,
            "top_p": top_p,
            "rollout_append_text": args.rollout_append_text,
            "rollout_max_new_tokens": args.rollout_max_new_tokens,
            "rollout_temperature": args.rollout_temperature,
            "paragraph_delimiter_token_id": paragraph_delimiter_token_id,
            "output_dir": args.output_dir,
        },
    )

    save_bundle = dict(bundle)
    save_bundle.update(
        {
            "rollout_texts": rollout_texts,
            "rollout_token_ids": rollout_token_ids,
            "rollout_responses": rollout_responses,
            "rollout_rewards": rollout_rewards,
            "args": merged_args,
        }
    )

    os.makedirs(os.path.dirname(out_filename), exist_ok=True)
    torch.save(save_bundle, out_filename)


if __name__ == "__main__":
    main()


