import os 
from vllm import LLM, SamplingParams
from standardrr import StandardRR
from transformers import AutoTokenizer 


os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
os.environ["VLLM_ALLOW_LONG_MAX_MODEL_LEN"] = "1"

class Rank1_NoReason(StandardRR):
    def __init__(
        self,
        base_model_name_or_path: str,
        batch_size: int = 999999999999,
        context_size: int = 16000,
        max_output_tokens: int = 8192,
        fp_options: str = "float16",
        num_gpus: int = 1,
        device: str = "cuda",
        dataset_prompt: str = None,
    ):

        self.context_size = context_size
        self.max_output_tokens = max_output_tokens
        self.num_gpus = num_gpus
        self.device = device
        self.dataset_prompt = dataset_prompt
        self.tokenizer = AutoTokenizer.from_pretrained(base_model_name_or_path)
        self.tokenizer.padding_side = "left"
        self.tokenizer.pad_token = self.tokenizer.eos_token
        
        # Overwrite the true/false tokens for the sampling_params
        self.true_token = self.tokenizer(" true", add_special_tokens=False).input_ids[0]
        self.false_token = self.tokenizer(" false", add_special_tokens=False).input_ids[0]
        
        self.model = LLM(
            model=base_model_name_or_path,
            tensor_parallel_size=int(num_gpus),
            trust_remote_code=True,
            max_model_len=context_size,
            dtype=fp_options,
            gpu_memory_utilization=0.9,
            enforce_eager=True,
        )
        self.sampling_params = SamplingParams(
            temperature=0,
            max_tokens=1,
            logprobs=20,
            allowed_token_ids=[self.true_token, self.false_token],
            skip_special_tokens=False
        )

    def _generate_model_outputs(self, prompts):
        return self.model.generate(prompts, self.sampling_params)

    def return_prompt(self, query, doc_content, prompt) -> str:
        query = prompt.replace("FILL_QUERY_HERE", query) if prompt else query
        return "Determine if the following passage is relevant to the query. " \
                "Answer only with 'true' or 'false'.\n" \
                f"Query: {query}\n" \
                f"Passage: {doc_content}\n" \
                "<think>\n" \
                "Okay, I have finished thinking.\n" \
                "</think>"