import logging
import numpy as np
import torch
from vllm import LLM, SamplingParams

from sal.config import Config
from sal.models.reward_models import load_prm, PRM
from sal.utils.data import get_dataset, save_dataset
from sal.utils.parser import H4ArgumentParser
from sal.utils.score import score, aggregate_scores
from search.mapping.scaling_processor import ScalingProcessor
from utils.save_mapping import load_temperature_dict_npz

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


def generate_with_temperature_and_bias(x, config: Config, llm: LLM, prm: PRM, temperature_map: dict, bias_map: dict):
    """
    Use vLLM to generate answers for a batch of problems with a mapping function.

    Args:
        x: Batch data containing 'problem' and 'unique_id' fields
        config: Configuration parameters
        llm: vLLM model
        prm: Process Reward Model
        temperature_map: Mapping from unique_id to temperature value
        bias_map: Mapping from unique_id to bias value
    """
    tokenizer = llm.get_tokenizer()
    
    if config.custom_chat_template is not None:
        tokenizer.chat_template = config.custom_chat_template
    
    convs = [
        [
            {"role": "system", "content": config.system_prompt},
            {"role": "user", "content": prompt},
        ]
        for prompt in x["problem"]
    ]
    
    templated_convs = tokenizer.apply_chat_template(
        convs, tokenize=False, add_generation_prompt=True
    )
    
    all_completions = []
    all_completion_tokens = []
    all_temperatures_used = []
    all_biases_used = []

    for i, (problem, unique_id, templated_conv) in enumerate(zip(x["problem"], x["unique_id"], templated_convs)):
        # temp
        if not temperature_map:
            temperature = config.temperature
        else:
            temp_data = temperature_map.get(str(unique_id), {"temperature": config.temperature})
            temperature = temp_data["temperature"]
            if isinstance(temperature, np.ndarray):
                pass
            elif isinstance(temperature, list):
                temperature = np.array(temperature)
            elif np.isscalar(temperature):
                temperature = float(temperature)
            else:
                raise ValueError(f"Unknown temperature type: {type(temperature)}")
        all_temperatures_used.append(temperature)

        # bias 
        bias_data = bias_map.get(str(unique_id), None)
        if bias_data is not None and "bias" in bias_data:
            bias = bias_data["bias"]
            if isinstance(bias, list):
                bias = np.array(bias)
            bias = torch.from_numpy(bias).float()
        else:
            logger.warning(f"unique_id={unique_id} cannot find bias, bias_data={bias_data}")
            bias = None
        all_biases_used.append(bias if bias is not None else None)

        logger.debug(f"Problem {i+1}/{len(x['problem'])}: unique_id={unique_id}, temperature={temperature}, bias={'set' if bias is not None else 'None'}")

        # build logits processor 
        scaling_processor = ScalingProcessor(temp=temperature, bias=bias, add_bias_first=True)

        # set up sampling parameters
        sampling_params = SamplingParams(
            temperature=1.0,  # base temperature for vLLM, mainly controled by the logits processor
            max_tokens=config.max_tokens,
            top_p=config.top_p,
            n=config.n2,  # generate n2 completions
            logits_processors=[scaling_processor],
        )

        # generate responses using the LLM
        responses = llm.generate([templated_conv], sampling_params=sampling_params, use_tqdm=False)

        completions = [output.text for output in responses[0].outputs]
        completion_tokens = [len(output.token_ids) for output in responses[0].outputs]

        if len(completions) != config.n2:
            raise ValueError(f"Generated {len(completions)} completions instead of {config.n2}")

        all_completions.append(completions)
        all_completion_tokens.append(completion_tokens)
    
    # PRM judge the generated completions
    scores = prm.score(x["problem"], all_completions)
    
    # calculate aggregated scores based on the specified aggregation strategy
    agg_scores = [
        [aggregate_scores(s, config.agg_strategy) for s in score] for score in scores
    ]
    
    pred = [completion[np.argmax(s)] for completion, s in zip(all_completions, agg_scores)]
    
    x["completions"] = all_completions
    x["scores"] = scores
    x["agg_scores"] = agg_scores
    x["completion_tokens"] = all_completion_tokens
    x["temperatures_used"] = all_temperatures_used
    x["biases_used"] = all_biases_used
    x["pred"] = pred
    return x


def main():
    parser = H4ArgumentParser(Config)
    config = parser.parse()
    
    # load temperature mapping（allow temperature_file_path is None, use config.temperature）
    if hasattr(config, 'temperature_file_path') and config.temperature_file_path:
        logger.info(f"Loading temperature mapping from: {config.temperature_file_path}")
        temperature_map = load_temperature_dict_npz(config.temperature_file_path)
        logger.info(f"Loaded {len(temperature_map)} temperature entries")
    else:
        logger.info(f"No temperature mapping file provided, using fixed temperature for all (temperature = {config.temperature}).")
        temperature_map = {}  # Let downstream fallback to config.temperature
    # load bias mapping
    logger.info(f"Loading bias mapping from: {config.bias_file_path}")
    bias_map = load_temperature_dict_npz(config.bias_file_path)
    logger.info(f"Loaded {len(bias_map)} bias entries")
    
    # init LLM
    logger.info(f"Initializing LLM from: {config.model_path}")
    num_gpus = torch.cuda.device_count()
    llm = LLM(
        model=config.model_path,
        gpu_memory_utilization=config.gpu_memory_utilization,
        enable_prefix_caching=True,
        seed=config.seed,
        tensor_parallel_size=num_gpus,
    )
    
    # PRM
    logger.info(f"Loading PRM from: {config.prm_path}")
    prm = load_prm(config)
    
    # dataset
    logger.info("Loading dataset...")
    dataset = get_dataset(config)
    
    # limit dataset size if specified
    if hasattr(config, 'num_samples') and config.num_samples is not None:
        logger.info(f"Limiting dataset to first {config.num_samples} samples")
        dataset = dataset.select(range(min(config.num_samples, len(dataset))))
    
    if "unique_id" not in dataset.column_names:
        logger.warning("Dataset does not contain 'unique_id' column. Adding index-based unique_id.")
        dataset = dataset.add_column("unique_id", list(range(len(dataset))))
    
    
    logger.info("Starting calibrated generation...")
    dataset = dataset.map(
        generate_with_temperature_and_bias,
        batched=True,
        batch_size=config.calib_batch_size,
        fn_kwargs={"config": config, "llm": llm, "prm": prm, "temperature_map": temperature_map, "bias_map": bias_map},
        desc="Generating with calibrated temperatures",
        load_from_cache_file=False,
    )
    
    dataset = score(dataset, config)
    
    save_dataset(dataset, config)
    logger.info("Done 🔥!")


if __name__ == "__main__":
    main()
