import torch
from typing import Tuple, List, Dict, Any

from dllm_generation import dllm_original_generation


def _build_tail_gen_kwargs(
    dllm_type: str,
    base_gen_kwargs: Dict[str, Any],
    *,
    tail_steps: int,
    tail_block_length: int = None,
    temperature: float = None,
    cfg_scale: float = None,
    top_p: float = None,
    alg: str = None,
    alg_temp: float = None,
) -> Dict[str, Any]:
    """
    Build kwargs for tail generation, respecting the backend type and overriding only supported keys.
    """
    dllm_type = dllm_type.lower()
    out = dict(base_gen_kwargs) if base_gen_kwargs else {}
    out["steps"] = tail_steps

    if temperature is not None:
        out["temperature"] = temperature

    if dllm_type == "llada":
        if tail_block_length is not None:
            out["block_length"] = tail_block_length
        if cfg_scale is not None:
            out["cfg_scale"] = cfg_scale
        # mask_id and remasking stay the same as base kwargs
    elif dllm_type == "dream":
        if top_p is not None:
            out["top_p"] = top_p
        if alg is not None:
            out["alg"] = alg
        if alg_temp is not None:
            out["alg_temp"] = alg_temp
    else:
        raise ValueError(f"Unsupported dllm_type: {dllm_type}")

    return out


def dllm_latent_seek_generation(
    *,
    dllm_type: str,
    reward_model,
    model,
    tokenizer,
    device: str,
    question: str,
    input_text: str,
    original_answer: str,
    start_index_in_answer: int = 0,
    max_num_steps: int = 10,
    lr: float = 0.03,
    k: float = 0.1,
    reward_threshold: float = -0.2,
    # Tail generation controls (use DLLM again)
    base_gen_kwargs: Dict[str, Any] = None,
    tail_steps: int = 128,
    tail_block_length: int = 32,
) -> Tuple[str, List[float], int, int, int]:
    """
    Latent-Seek optimization for DLLMs. This first version mirrors the AR-LLM approach
    by directly optimizing a discrete token segment via a learnable categorical distribution,
    then regenerating the tail with the DLLM.

    The gradient signal is REINFORCE-style: we maximize the (black-box) reward by
    increasing log-prob of the chosen tokens (no gradients through reward).

    Returns:
      final_answer_str, reward_history, original_answer_length, optimized_answer_length, update_length
    """
    # Prepare prompt and original answer tokens
    prompt_ids = tokenizer([input_text], return_tensors="pt").to(device).input_ids
    prompt_len = prompt_ids.shape[1]
    answer_ids = tokenizer.encode(original_answer, add_special_tokens=False)
    original_answer_token_len = len(answer_ids)

    reward_history: List[float] = []

    # Initialize reward from DLM-backed reward model
    initial_reward = reward_model.get_reward(question, original_answer, input_text)
    reward_history.append(initial_reward)
    current_reward = float(initial_reward)
    print(f"-- Original Output: {original_answer} -- Initial Reward: {initial_reward}")

    # Determine update span
    update_length = min(int(k * original_answer_token_len), 300)
    if update_length <= 0:
        print(
            f"Update Length is {update_length} (k * original_len = {k * original_answer_token_len}). Skipping optimization."
        )
        return original_answer, reward_history, original_answer_token_len, original_answer_token_len, 0

    if start_index_in_answer >= original_answer_token_len:
        print(
            f"start_index_in_answer ({start_index_in_answer}) >= original_answer_token_len ({original_answer_token_len}). Skipping."
        )
        return original_answer, reward_history, original_answer_token_len, original_answer_token_len, 0

    actual_update_end_idx = min(start_index_in_answer + update_length, original_answer_token_len)
    actual_update_length = actual_update_end_idx - start_index_in_answer
    if actual_update_length <= 0:
        print("Actual update length is 0. Skipping optimization.")
        return original_answer, reward_history, original_answer_token_len, original_answer_token_len, 0

    print(
        f"Optimizing {actual_update_length} tokens starting at index {start_index_in_answer} of the answer."
    )

    # Build token segments
    before_opt_ids = answer_ids[:start_index_in_answer]
    init_opt_ids = answer_ids[start_index_in_answer:actual_update_end_idx]

    # Initialize a learnable categorical distribution over vocab for each token in the optimized segment
    vocab_size = tokenizer.vocab_size if hasattr(tokenizer, "vocab_size") else len(tokenizer)
    init_logits = torch.zeros(actual_update_length, vocab_size, device=device)
    # Encourage staying near the original tokens initially
    for i, tok_id in enumerate(init_opt_ids):
        if tok_id < vocab_size:
            init_logits[i, tok_id] = 2.0  # small positive bias

    opt_logits = torch.nn.Parameter(init_logits)
    optimizer = torch.optim.Adam([opt_logits], lr=lr)

    new_answer_str = None

    # Optimization loop
    for step_idx in range(max_num_steps):
        if current_reward > reward_threshold:
            final_answer_str = new_answer_str if new_answer_str is not None else original_answer
            optimized_answer_length = len(tokenizer.encode(final_answer_str, add_special_tokens=False))
            print(
                f"-- Optimization stopped (threshold met). Final Answer: {final_answer_str}, Current Reward: {current_reward}"
            )
            return (
                final_answer_str,
                reward_history,
                original_answer_token_len,
                optimized_answer_length,
                actual_update_length,
            )

        optimizer.zero_grad()

        # Categorical distribution over tokens for the optimized span
        probs = torch.softmax(opt_logits, dim=-1) + 1e-8
        opt_ids = torch.argmax(probs, dim=-1)  # (L,)

        # Policy gradient style objective (no gradients through reward)
        log_pi_xz = torch.log(probs[torch.arange(actual_update_length, device=device), opt_ids] + 1e-10)
        loss = -float(current_reward) * log_pi_xz.sum()
        print(f"-- Step {step_idx + 1}, Loss: {loss.item():.4f}")
        loss.backward()
        optimizer.step()

        # Build current prompt including the optimized segment
        with torch.no_grad():
            combined_ids = torch.tensor(
                before_opt_ids + opt_ids.tolist(), dtype=torch.long, device=device
            )
            current_prompt_ids = torch.cat([prompt_ids[0], combined_ids], dim=-1).unsqueeze(0)

            # Determine tail length to generate
            tail_gen_length = max(0, original_answer_token_len - actual_update_length)
            if tail_gen_length < 10 and original_answer_token_len > actual_update_length:
                tail_gen_length = min(max(10, original_answer_token_len - actual_update_length), 32)

            # Construct input_text for DLLM by decoding the current prompt tokens
            current_prompt_text = tokenizer.decode(current_prompt_ids[0].tolist(), skip_special_tokens=True)

            if tail_gen_length > 0:
                # Tail generation using the same DLLM
                tail_kwargs = _build_tail_gen_kwargs(
                    dllm_type,
                    base_gen_kwargs,
                    tail_steps=tail_steps,
                    tail_block_length=tail_block_length,
                    temperature=base_gen_kwargs.get("temperature") if base_gen_kwargs else None,
                    cfg_scale=base_gen_kwargs.get("cfg_scale") if base_gen_kwargs else None,
                    top_p=base_gen_kwargs.get("top_p") if base_gen_kwargs else None,
                    alg=base_gen_kwargs.get("alg") if base_gen_kwargs else None,
                    alg_temp=base_gen_kwargs.get("alg_temp") if base_gen_kwargs else None,
                )

                # Ensure LLaDA constraints: gen_length % block_length == 0 and steps % num_blocks == 0
                if dllm_type.lower() == "llada":
                    bl = tail_block_length if tail_block_length is not None else base_gen_kwargs.get("block_length", 32)
                    bl = int(bl)
                    # round up length to multiple of block length
                    tail_blocks = max(1, (tail_gen_length + bl - 1) // bl)
                    adjusted_len = tail_blocks * bl
                    # adjust steps to be divisible by number of blocks
                    base_steps = int(tail_kwargs.get("steps", 128))
                    if base_steps < tail_blocks:
                        base_steps = tail_blocks
                    if base_steps % tail_blocks != 0:
                        base_steps = base_steps + (tail_blocks - (base_steps % tail_blocks))
                    tail_kwargs["steps"] = base_steps
                    tail_kwargs["max_new_tokens"] = adjusted_len
                else:
                    tail_kwargs["max_new_tokens"] = tail_gen_length

                tail_text, _ = dllm_original_generation(
                    dllm_type=dllm_type,
                    model=model,
                    tokenizer=tokenizer,
                    device=device,
                    input_text=current_prompt_text,
                    **tail_kwargs,
                )

                new_answer_str = tokenizer.decode(before_opt_ids, skip_special_tokens=True) \
                    + tokenizer.decode(opt_ids.tolist(), skip_special_tokens=True) \
                    + tail_text
            else:
                new_answer_str = tokenizer.decode(before_opt_ids, skip_special_tokens=True) \
                    + tokenizer.decode(opt_ids.tolist(), skip_special_tokens=True)

            current_reward = float(reward_model.get_reward(question, new_answer_str, input_text))
            print(f"-- Step {step_idx + 1} New Answer: {new_answer_str}, Current Reward: {current_reward:.4f}")
            reward_history.append(current_reward)

        # Cleanup
        torch.cuda.empty_cache()

    # Finished all steps
    final_answer_str = new_answer_str if new_answer_str is not None else original_answer
    optimized_answer_length = len(tokenizer.encode(final_answer_str, add_special_tokens=False))
    print(f"-- Max steps reached. Final answer: {final_answer_str}")
    return final_answer_str, reward_history, original_answer_token_len, optimized_answer_length, actual_update_length
