from typing import Any
import torch
from vllm import LLM
from vllm.engine.arg_utils import PoolerConfig
from vllm.model_executor.layers.pooler import PoolingType
from transformers import AutoTokenizer, AutoModelForCausalLM


class Id:
    def __call__(self, hidden_states) -> Any:
        print("CALLED")
        return hidden_states


def main():
    device = torch.device("cuda:0")
    model_id = "unsloth/Llama-3.2-3B-Instruct"
    llm = LLM(
        model_id,
        task="embed",
        tensor_parallel_size=1,
        override_pooler_config=PoolerConfig(
            pooling_type="ALL",
            normalize=False,
        ),
        enforce_eager=True,
    )
    print(llm.llm_engine.model_executor.driver_worker.worker.model_runner.model)
    llm.llm_engine.model_executor.driver_worker.worker.model_runner.model.norm = Id()
    llm_ref = AutoModelForCausalLM.from_pretrained(model_id, device_map=device)
    tokenizer = AutoTokenizer.from_pretrained(model_id)

    prompts_text = [
        "This is a pretty long prompt with plenty of words",
        # "This is another set of words",
    ]

    tokens_to_pad = [
        tokenizer(prompt, return_tensors="pt")["input_ids"].flatten()
        for prompt in prompts_text
    ]
    vllm_inputs = [
        dict(prompt_token_ids=input_ids.tolist())
        for input_ids in tokens_to_pad
    ]
    inputs = tokenizer.pad(dict(input_ids=tokens_to_pad), return_tensors='pt')
    inputs = inputs.to(device)

    outputs = llm.encode(vllm_inputs)
    outputs = [output.outputs.data for output in outputs]

    outputs_ref = llm_ref.model(**inputs).last_hidden_state.cpu()

    print(outputs[0])
    print(outputs_ref[0])


if __name__ == "__main__":
    main()
