import logging
import numpy as np
import torch
from vllm import LLM

from sal.config import Config
from sal.models.reward_models import load_prm, PRM
from sal.utils.data import get_dataset, save_dataset
from sal.utils.parser import H4ArgumentParser
from sal.utils.score import score, aggregate_scores
from sal.search.beam_search import beam_search as beam_search_with_scaling  # Use modified beam_search
from utils.save_mapping import load_temperature_dict_npz

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


def _get_temp_and_bias_for_batch(unique_ids, config, temperature_map, bias_map):
    temps, biases = [], []
    for uid in unique_ids:
        # temperature
        if not temperature_map:
            temperature = config.temperature
        else:
            temp_data = temperature_map.get(str(uid), {"temperature": config.temperature})
            temperature = temp_data["temperature"]
            if isinstance(temperature, np.ndarray):
                pass
            elif isinstance(temperature, list):
                temperature = np.array(temperature)
            elif np.isscalar(temperature):
                temperature = float(temperature)
            else:
                raise ValueError(f"Unknown temperature type: {type(temperature)}")

        # bias
        bias_data = bias_map.get(str(uid), None)
        if bias_data is not None and "bias" in bias_data:
            bias = bias_data["bias"]
            if isinstance(bias, list):
                bias = np.array(bias)
            bias = torch.from_numpy(bias).float()
        else:
            logger.warning(f"unique_id={uid} cannot find bias, bias_data={bias_data}")
            bias = None

        temps.append(temperature)
        biases.append(bias)
    return temps, biases


def generate_with_temperature_and_bias_beam(x, config: Config, llm: LLM, prm: PRM, temperature_map: dict, bias_map: dict):
    """
    Apply per-prompt ScalingProcessor based on unique_id and generate responses using beam search.
    Maintain the same return fields as generate_with_temperature_and_bias.py.
    """
    problems = x["problem"]
    unique_ids = x["unique_id"]

    # Record actual temperatures and biases used (one set per example)
    temperatures_used, biases_used = _get_temp_and_bias_for_batch(unique_ids, config, temperature_map, bias_map)

    # Call beam search with per-prompt logits processors support
    results = beam_search_with_scaling(
        {"problem": problems, "unique_id": unique_ids},
        config=config,
        llm=llm,
        prm=prm,
        temperature_map=temperature_map,
        bias_map=bias_map,
    )

    # Calculate agg_scores to maintain consistent output fields
    scores = results["scores"]  # [[beam_scores_of_example0], [beam_scores_of_example1], ...]
    agg_scores = [
        [aggregate_scores(s, config.agg_strategy) for s in beam_scores]
        for beam_scores in scores
    ]

    x["completions"] = results["completions"]
    x["scores"] = scores
    x["agg_scores"] = agg_scores
    x["completion_tokens"] = results["completion_tokens"]
    x["temperatures_used"] = temperatures_used
    x["biases_used"] = biases_used
    x["pred"] = results["pred"]
    return x


def main():
    parser = H4ArgumentParser(Config)
    config = parser.parse()

    # Temperature mapping
    if hasattr(config, 'temperature_file_path') and config.temperature_file_path:
        logger.info(f"Loading temperature mapping from: {config.temperature_file_path}")
        temperature_map = load_temperature_dict_npz(config.temperature_file_path)
        logger.info(f"Loaded {len(temperature_map)} temperature entries")
    else:
        logger.info(f"No temperature mapping file provided, using fixed temperature for all (temperature = {config.temperature}).")
        temperature_map = {}

    # Bias mapping
    logger.info(f"Loading bias mapping from: {config.bias_file_path}")
    bias_map = load_temperature_dict_npz(config.bias_file_path)
    logger.info(f"Loaded {len(bias_map)} bias entries")

    # LLM
    logger.info(f"Initializing LLM from: {config.model_path}")
    num_gpus = torch.cuda.device_count()
    llm = LLM(
        model=config.model_path,
        gpu_memory_utilization=config.gpu_memory_utilization,
        enable_prefix_caching=True,
        seed=config.seed,
        tensor_parallel_size=num_gpus,
    )

    # PRM
    logger.info(f"Loading PRM from: {config.prm_path}")
    prm = load_prm(config)

    # Dataset
    logger.info("Loading dataset...")
    dataset = get_dataset(config)

    if hasattr(config, 'num_samples') and config.num_samples is not None:
        logger.info(f"Limiting dataset to first {config.num_samples} samples")
        dataset = dataset.select(range(min(config.num_samples, len(dataset))))

    if "unique_id" not in dataset.column_names:
        logger.warning("Dataset does not contain 'unique_id' column. Adding index-based unique_id.")
        dataset = dataset.add_column("unique_id", list(range(len(dataset))))

    logger.info("Starting calibrated beam-search generation...")
    dataset = dataset.map(
        generate_with_temperature_and_bias_beam,
        batched=True,
        batch_size=config.calib_batch_size,
        fn_kwargs={"config": config, "llm": llm, "prm": prm, "temperature_map": temperature_map, "bias_map": bias_map},
        desc="Generating with calibrated delta and temperatures (beam search)",
        load_from_cache_file=False,
    )

    dataset = score(dataset, config)
    save_dataset(dataset, config)
    logger.info("Done 🔥!")


if __name__ == "__main__":
    main()