import sys
import numpy as np
from typing import Callable, List, Any
from collections import defaultdict
from transformers import PreTrainedTokenizer

from search.decoding import residual_mppi_step, SearchResult
from search.utils import aggregate_score, compute_divergence
from utils.concurrency import run_batch

def residual_mppi(llm: Any, prompt: str, sample: Any, verifier: Any, tokenizer: PreTrainedTokenizer, args: Any) -> SearchResult:
    current_text = prompt
    step = 0
    planning_steps = 0

    residual_mppi_step_samples = args.n_samples // args.residual_mppi_rollout_samples
    current_step_samples = residual_mppi_step_samples
    current_rollout_samples = args.residual_mppi_rollout_samples
    
    log_values = defaultdict(list)
    metric = defaultdict(list)
    while True:
        step += 1
        if current_step_samples > 1 and current_rollout_samples > 1:
            planning_steps += 1
            
        # Initialize N starts
        prompt_inputs = [current_text for _ in range(current_step_samples)]
        
        # Output N actions/ states
        if args.residual_mppi_stop is not None:
            batch_params = [dict(llm=llm, prompt=p, temperature=args.temperature, stop=args.residual_mppi_stop) for p in prompt_inputs]
        else:
            batch_params = [dict(llm=llm, prompt=p, temperature=args.temperature, max_tokens=args.residual_mppi_execute_length) for p in prompt_inputs]
        
        output_states = run_batch(residual_mppi_step, batch_params, num_threads=16)

        output_texts = [state.prompt[len(prompt):] + state.response_text for state in output_states]
        output_scores = verifier(sample, output_texts)
        step_texts = [state.response_text for state in output_states]
        step_lengths = [len(tokenizer.encode(text)) for text in step_texts]
        step_scores = [score[-length:] for score, length in zip(output_scores, step_lengths)]
        agg_output_scores = aggregate_score(step_scores, args.agg_strategy)
        
        output_Q = np.array(agg_output_scores)
        output_selected_logprobs = []
        output_logprobs = [state.meta_info["result"].get("output_top_logprobs", []) for state in output_states]
        for k, output_top_logprobs in enumerate(output_logprobs):
            if output_top_logprobs:
                output_selected_logprobs.append(np.mean([top_logprobs[0][0] for top_logprobs in output_top_logprobs if top_logprobs]))
            else:
                output_selected_logprobs.append(0.0)
                
        output_Q += args.residual_mppi_omega * np.array(output_selected_logprobs)

        rollout_prompts = [state.prompt + state.response_text for state in output_states for _ in range(current_rollout_samples)]
        if args.residual_mppi_stop is not None:
            rollout_params = [dict(llm=llm, prompt=p, temperature=args.residual_mppi_rollout_temperature, stop=args.residual_mppi_stop) for p in rollout_prompts]
        else:
            rollout_params = [dict(llm=llm, prompt=p, temperature=args.residual_mppi_rollout_temperature, max_tokens=args.residual_mppi_rollout_length) for p in rollout_prompts]
            
        rollout_states = run_batch(residual_mppi_step, rollout_params, num_threads=16)
        rollout_texts = [state.prompt[len(prompt):] + state.response_text for state in rollout_states]

        rollout_scores = verifier(sample, rollout_texts)
        rollout_step_texts = [state.response_text for state in rollout_states]
        rollout_step_lengths = [len(tokenizer.encode(text)) for text in rollout_step_texts]
        rollout_step_scores = [score[-length:] for score, length in zip(rollout_scores, rollout_step_lengths)]
        agg_rollout_scores = aggregate_score(rollout_step_scores, args.agg_strategy)

        rollout_Q = np.array(agg_rollout_scores)
        rollout_selected_logprobs = []
        rollout_logprobs = [state.meta_info["result"].get("output_top_logprobs", []) for state in rollout_states]
        for k, output_top_logprobs in enumerate(rollout_logprobs):
            if output_top_logprobs:
                rollout_selected_logprobs.append(np.mean([top_logprobs[0][0] for top_logprobs in output_top_logprobs if top_logprobs]))
            else:
                rollout_selected_logprobs.append(0.0)
                
        rollout_Q += args.residual_mppi_omega * np.array(rollout_selected_logprobs)

        rollout_Q = rollout_Q.reshape(current_step_samples, current_rollout_samples).mean(axis=1) # M
        completed_mask = np.array([not state.completed for state in output_states])
        total_Q = np.where(completed_mask, (output_Q + rollout_Q) / 2 , output_Q)
        selected_idx = np.argmax(total_Q)
        append_text = output_states[selected_idx].response_text
        current_text += append_text

        divergence = compute_divergence(output_Q, rollout_Q, method=args.residual_mppi_divergence_method)
        metric["divergence"].append(divergence)
        if divergence >= args.residual_mppi_divergence_threshold:
            current_step_samples = residual_mppi_step_samples
            current_rollout_samples = args.residual_mppi_rollout_samples
        else:
            current_step_samples = 1
            current_rollout_samples = 1

        log_values["total_Q"].append(total_Q.tolist())
        log_values["output_Q"].append(output_Q.tolist())
        log_values["rollout_Q"].append(rollout_Q.tolist())
        log_values["output_scores"].append(agg_output_scores)
        log_values["rollout_scores"].append(agg_rollout_scores)
        log_values["output_logprobs"].append(output_selected_logprobs)
        log_values["rollout_logprobs"].append(rollout_selected_logprobs)
        
        if output_states[selected_idx].completed:
            break
        
        generated_tokens = len(tokenizer.encode(current_text[len(prompt):]))
        if generated_tokens > args.max_tokens:
            break
        
        if step > args.max_search_steps:
            final_state = residual_mppi_step(llm=llm, prompt=current_text, temperature=args.residual_mppi_rollout_temperature, max_tokens=args.max_tokens - generated_tokens)
            current_text += final_state.response_text
            break
        
    response = current_text[len(prompt):]
    return SearchResult(
        outputs=[response],
        prediction=response,
        metrics={
            "step": step,
            "planning_steps": planning_steps,
            "divergence": np.mean(metric["divergence"]),
        },
        log_values=log_values
    )
