#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import logging
from collections import defaultdict

import numpy as np
from tqdm import tqdm
from vllm import LLM, SamplingParams

from sal.config import Config
from sal.models.reward_models import PRM

from .utils import Beam, build_conv, generate_k_steps, last

logger = logging.getLogger()
from sal.utils.score import aggregate_scores

import torch
from search.mapping.scaling_processor import ScalingProcessor


def _beam_search(
    batch_of_prompts,
    batch_of_unique_ids,
    config: Config,
    llm: LLM,
    prm: PRM,
    temperature_map: dict | None,
    bias_map: dict | None,
) -> list[Beam]:
    # Initialize all beams and their corresponding unique_id
    beams: list[Beam] = []
    beam_uids: list = []
    for prompt, uid in zip(batch_of_prompts, batch_of_unique_ids, strict=False):
        for i in range(config.n):
            beams.append(
                Beam(
                    prompt=prompt,
                    index=i,
                    current_text="",
                    next_texts=None,
                    lookahead_texts=None,
                    pruned=False,
                    completed=False,  # New flag to track completion
                    stop_reasons=None,
                    history=[],
                    best_scores=[],
                    all_scores=[],
                    previous_text=None,
                    completion_tokens=0,
                )
            )
            beam_uids.append(uid)

    completed_beams: list[Beam] = []

    last_active_beams: list[Beam] | None = None
    for i in tqdm(range(config.num_iterations), desc="Beam search iterations"):
        if i == 0:
            active_beams = [b for b in beams if not b.pruned]
            active_uids = [u for b, u in zip(beams, beam_uids) if not b.pruned]
        else:
            # Synchronously remove pruned beams and their uids
            mask = [not b.pruned for b in active_beams]
            active_beams = [b for b in active_beams if not b.pruned]
            active_uids = [u for u, m in zip(active_uids, mask) if m]

        # Duplicate active beams to ensure that we have config.n beams per iteration
        if len(active_beams) != config.n:
            repeats = (config.n // len(active_beams)) + 1
            logger.debug(
                f"Extending active_beams with {repeats} repetitions to reach size {config.n}"
            )
            extended_active_beams = [copy.deepcopy(b) for b in (active_beams * repeats)[: config.n]]
            extended_active_uids = (active_uids * repeats)[: config.n]
            active_beams = extended_active_beams
            active_uids = extended_active_uids
            if len(active_beams) != config.n:
                raise ValueError(
                    f"Expected {config.n} active beams, but got {len(active_beams)}"
                )

        # Prepare corresponding SamplingParams for each active beam (including per-prompt logits processor)
        per_prompt_sampling_params: list[SamplingParams] = []
        for uid in active_uids:
            # Temperature
            if not temperature_map:
                temperature = config.temperature
            else:
                temp_data = temperature_map.get(str(uid), {"temperature": config.temperature})
                temperature = temp_data["temperature"]
                if isinstance(temperature, np.ndarray):
                    pass
                elif isinstance(temperature, list):
                    temperature = np.array(temperature)
                elif np.isscalar(temperature):
                    temperature = float(temperature)
                else:
                    raise ValueError(f"Unknown temperature type: {type(temperature)}")

            # Bias
            bias = None
            if bias_map:
                bias_data = bias_map.get(str(uid), None)
                if bias_data is not None and "bias" in bias_data:
                    b = bias_data["bias"]
                    if isinstance(b, list):
                        b = np.array(b)
                    if isinstance(b, np.ndarray):
                        bias = torch.from_numpy(b).float()
                    else:
                        # 若已是 tensor 或 None 就直接用
                        bias = b

            scaling_processor = ScalingProcessor(temp=temperature, bias=bias, add_bias_first=True)

        
            # 中途迭代：設 stop 與 include_stop_str_in_output；最後一步不設 stop
            if i == config.num_iterations - 1:
                sp = SamplingParams(
                    logits_processors=[scaling_processor],
                    temperature=1.0,  # base temperature; 實際由 scaling_processor 調整
                    max_tokens=config.max_tokens,
                    top_p=config.top_p,
                    n=1,
                )
            else:
                sp = SamplingParams(
                    logits_processors=[scaling_processor],
                    temperature=1.0,
                    max_tokens=config.max_tokens,
                    top_p=config.top_p,
                    stop=["\n\n"],
                    include_stop_str_in_output=True,
                    n=1,
                )
            per_prompt_sampling_params.append(sp)

        convs = [
            build_conv(b.prompt, b.current_text, config.system_prompt)
            for b in active_beams
        ]
        continue_final_message = i > 0
        add_generation_prompt = i == 0

        tokenizer = llm.get_tokenizer()
        if config.custom_chat_template is not None:
            tokenizer.chat_template = config.custom_chat_template
        templated_convs = tokenizer.apply_chat_template(
            convs,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tokenize=False,
        )
        lookahead = 0 if i == config.num_iterations - 1 else config.lookahead
        gen_results = generate_k_steps(
            templated_convs, lookahead, llm, per_prompt_sampling_params, 1
        )

        prompts, completions = [], []
        for beam, gen_result in zip(active_beams, gen_results, strict=True):
            beam.next_texts = gen_result.next_texts
            beam.stop_reasons = gen_result.stop_reasons
            beam.lookahead_texts = gen_result.lookahead_texts
            beam.completion_tokens += gen_result.completion_tokens
            beam.current_text += beam.next_texts[0]
            beam.history.append(beam.next_texts[0])

            if (
                beam.stop_reasons[0] == "EOS"
                or beam.stop_reasons[0] == "length"
                or beam.next_texts[0] == ""
            ):
                beam.completed = True
                completed_beams.append(beam)
            prompts.append(beam.prompt)
            completions.append([beam.current_text])

        scores = prm.score(prompts, completions)

        agg_scores = [
            [aggregate_scores(s, config.agg_strategy) for s in score]
            for score in scores
        ]

        for beam, score in zip(active_beams, scores, strict=True):
            beam.all_scores = score[0]

        # Now filter active_beams/agg_scores/uids for beams that are completed
        keep_mask = [not b.completed for b in active_beams]
        agg_scores = [s for s, m in zip(agg_scores, keep_mask) if m]
        active_uids = [u for u, m in zip(active_uids, keep_mask) if m]
        active_beams = [b for b in active_beams if not b.completed]

    # Early stopping if all beams are completed
        if len(active_beams) == 0:
            break

        # Filter duplicate active beams
        if config.filter_duplicates:
            # Create a dictionary to filter duplicates and retain order
            unique_beam_dict = {}
            for i, b in enumerate(active_beams):
                if b.current_text not in unique_beam_dict:
                    unique_beam_dict[b.current_text] = (
                        i  # Map the unique text to its index
                    )
            keep_indices = list(unique_beam_dict.values())
            active_beams = [active_beams[i] for i in keep_indices]
            active_uids = [active_uids[i] for i in keep_indices]
            agg_scores = [agg_scores[i] for i in keep_indices]

        # Get indices for top (config.n / config.beam_width) completions
        top_indices = np.argsort(np.array(agg_scores).flatten())[ 
            -(config.n // config.beam_width) :
        ]

    for idx, beam in enumerate(active_beams):
            if idx not in top_indices:
                beam.pruned = True

    # 保存當前仍然存活的 beams 作為回退
    last_active_beams = [b for b in active_beams if not b.pruned]

    # 如果沒有任何完成的 beam，使用最後一次迭代的活躍 beams 作為回退
    if len(completed_beams) == 0:
        fallback = last_active_beams if last_active_beams is not None and len(last_active_beams) > 0 else beams
        # 取最多 config.n 個
        completed_beams = fallback[: config.n]

    # Filter completed beams for those with top config.n scores
    if config.sort_completed:
        completed_beams = sorted(
            completed_beams,
            key=lambda b: aggregate_scores(b.all_scores, config.agg_strategy),
            reverse=True,
        )[: config.n]
    else:
        completed_beams = completed_beams[: config.n]

    if len(completed_beams) != config.n:
        # If we don't have enough completed_beams, duplicate until we reach config.n
        repeats = (config.n // len(completed_beams)) + 1
        logger.debug(
            f"Extending completed_beams with {repeats} repetitions to reach size {config.n}"
        )
        extended_completed_beams = [
            copy.deepcopy(b) for b in (completed_beams * repeats)[: config.n]
        ]
        completed_beams = extended_completed_beams

    return completed_beams


def beam_search(
    examples,
    config: Config,
    llm: LLM,
    prm: PRM,
    temperature_map: dict | None = None,
    bias_map: dict | None = None,
):
    problems = examples["problem"]
    if "unique_id" in examples:
        unique_ids = examples["unique_id"]
    else:
        # 若沒有 unique_id，使用索引填充
        unique_ids = list(range(len(problems)))

    results = {"completions": [], "pred": [], "completion_tokens": [], "scores": []}

    # 逐題執行 beam search，避免跨題目互相競爭導致某些題目無結果
    for p, uid in zip(problems, unique_ids):
        per_problem_beams = _beam_search(
            [p],
            [uid],
            config,
            llm,
            prm,
            temperature_map or {},
            bias_map or {},
        )

        completions = [b.current_text for b in per_problem_beams]
        agg_scores = [aggregate_scores(b.all_scores, config.agg_strategy) for b in per_problem_beams]
        pred = completions[int(np.argmax(agg_scores))] if len(agg_scores) > 0 else ""

        results["completions"].append(completions)
        results["scores"].append([b.all_scores for b in per_problem_beams])
        results["pred"].append(pred)
        results["completion_tokens"].append([b.completion_tokens for b in per_problem_beams])

    return results
