import torch

def generate_multiple_samples(model, tokenizer, prompt, n=20, max_new_tokens=128, temperature=0.7, top_p=0.9):
    device = model.device
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(device)
    input_len = inputs["input_ids"].shape[-1]

    outputs = []
    for _ in range(n):
        with torch.no_grad():
            output_ids = model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=True,
                top_p=top_p,
                temperature=temperature,
                pad_token_id=tokenizer.eos_token_id
            )
        decoded = tokenizer.decode(output_ids[0][input_len:], skip_special_tokens=True).strip()
        decoded = decoded.split("6. ")[-1].strip()
        outputs.append(decoded)

    return outputs

def get_generation_config(dataset_name, model_name):
    """
    Return generation config and stop phrases based on dataset and model type.
    """

    # Default values
    config = {
        "max_new_tokens": 64,
        "do_sample": False,
        "temperature": 0.0,
        "stop_phrases": ["Q:", "[END]", "Question:", "Answer:", "Problem:", "Explanation:", "####"]
    }

    # Customize for specific datasets
    if dataset_name.lower() == "mbpp":
        config["max_new_tokens"] = 150
        config["stop_phrases"] += ["```", "#", "assert", "unittest"]
    elif dataset_name.lower() == "gsm8k":
        config["max_new_tokens"] = 96
        config["stop_phrases"] += ["Let's solve", "Explanation:", "####"]
    elif dataset_name.lower() == "bbh":
        config["max_new_tokens"] = 64
        config["stop_phrases"] += ["(A)", "(B)", "(C)", "(D)", "(E)", "(F)"]

    # Customize for specific models if needed
    if "gemma" in model_name.lower():
        config["temperature"] = 0.2
    elif "llama" in model_name.lower():
        config["temperature"] = 0.0

    return config
