#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
from vllm import LLM, SamplingParams

from sal.config import Config
from sal.models.reward_models import PRM
from sal.utils.score import aggregate_scores
from sal.models.reward_models import load_prm

import gc
import torch
import copy


def best_of_n(x, config: Config, llm: LLM, prm: PRM):
    num_gpus = torch.cuda.device_count()
    if llm is None:
        if "32" or "14" in config.model_path:
            llm = LLM(
                model=config.model_path,
                gpu_memory_utilization=0.95,
                max_model_len=20324,
                # enable_prefix_caching=True,
                enforce_eager=True,
                seed=config.seed,
                tensor_parallel_size=1,
                device="cuda:0",
            )
        elif "Qwen" in config.model_path:
            llm = LLM(
                model=config.model_path,
                gpu_memory_utilization=config.gpu_memory_utilization if num_gpus <= 1 else 0.7,
                max_model_len=32768,
                # enable_prefix_caching=True,
                enforce_eager=False,
                seed=config.seed,
                tensor_parallel_size=1,
                device="cuda:0",
            )
        else:
            llm = LLM(
                model=config.model_path,
                gpu_memory_utilization=config.gpu_memory_utilization if num_gpus <= 1 else 0.7,
                max_model_len=8192,
                enable_prefix_caching=True,
                seed=config.seed,
                tensor_parallel_size=1,
                device="cuda:0",
            )


    tokenizer = llm.get_tokenizer()

    convs = [
        [
            {"role": "system", "content": config.system_prompt},
            {"role": "user", "content": prompt},
        ]
        for prompt in x["problem"]
    ]
    tokenizer = llm.get_tokenizer()
    # TODO: set the augmented template from a file
    if config.custom_chat_template is not None:
        tokenizer.chat_template = config.custom_chat_template
    templated_convs = tokenizer.apply_chat_template(
        convs, tokenize=False, add_generation_prompt=True
    )

    # Duplicate convs to generate config.n completions per prompt so we can do continous batching
    # This makes [p1, p2, p3, p4] become [p1, p1, p2, p2, p3, p3, p4, p4] for e.g. config.n=2
    templated_convs = [c for conv in templated_convs for c in [conv] * config.n]

    # Initialize empty lists for completions and completion tokens
    completions = [[] for _ in range(len(x["problem"]))]
    completion_tokens = [[] for _ in range(len(x["problem"]))]

    sampling_params = SamplingParams(
        temperature=config.temperature,
        max_tokens=config.max_tokens,
        top_p=config.top_p,
        top_k=config.top_k,
        n=1,  # Since we've already duplicated the prompt_token_ids, we only need to generate 1 completion per prompt
    )

    # ──  First pass – identical to your original call ───────────────────────
    responses = llm.generate(
        templated_convs,
        sampling_params=sampling_params,
        use_tqdm=False,
    )

    expected = len(x["problem"]) * config.n
    if len(responses) != expected:
        raise ValueError(f"Generated {len(responses)} responses instead of {expected}")

    # ──  Collect every generation that finished prematurely ────────────────
    unfinished_idxs = [
        idx
        for idx, r in enumerate(responses)
        if r.outputs[0].finish_reason not in {"stop", "eos_token"}    # adjust set if needed
    ]

    print('first round of generation finished')

    # ── Pad, bump the budget by +1 000, and generate continuations ─────────
    if unfinished_idxs:
        continued_convs, base_outputs = [], []

        for idx in unfinished_idxs:
            conv = templated_convs[idx]

            base_out = responses[idx].outputs[0]   # pointer to the *original* output object
            base_outputs.append(base_out)

            partial_text = base_out.text
            if config.end_think_token not in partial_text:
                partial_text += config.end_think_token

            # Feed conversation back with the assistant’s partial reply attached
            continued_convs.append(conv + partial_text)

        # Clone / tweak the sampling parameters
        extra_sampling = copy.deepcopy(sampling_params)
        if hasattr(extra_sampling, "max_tokens"):
            extra_sampling.max_tokens = 1000
        else:
            raise AttributeError("sampling_params has no max-token field")

        # Second inference round (only for unfinished items)
        extra_responses = llm.generate(
            continued_convs,
            sampling_params=extra_sampling,
            use_tqdm=False,
        )

        # ──Stitch: append continuation to the originals in-place ──────────
        for cont_resp, base_out in zip(extra_responses, base_outputs):
            cont_out = cont_resp.outputs[0]            # first (and only) beam
            base_out.text      += cont_out.text
            base_out.token_ids += cont_out.token_ids
            base_out.finish_reason = cont_out.finish_reason

    # ── Final bookkeeping (unchanged) ─────────────────────────────────────
    for i in range(len(completions)):
        completions[i] = [
            o.text
            for r in responses[i * config.n : (i + 1) * config.n]
            for o in r.outputs
        ]
        completion_tokens[i] = [
            len(o.token_ids)
            for r in responses[i * config.n : (i + 1) * config.n]
            for o in r.outputs
        ]

    print('generated completions')

    # Check we generated the correct number of completions for each prompt
    for c in completions:
        if len(c) != config.n:
            raise ValueError(f"Generated {len(c)} completions instead of {config.n}")
    
    if prm is None:
        del llm
        gc.collect()
        torch.cuda.empty_cache()
        prm = load_prm(config)

    scores = prm.score(x["problem"], completions)
    agg_scores = [
        [aggregate_scores(s, config.agg_strategy) for s in score] for score in scores
    ]

    # Select the completion with the highest score
    pred = [completion[np.argmax(s)] for completion, s in zip(completions, agg_scores)]

    x["completions"] = completions
    x["scores"] = scores
    x["pred"] = pred
    x["completion_tokens"] = completion_tokens

    return x