import torch
from typing import List, Dict
from vllm import SamplingParams
print(torch.cuda.device_count(), "GPUs available")
device = 'cuda' if torch.cuda.is_available() else 'cpu'
import random
import numpy as np
import torch

# Set a global seed for other libraries
seed = 42
random.seed(seed)
np.random.seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)
def base_model_chat_template(messages, prompt=None, retro=False, tokenizer=None):
    current = ""
    if prompt:
        if isinstance(messages[0], dict):
            # this is inference case
            result = prompt.format(question=messages[0]['content'])
        else:
            if len(messages) == 2: # this is feedback case
                result = prompt.format(question=messages[0].strip(), initial_response=messages[1].strip())
            else: # this is refine case
                result = prompt.format(question=messages[0].strip(), initial_response=messages[1].strip(), feedback=messages[2].strip())
        return result
    for turn in messages:
        # this is chat case
        if turn["role"] == "user":
            current += f"User: {turn['content']}\n"
        elif turn["role"] == "assistant":
            current += f"Assistant: {turn['content']}"
    if not retro:
        current += "Assistant:"
    if tokenizer:
        current = tokenizer.encode(current, return_tensors="pt")
    return current

def model_inference_batch_vllm(
    model,
    tokenizer,
    messages_batch: List[List[Dict[str, str]]], 
    temperature: float = 0,
    max_new_tokens: int = 3000,
    retro=False,
    prompt=None
) -> List[str]:
    result = []
    
    sampling_params = SamplingParams(temperature=temperature, max_tokens=max_new_tokens, top_k=50, stop=["Problem:"], seed=72)
    processed_messages = []
    for messages in messages_batch:
        if "Instruct" not in model.llm_engine.model_config.model:
            # we are dealing with a base model
            text = base_model_chat_template(messages, prompt=prompt, retro=retro, tokenizer=None)
        # Prepare all inputs in batch
        else:
            if retro:
                text = tokenizer.apply_chat_template(
                    messages,
                    tokenize=False,
                    continue_final_message=True
                )
            else:
                text = tokenizer.apply_chat_template(
                    messages,
                    tokenize=False,
                    add_generation_prompt=True
                )
        processed_messages.append(text)
    print(processed_messages[0])
    print("--------------------------------")
    print(processed_messages[1])
    outputs = model.generate(processed_messages, sampling_params)
    for output in outputs:
        generated_text = output.outputs[0].text
        result.append(generated_text)
    return result
        
def model_inference_batch_vllm_aio(
    model,
    tokenizer,
    messages_batch: List[List[Dict[str, str]]], 
    temperature: float = 0,
    max_new_tokens: int = 3000,
    retro=False,
    batch_size: int = 999999,
    prompt=None
) -> List[str]:
    result = []
    sampling_params = SamplingParams(temperature=temperature, max_tokens=max_new_tokens, top_k=50, stop=["Problem:"], seed=72)
    

    # Prepare batch inputs
    processed_messages = []
    for messages in messages_batch:
        if "Instruct" not in model.llm_engine.model_config.model:
            # we are dealing with a base model
            text = base_model_chat_template(messages, prompt=prompt, retro=retro, tokenizer=None)
        else:
            # we are dealing with a instruct model
            if retro:
                text = tokenizer.apply_chat_template(
                    messages,
                    tokenize=False,
                    continue_final_message=True
                )
            else:
                text = tokenizer.apply_chat_template(
                    messages,
                    tokenize=False,
                    add_generation_prompt=True
                )
        processed_messages.append(text)

    # Generate for this batch
    outputs = model.generate(processed_messages, sampling_params)
    
    # Collect results
    for output in outputs:
        generated_text = output.outputs[0].text
        result.append(generated_text)
    
    return result

def base_model_inference_aio(
    model,
    messages_batch: List[str], 
    temperature: float = 0,
    max_new_tokens: int = 3000,
    batch_size: int = 999999999,
    stop_tokens: List[str] = ["Problem:"]
) -> List[str]:
    result = []
    sampling_params = SamplingParams(temperature=temperature, max_tokens=max_new_tokens, top_k=50, stop=stop_tokens, seed=72)
    
    # Process messages in batches
    for batch_start in range(0, len(messages_batch), batch_size):
        batch_end = min(batch_start + batch_size, len(messages_batch))
        batch_subset = messages_batch[batch_start:batch_end]
        # Generate for this batch
        outputs = model.generate(batch_subset, sampling_params)
        # Collect results
        for output in outputs:
            generated_text = output.outputs[0].text
            result.append(generated_text)
    
    return result
