import numpy as np
import torch
from vllm import LLM, SamplingParams

from sal.config import Config
from sal.models.reward_models import PRM
from sal.utils.score import aggregate_scores
import torch.nn.functional as F


def adaptive_temperature(x, config: Config, llm: LLM, prm: PRM):
    tokenizer = llm.get_tokenizer()

    if config.custom_chat_template is not None:
        tokenizer.chat_template = config.custom_chat_template

    # Create conversations
    convs = [
        [
            {"role": "system", "content": config.system_prompt},
            {"role": "user", "content": prompt},
        ]
        for prompt in x["problem"]
    ]

    templated_convs = tokenizer.apply_chat_template(
        convs, tokenize=False, add_generation_prompt=True
    )

    num_problems = len(x["problem"])
    N1 = config.n1  # Step 1: candidate generation runs
    N2 = config.n2  # Step 5: final guided generation
    K = config.k  # Top-K for calibration set

    completions_per_prompt = []
    selected_calibrations = []
    temperatures = []
    final_completions = []
    final_scores = []
    final_completion_tokens = []

    # Step 1: Generate N1 completions per prompt
    templated_convs_N1 = [c for conv in templated_convs for c in [conv] * N1]
    sampling_params = SamplingParams(
        temperature=config.temperature,
        max_tokens=config.max_tokens,
        top_p=config.top_p,
        n=1,
    )

    responses = llm.generate(templated_convs_N1, sampling_params=sampling_params, use_tqdm=False)
    assert len(responses) == num_problems * N1

    for i in range(num_problems):
        prompt = x["problem"][i]
        completions = [
            output.text
            for r in responses[i * N1: (i + 1) * N1]
            for output in r.outputs
        ]
        scores = prm.score([prompt], [completions])[0]
        agg_scores = [aggregate_scores(s, config.agg_strategy) for s in scores]

        # Step 2: Select Top-K completions as calibration set
        topk_idx = np.argsort(agg_scores)[-K:]
        calib_texts = [completions[j] for j in topk_idx]

        # Step 3: Compute pseudo SFT loss (token-level NLL) to tune temperature
        input_ids = []
        labels = []
        for ans in calib_texts:
            encoded = tokenizer(ans, return_tensors="pt")
            input_ids.append(encoded["input_ids"])
            labels.append(encoded["input_ids"].clone())
        input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
        labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)

        # Dummy logits from base model
        with torch.no_grad():
            outputs = llm.model(input_ids.to(llm.device))
            logits = outputs.logits  # [B, L, V]

        # Step 4: Learn best temperature via grid search to minimize loss
        best_temp = 1.0
        best_loss = float("inf")
        for temp in np.linspace(0.3, 2.0, 20):
            scaled_logits = logits / temp
            loss = F.cross_entropy(
                scaled_logits.view(-1, scaled_logits.size(-1)),
                labels.to(llm.device).view(-1),
                ignore_index=-100,
                reduction="mean"
            )
            if loss.item() < best_loss:
                best_loss = loss.item()
                best_temp = temp

        temperatures.append(best_temp)

        # Step 5: Use learned Temp to sample N2 completions
        templated_prompt = templated_convs[i]
        prompt_repeated = [templated_prompt] * N2

        sampling_params = SamplingParams(
            temperature=best_temp,
            max_tokens=config.max_tokens,
            top_p=config.top_p,
            n=1,
        )

        final_resp = llm.generate(prompt_repeated, sampling_params=sampling_params, use_tqdm=False)
        final_comps = [output.text for r in final_resp for output in r.outputs]
        final_tok_lens = [len(output.token_ids) for r in final_resp for output in r.outputs]

        final_completions.append(final_comps)
        final_completion_tokens.append(final_tok_lens)

        score = prm.score([prompt], [final_comps])[0]
        agg = [aggregate_scores(s, config.agg_strategy) for s in score]
        best = final_comps[np.argmax(agg)]
        final_scores.append(score)
        completions_per_prompt.append(best)

    x["calibration_temperatures"] = temperatures
    x["final_completions"] = final_completions
    x["completion_tokens"] = final_completion_tokens
    x["scores"] = final_scores
    x["pred"] = completions_per_prompt

    return x
