#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
from vllm import LLM, SamplingParams

from sal.config import Config
from sal.models.reward_models import PRM
from sal.utils.score import aggregate_scores

import gc
import torch

from sal.models.reward_models import load_prm


def id_best_of_n(x, config: Config, llm: LLM, prm: PRM):
    num_gpus = torch.cuda.device_count()
    if llm is None:
        if "32B" or "14B" in config.model_path:
            llm = LLM(
                model=config.model_path,
                gpu_memory_utilization=0.95,
                max_model_len=20324,
                # enable_prefix_caching=True,
                enforce_eager=True,
                seed=config.seed,
                tensor_parallel_size=1,
                device="cuda:0",
            )
        elif "Qwen" in config.model_path:
            llm = LLM(
                model=config.model_path,
                gpu_memory_utilization=config.gpu_memory_utilization if num_gpus <= 1 else 0.7,
                max_model_len=32768,
                # enable_prefix_caching=True,
                enforce_eager=False,
                seed=config.seed,
                tensor_parallel_size=1,
                device="cuda:0",
            )
        else:
            llm = LLM(
                model=config.model_path,
                gpu_memory_utilization=config.gpu_memory_utilization if num_gpus <= 1 else 0.7,
                max_model_len=8192,
                enable_prefix_caching=True,
                seed=config.seed,
                tensor_parallel_size=1,
                device="cuda:0",
            )

    tokenizer = llm.get_tokenizer()

    cur_budget = config.initial_budget
    budget_factor = config.id_gamma # multiply the budget by budget_factor after each iteration so the the response could be longer

    convs = [
        [
            {"role": "system", "content": config.system_prompt},
            {"role": "user", "content": prompt},
        ]
        for prompt in x["problem"]
    ]
    tokenizer = llm.get_tokenizer()
    # TODO: set the augmented template from a file
    if config.custom_chat_template is not None:
        tokenizer.chat_template = config.custom_chat_template
    templated_convs = tokenizer.apply_chat_template(
        convs, tokenize=False, add_generation_prompt=True
    )

    # Duplicate convs to generate config.n completions per prompt so we can do continous batching
    # This makes [p1, p2, p3, p4] become [p1, p1, p2, p2, p3, p3, p4, p4] for e.g. config.n=2
    templated_convs = [c for conv in templated_convs for c in [conv] * config.n]

    # Initialize empty lists for completions and completion tokens
    completions = [["" for __ in range(config.n) ] for _ in range(len(x["problem"]))]
    completion_tokens = [[0 for __ in range(config.n)] for _ in range(len(x["problem"]))]
    completions_ids = [(i, j) for i in range(len(x["problem"])) for j in range(config.n)]

    sampling_params = SamplingParams(
        temperature=config.temperature,
        max_tokens=config.initial_budget,
        top_p=config.top_p,
        top_k=config.top_k,
        n=1,  # Since we've already duplicated the prompt_token_ids, we only need to generate 1 completion per prompt
    )
    
    # reflection_prompt = "Wait! Maybe I made some mistakes! I need to rethink from scratch.\n"
    reflection_prompt = "Wait"
    reflection_prompt_token_ids = tokenizer.encode(reflection_prompt)

    node_split_string = "\n"

    for i in range(len(templated_convs)):
        templated_convs[i] = {
            "prompt_token_ids": tokenizer.encode(templated_convs[i])
        }
    max_budget = config.max_tokens
    # if max_budget > 12000:
        # max_budget = 10000  # To avoid OOM for very large budgets, set a max budget to prevent it from going too high

    while cur_budget <= max_budget:

        new_convs = []

        new_responses = llm.generate(
            templated_convs,
            sampling_params=sampling_params,
            use_tqdm=False,
        )

        # update the budget
        cur_budget *= budget_factor
        cur_budget = (int)(cur_budget)  # Ensure cur_budget is an integer for max_tokens
        sampling_params.max_tokens = cur_budget

        new_completions_ids = []

        for i in range(len(completions_ids)):
            # concatenate the completions with the new responses
            id_i, id_j = completions_ids[i]

            # depending on whether the current response if finished or not, we may need to add the reflection prompt
            if new_responses[i].outputs[0].finish_reason == "length":
                new_completions_ids.append(completions_ids[i])

                if cur_budget <= max_budget:
                    # keep the generated response before the last node split string
                    splitted_text = new_responses[i].outputs[0].text.split(node_split_string)
                    string_to_add = node_split_string.join(splitted_text[:-1] + [reflection_prompt])
                else:
                    # check whether end think token is already generated, if not, add it
                    if config.end_think_token in new_responses[i].outputs[0].text:
                        string_to_add = new_responses[i].outputs[0].text
                    else:
                        # add the reflection prompt to the end of the response
                        string_to_add = new_responses[i].outputs[0].text + config.end_think_token

                completions[id_i][id_j] += string_to_add
                new_token_list = tokenizer.encode(string_to_add)
                completion_tokens[id_i][id_j] += len(new_token_list)

                new_convs.append({
                    "prompt_token_ids": templated_convs[i]["prompt_token_ids"] + new_token_list
                })
                
            else:
                assert new_responses[i].outputs[0].finish_reason == "stop"
                completions[id_i][id_j] += new_responses[i].outputs[0].text
                completion_tokens[id_i][id_j] += len(new_responses[i].outputs[0].token_ids)

        # update the completions_ids
        completions_ids = new_completions_ids

        # update the templated_convs
        templated_convs = new_convs

        if len(completions_ids) == 0:
            break
    
    # Check if the previous loop is terminated by the max budget or not, if yes, run one more time inference, and save the results
    if len(completions_ids) > 0:
        sampling_params.max_tokens = max_budget + 1000  # add some extra tokens to avoid OOM

        new_responses = llm.generate(
            templated_convs,
            sampling_params=sampling_params,
            use_tqdm=False,
        )

        for i in range(len(completions_ids)):
            # concatenate the completions with the new responses
            id_i, id_j = completions_ids[i]

            completions[id_i][id_j] += new_responses[i].outputs[0].text
            completion_tokens[id_i][id_j] += len(new_responses[i].outputs[0].token_ids)

    # Check we generated the correct number of completions for each prompt
    for c in completions:
        if len(c) != config.n:
            raise ValueError(f"Generated {len(c)} completions instead of {config.n}")
    
    print("Generated completions for all prompts")

    if prm is None:
        del llm
        gc.collect()
        torch.cuda.empty_cache()
        prm = load_prm(config)

    scores = prm.score(x["problem"], completions)
    agg_scores = [
        [aggregate_scores(s, config.agg_strategy) for s in score] for score in scores
    ]

    # Select the completion with the highest score
    pred = [completion[np.argmax(s)] for completion, s in zip(completions, agg_scores)]

    x["completions"] = completions
    x["scores"] = scores
    x["pred"] = pred
    x["completion_tokens"] = completion_tokens

    return x
