import numpy as np
import copy
from typing import Callable, List, Any
from transformers import PreTrainedTokenizer

from search.decoding import simple_decoding, cot_decoding, ResponseState, SearchResult
from search.utils import aggregate_score
from utils.concurrency import run_batch

def beam_search(llm: Any, prompt: str, sample: Any, verifier: Callable[[str, List[str]], List[List[float]]], tokenizer: PreTrainedTokenizer, args: Any) -> SearchResult:
    num_beams = args.n_samples // args.beam_width
    sampling_params = dict(temperature=args.temperature, max_tokens=args.max_tokens, stop=[args.step_token])

    active_beams = [ResponseState(prompt=prompt) for _ in range(num_beams)]
    completed_beams = []

    for step in range(args.max_search_steps):
        if step == args.max_search_steps - 1:
            sampling_params["stop"] = None

        if args.decoding_strategy == "simple":
            prompt_inputs = [beam.prompt + beam.response_text for beam in active_beams for _ in range(args.beam_width)]
            batch_params = [dict(llm=llm, prompt=p, **sampling_params) for p in prompt_inputs]
            output_states = run_batch(simple_decoding, batch_params, num_threads=16)
        elif args.decoding_strategy == "cot":
            prompt_inputs = [beam.prompt + beam.response_text for beam in active_beams]
            cot_params = [dict(llm=llm, prompt=p, generated_answer=p[len(prompt):], cot_width=args.beam_width, **sampling_params) for p in prompt_inputs]
            cot_decoding_states = run_batch(cot_decoding, cot_params, num_threads=16)
            output_states = [state for states_list in cot_decoding_states for state in states_list]

        output_texts = [state.prompt[len(prompt):] + state.response_text for state in output_states]

        scores = verifier(sample, output_texts)
        agg_scores = aggregate_score(scores, args.agg_strategy)

        for i, state in enumerate(output_states):
            state.meta_info["scores"] = scores[i]
            state.meta_info["agg_scores"] = agg_scores[i]
        top_indices = np.argsort(agg_scores)[-num_beams:]

        next_beams = []
        for idx in top_indices:
            total_tokens = len(tokenizer.encode(output_texts[idx]))
            if output_states[idx].completed:
                completed_beams.append(output_states[idx])
            elif total_tokens >= args.max_tokens:
                completed_beams.append(output_states[idx])
            else:
                next_beams.append(output_states[idx])
        active_beams = next_beams

        if len(active_beams) == 0:
            break
    
    completed_beams.extend(active_beams)
    outputs = [beam.prompt[len(prompt):] + beam.response_text for beam in completed_beams]
    scores = [beam.meta_info["scores"] for beam in completed_beams]
    agg_scores = [beam.meta_info["agg_scores"] for beam in completed_beams]

    return SearchResult(
        outputs=outputs,
        prediction=outputs[np.argmax(agg_scores)],
        scores=scores,
        agg_scores=agg_scores,
        metrics={"rewards_mean": np.mean([np.mean(sc) for sc in scores]), "agg_rewards_mean": np.mean([np.mean(sc) for sc in agg_scores]), "step": step}
    )
