from vllm import LLM, SamplingParams

class VLLM_Generator:
    def __init__(self, model, config):
        self.config = config
        self.llm = LLM(
                model=model.name_or_path,
                gpu_memory_utilization=self.config.vllm_gpu_memory_utilization,
                dtype=self.config.vllm_dtype,                         # e.g. use fp16
                trust_remote_code=True,
                device = self.config.vllm_device,
                # When release by vLLM, we would be able to distribute the model on multiple GPUs
                # See https://github.com/vllm-project/vllm/pull/12071
                # tensor_parallel_size=torch.cuda.device_count(),
                # distributed_executor_backend="external_launcher",
            )
        self.sampling_params = SamplingParams(
            n = self.config.num_generations,
            max_tokens=self.config.max_completion_length,
            temperature=self.config.temperature,
            top_k = -1,                  # sample from all tokens
            top_p = 1.0,                # no nucleus
            repetition_penalty = 1.0,            # HF default
            logprobs=None
            )
    def move_model_to_vllm(self, model):
        if self.config.use_lora:
            model.merge_adapter()
            for name, param in model.named_parameters():
                name = name.removeprefix("base_model.model.").replace(".base_layer", "")
                if model.prefix in name:
                    continue
                # When module to save, remove its prefix and discard the original module
                if "original_module" in name:
                    continue
                name = name.replace("modules_to_save.default.", "")
                llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
                llm_model.load_weights([(name, param.data)])

            model.unmerge_adapter()
            model.print_trainable_parameters()
        else:
            llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
            llm_model.load_weights(model.state_dict().items())

        self.llm.reset_prefix_cache()
    
    def generate(self, prompt_token_ids):
        output = self.llm.generate(prompt_token_ids=prompt_token_ids,
                            sampling_params=self.sampling_params,
                            use_tqdm=False)
        return output