#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import logging
from collections import defaultdict

import numpy as np
from tqdm import tqdm
from vllm import LLM, SamplingParams

from sal.config import Config
from sal.models.reward_models import PRM

from .utils import Beam, build_conv, generate_k_steps, last

logger = logging.getLogger()
from sal.utils.score import aggregate_scores


def _beam_search(batch_of_prompts, config: Config, llm: LLM, prm: PRM) -> list[Beam]:
    sampling_params = SamplingParams(
        temperature=config.temperature,
        max_tokens=config.max_tokens,
        top_p=config.top_p,
        stop=["\n\n"],
        include_stop_str_in_output=True,
        n=1,
    )

    beams: list[Beam] = []
    for prompt in batch_of_prompts:
        for i in range(config.n):
            beams.append(
                Beam(
                    prompt=prompt,
                    index=i,
                    current_text="",
                    next_texts=None,
                    lookahead_texts=None,
                    pruned=False,
                    completed=False,  # New flag to track completion
                    stop_reasons=None,
                    history=[],
                    best_scores=[],
                    all_scores=[],
                    previous_text=None,
                    completion_tokens=0,
                )
            )

    completed_beams: list[Beam] = []

    for i in tqdm(range(config.num_iterations), desc="Beam search iterations"):
        # print(f"Iteration {i + 1}/{config.num_iterations}\n")
        early_exit : bool = False
        if i == 0:
            active_beams = [b for b in beams if not b.pruned]
        else:
            active_beams = [b for b in active_beams if not b.pruned]

        # print(f"After pruning: {len(active_beams)} active beams")
        if i > 0:
            for granularity in range(config.g - 1):
                # print(f"========================EXTENDING for g={granularity} @ {i}=========================")
                convs = [
                    build_conv(b.prompt, b.current_text, config.system_prompt)
                    for b in active_beams
                ]
                tokenizer = llm.get_tokenizer()
                if config.custom_chat_template is not None:
                    tokenizer.chat_template = config.custom_chat_template
                templated_convs = tokenizer.apply_chat_template(
                    convs,
                    add_generation_prompt=False,
                    continue_final_message=True,
                    tokenize=False,
                )
                lookahead = 0 if i == config.num_iterations - 1 else config.lookahead
                # print(f"Templated convs: {templated_convs}")
                gen_results = generate_k_steps(
                    templated_convs, lookahead, llm, sampling_params, 1
                )
                for beam, gen_result in zip(active_beams, gen_results, strict=True):
                    beam.next_texts = gen_result.next_texts
                    beam.stop_reasons = gen_result.stop_reasons
                    beam.lookahead_texts = gen_result.lookahead_texts
                    beam.completion_tokens += gen_result.completion_tokens
                    beam.current_text += beam.next_texts[0]
                    beam.history.append(beam.next_texts[0])
                #     print(f"Beam {beam.index}: {beam.current_text} (length: {len(beam.current_text)})")
                #     print(f"Stop reasons: {beam.stop_reasons}")
                # print(f"========================EXTENDING DONE for g={granularity} @ {i}=========================")

            # Filter beams if they are completed during the extension
            extension_exit_beams : list[Beam] = []
            extension_exit_prompts, extension_exit_completions = [], []
            for beam in active_beams:
                if (
                    beam.stop_reasons[0] == "EOS"
                    or beam.stop_reasons[0] == "length"
                    or beam.next_texts[0] == ""
                ):
                    extension_exit_beams.append(beam)
                    beam.completed = True
                    completed_beams.append(beam)
                    if len(completed_beams) == config.n:
                        early_exit = True
                    extension_exit_prompts.append(beam.prompt)
                    extension_exit_completions.append([beam.current_text])
            #         print(f"++++++++++++++ Extension Stage, Completed beam {beam.index} at iteration {i}, pushing to completed_beams ++++++++++++++")
            #         print(f"completed_beams now has {len(completed_beams)} beams")
            # print(f"Extension exit prompts: {extension_exit_prompts}")
            # print(f"Extension exit completions: {extension_exit_completions}")
            # print(f"Size of extension_exit_prompts: {len(extension_exit_prompts)}")
            # print(f"Size of extension_exit_completions: {len(extension_exit_completions)}")
            # print(f"Size of extension_exit_beams: {len(extension_exit_beams)}")
            extension_exit_scores = prm.score(
                extension_exit_prompts, extension_exit_completions
            )
            # print(f"Extension exit scores: {extension_exit_scores}")
            # print (f"extension_exit_scores length: {len(extension_exit_scores)}")
            for beam, score in zip(extension_exit_beams, extension_exit_scores, strict=True):
                beam.all_scores = score[0]

            # Now filter active_beams 
            active_beams = [ b for b in active_beams if not b.completed ]
            # print(f"After removing completed beams from extension, active_beams has {len(active_beams)} beams")
            # Early stopping if all beams are completed
            if len(active_beams) == 0:
                break
            # elif len(active_beams) < config.n // config.beam_width:
            #     # Duplicate the current active beams to reach the desired number of active beams
            #     # Enlarge the active beams to config.n following first come first serve
            #     repeats = (config.n // config.beam_width) // len(active_beams) + 1
            #     # logger.debug(
            #     #     f"Extending active_beams with {repeats} repetitions to reach size {config.n // config.beam_width}"
            #     # )
            #     print(
            #         f"Extending active_beams with {repeats} repetitions to reach size {config.n // config.beam_width}"
            #     )
            #     extended_active_beams = [
            #         copy.deepcopy(b) for b in (active_beams * repeats)[: config.n // config.beam_width]
            #     ]
            #     active_beams = extended_active_beams
            #     if len(active_beams) != config.n // config.beam_width:
            #         raise ValueError(
            #             f"Expected {config.n // config.beam_width} active beams, but got {len(active_beams)}"
            #         )

            # should be B1 (n/beam_width) active beams
            # print(f"After expansion & filter/Before Duplication: {len(active_beams)} active beams")
            # Duplication
            # Copy the active beams beam_width (B2) times
            if len(active_beams) != config.n:
                repeats = (config.n // len(active_beams)) + 1
                logger.debug(
                    f"Extending active_beams with {repeats} repetitions to reach size {config.n}"
                )
                extended_active_beams = [
                    copy.deepcopy(b) for b in (active_beams * repeats)[: config.n]
                ]
                # a,b becomes a,b,a,b,a,b,a,b
                # for idx, beam in enumerate(extended_active_beams):
                #     print(f"Beam {idx}: {beam.current_text} (length: {len(beam.current_text)})")
                active_beams = extended_active_beams
                if len(active_beams) != config.n:
                    raise ValueError(
                        f"Expected {config.n} active beams, but got {len(active_beams)}"
                    )
                # print(f"After duplication/Enter branching: {len(active_beams)} active beams")
        if len(active_beams) == 0 or early_exit:
            # if len(active_beams) == 0:
            #     print(f"All beams completed during extension at iteration {i}, stopping early.")
            # elif early_exit:
            #     print(f"triggering early exit in branching")
        # if len(active_beams) == 0:
            # print(f"************************")
            # print(f"All beams completed during extension at iteration {i}, stopping early.")
            # print(f"************************")
            break

        # Branchout
        if i == config.num_iterations - 1:
            # Last iteration, generate to EOS
            sampling_params = SamplingParams(
                temperature=config.temperature,
                max_tokens=config.max_tokens,
                top_p=config.top_p,
                n=1,
            )

        convs = [
            build_conv(b.prompt, b.current_text, config.system_prompt)
            for b in active_beams
        ]
        add_generation_prompt = i == 0
        continue_final_message = i > 0

        tokenizer = llm.get_tokenizer()
        if config.custom_chat_template is not None:
            tokenizer.chat_template = config.custom_chat_template
        templated_convs = tokenizer.apply_chat_template(
            convs,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tokenize=False,
        )
        lookahead = 0 if i == config.num_iterations - 1 else config.lookahead
        # print(f"Templated convs: {templated_convs}")
        # print(f"-----------------------------BRANCHING @ {i}-----------------------------")
        gen_results = generate_k_steps(
            templated_convs, lookahead, llm, sampling_params, 1
        )

        prompts, completions = [], []
        for beam, gen_result in zip(active_beams, gen_results, strict=True):
            beam.next_texts = gen_result.next_texts
            beam.stop_reasons = gen_result.stop_reasons
            beam.lookahead_texts = gen_result.lookahead_texts
            beam.completion_tokens += gen_result.completion_tokens
            beam.current_text += beam.next_texts[0]
            beam.history.append(beam.next_texts[0])
            # print(f"Beam {beam.index}: {beam.current_text} (length: {len(beam.current_text)})")
            # print(f"Stop reasons: {beam.stop_reasons}")

            if (
                beam.stop_reasons[0] == "EOS"
                or beam.stop_reasons[0] == "length"
                or beam.next_texts[0] == ""
            ):
                beam.completed = True
                completed_beams.append(beam)
                if len(completed_beams) == config.n:
                    # print("triggering early exit in branching")
                    early_exit = True
                # print(f"++++++++++++++ Branch Stage, Completed beam {beam.index} at iteration {i}, pushing to completed_beams ++++++++++++++")
                # print(f"completed_beams now has {len(completed_beams)} beams")
            prompts.append(beam.prompt)
            completions.append([beam.current_text])
        # print(f"------------------------------BRANCHING @ {i} COMPLETED------------------------------")

        # debugging
        # print (f"Prompts length: {len(prompts)}")
        # print (f"Completions length: {len(completions)}")
        # print (f"Prompts: {prompts}")
        # print (f"Completions: {completions}")
        scores = prm.score(prompts, completions)

        agg_scores = [
            [aggregate_scores(s, config.agg_strategy) for s in score]
            for score in scores
        ]

        for beam, score in zip(active_beams, scores, strict=True):
            beam.all_scores = score[0]

        # Now filter active_beams and agg_scores for beams that are completed
        agg_scores = [
            agg_scores[i] for i, b in enumerate(active_beams) if not b.completed
        ]
        active_beams = [b for b in active_beams if not b.completed]
        # print(f"Active and Incomplete beams: {len(active_beams)}")

        # Early stopping if all beams are completed
        if (len(active_beams) == 0) or early_exit:
            # if len(active_beams) == 0:
            #     print(f"All beams completed during branching at iteration {i}, stopping early.")
            # elif early_exit:
            #     print(f"triggering early exit in branching")
        # if len(active_beams) == 0:
            # print(f"************************")
            # print(f"All beams completed at iteration {i}, stopping early.")
            # print(f"************************")
            break

        # Filter duplicate active beams
        if config.filter_duplicates:
            # Create a dictionary to filter duplicates and retain order
            unique_beam_dict = {}
            for i, b in enumerate(active_beams):
                if b.current_text not in unique_beam_dict:
                    unique_beam_dict[b.current_text] = (
                        i  # Map the unique text to its index
                    )
            active_beams = [active_beams[i] for i in unique_beam_dict.values()]
            agg_scores = [agg_scores[i] for i in unique_beam_dict.values()]

        # Get indices for top (config.n / config.beam_width) completions
        top_indices = np.argsort(np.array(agg_scores).flatten())[
            -(config.n // config.beam_width) :
        ]

        for idx, beam in enumerate(active_beams):
            if idx not in top_indices:
                # print(f"Pruning beam {idx} with score {agg_scores[idx]} at iteration {i}")
                beam.pruned = True
            # elif idx in top_indices:
                # print(f"Active beam {idx}, score {agg_scores[idx]}")
                # print (f"{beam.current_text}")

    # Filter completed beams for those with top config.n scores
    # print(f"We have exited the iteration loop, completed_beams has {len(completed_beams)} beams")
    if config.sort_completed:
        completed_beams = sorted(
            completed_beams,
            key=lambda b: aggregate_scores(b.all_scores, config.agg_strategy),
            reverse=True,
        )[: config.n]
    else:
        completed_beams = completed_beams[: config.n]
        # print(f"completed_beams has {len(completed_beams)} beams after filtering")

    if len(completed_beams) != config.n:
        # If we don't have enough completed_beams, duplicate until we reach config.n
        repeats = (config.n // len(completed_beams)) + 1
        logger.debug(
            f"Extending completed_beams with {repeats} repetitions to reach size {config.n}"
        )
        # print(
        #     f"We have exited the iteration loop, but we don't have enough completed_beams, extending completed_beams with {repeats} repetitions to reach size {config.n}"
        # )
        extended_completed_beams = [
            copy.deepcopy(b) for b in (completed_beams * repeats)[: config.n]
        ]
        completed_beams = extended_completed_beams

    return completed_beams


def beam_search(examples, config: Config, llm: LLM, prm: PRM):
    problems = examples["problem"]
    beam_results = _beam_search(problems, config, llm, prm)

    # Group together alike beams and store in the dataset
    grouped_results = defaultdict(list)
    for results in beam_results:
        grouped_results[results.prompt].append(results)

    results = {"completions": [], "pred": [], "completion_tokens": [], "scores": []}

    for p in problems:
        beams = grouped_results[p]
        completions = [b.current_text for b in beams]
        agg_scores = [
            aggregate_scores(b.all_scores, config.agg_strategy) for b in beams
        ]
        pred = completions[np.argmax(agg_scores)]
        results["completions"].append(completions)
        results["scores"].append([b.all_scores for b in beams])
        results["pred"].append(pred)
        results["completion_tokens"].append([b.completion_tokens for b in beams])

    return results