import sys
import numpy as np
import sglang as sgl
from typing import Callable, List
from collections import defaultdict
from transformers import PreTrainedTokenizer
from sglang.lang.interpreter import ProgramState

from search.decoding import residual_mppi_step
from search.utils import aggregate_score, compute_divergence


@sgl.function
def residual_mppi(s: ProgramState, prompt, sample, verifier, tokenizer: PreTrainedTokenizer, args):
    s += prompt
    step = 0
    planning_steps = 0

    residual_mppi_step_samples = args.n_samples // args.residual_mppi_rollout_samples
    current_step_samples = residual_mppi_step_samples
    current_rollout_samples = args.residual_mppi_rollout_samples
    
    log_values = defaultdict(list)
    metric = defaultdict(list)
    while True:
        step += 1
        if current_step_samples > 1 and current_rollout_samples > 1:
            planning_steps += 1
        # # Initialize N starts
        prompt_inputs = [s.text() for _ in range(current_step_samples)] # M
        # Output N actions/ states
        if args.residual_mppi_stop is not None:
            output_states = residual_mppi_step.run_batch([dict(prompt=p, temperature=args.temperature, stop=args.residual_mppi_stop) for p in prompt_inputs], num_threads=16)
        else:
            output_states = residual_mppi_step.run_batch([dict(prompt=p, temperature=args.temperature, max_tokens=args.residual_mppi_execute_length) for p in prompt_inputs], num_threads=16)

        output_texts = [state.text()[len(prompt):] for state in output_states]
        output_scores = verifier(sample, output_texts)
        step_texts = [state.text()[len(s.text()):] for state in output_states]
        step_lengths = [len(tokenizer.encode(text)) for text in step_texts]
        step_scores = [score[-length:] for score, length in zip(output_scores, step_lengths)]
        agg_output_scores = aggregate_score(step_scores, args.agg_strategy)
        
        output_Q = np.array(agg_output_scores)
        output_selected_logprobs = []
        output_logprobs = [state.get_meta_info("result")["output_top_logprobs"] for state in output_states]
        for k, output_top_logprobs in enumerate(output_logprobs):
            output_selected_logprobs.append(np.mean([top_logprobs[0][0] for top_logprobs in output_top_logprobs]))
        output_Q += args.residual_mppi_omega * np.array(output_selected_logprobs)

        rollout_prompts = [state.text() for state in output_states for _ in range(current_rollout_samples)]
        if args.residual_mppi_stop is not None:
            rollout_states = residual_mppi_step.run_batch([dict(prompt=p, temperature=args.residual_mppi_rollout_temperature, stop=args.residual_mppi_stop) for p in rollout_prompts], num_threads=16)
        else:
            rollout_states = residual_mppi_step.run_batch([dict(prompt=p, temperature=args.residual_mppi_rollout_temperature, max_tokens=args.residual_mppi_rollout_length) for p in rollout_prompts], num_threads=16)
        rollout_texts = [state.text()[len(prompt):] for state in rollout_states]

        rollout_scores = verifier(sample, rollout_texts)
        rollout_step_texts = [state.text()[len(s.text()):] for state in rollout_states]
        rollout_step_lengths = [len(tokenizer.encode(text)) for text in rollout_step_texts]
        rollout_step_scores = [score[-length:] for score, length in zip(rollout_scores, rollout_step_lengths)]
        agg_rollout_scores = aggregate_score(rollout_step_scores, args.agg_strategy)

        rollout_Q = np.array(agg_rollout_scores)
        rollout_selected_logprobs = []
        rollout_logprobs = [state.get_meta_info("result")["output_top_logprobs"] for state in rollout_states]
        for k, output_top_logprobs in enumerate(rollout_logprobs):
            rollout_selected_logprobs.append(np.mean([top_logprobs[0][0] for top_logprobs in output_top_logprobs]))
        rollout_Q += args.residual_mppi_omega * np.array(rollout_selected_logprobs)

        rollout_Q = rollout_Q.reshape(current_step_samples, current_rollout_samples).mean(axis=1) # M
        completed_mask = np.array([not state["completed"] for state in output_states])
        total_Q = np.where(completed_mask, (output_Q + rollout_Q) / 2 , output_Q)
        selected_idx = np.argmax(total_Q)
        append_text = output_states[selected_idx].text()[len(s.text()):]
        s += append_text

        divergence = compute_divergence(output_Q, rollout_Q, method=args.residual_mppi_divergence_method)
        metric["divergence"].append(divergence)
        if divergence >= args.residual_mppi_divergence_threshold:
            current_step_samples = residual_mppi_step_samples
            current_rollout_samples = args.residual_mppi_rollout_samples
        else:
            current_step_samples = 1
            current_rollout_samples = 1

        log_values["total_Q"].append(total_Q.tolist())
        log_values["output_Q"].append(output_Q.tolist())
        log_values["rollout_Q"].append(rollout_Q.tolist())
        log_values["output_scores"].append(agg_output_scores)
        log_values["rollout_scores"].append(agg_rollout_scores)
        log_values["output_logprobs"].append(output_selected_logprobs)
        log_values["rollout_logprobs"].append(rollout_selected_logprobs)
        
        # print(append_text, end="", flush=True, file=sys.stderr)
        if output_states[selected_idx]["completed"]:
            break
            # print(f"Trigger eos condition with {output_texts[selected_idx]}")
        
        generated_tokens = len(tokenizer.encode(s.text()[len(prompt):]))
        if generated_tokens > args.max_tokens:
            break
            # print(f"Trigger max tokens condition with {s.text()[len(prompt):]}")
        
        if step > args.max_search_steps:
            final_prompt = s.text()
            final_state = residual_mppi_step.run(prompt=final_prompt, temperature=args.residual_mppi_rollout_temperature, max_tokens=args.max_tokens - generated_tokens)
            final_text = final_state.text()[len(final_prompt):]
            s += final_text
            break
        
    response = s.text()[len(prompt):]
    s["prediction"] = response
    s["outputs"] = [response]
    s["log_values"] = log_values
    s["metric"] = {
        "step": step,
        "planning_steps": planning_steps,
        "divergence": np.mean(metric["divergence"]),
    }
