from transformers import (
    Qwen2_5_VLForConditionalGeneration,
    AutoProcessor,
    AutoTokenizer,
    AutoModelForCausalLM,
)
from qwen_vl_utils import process_vision_info
import torch
import os
from typing import Tuple, Dict

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

def load_vlm(model_id: str) -> Tuple[Qwen2_5_VLForConditionalGeneration, AutoProcessor]:
    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        model_id,
        torch_dtype=torch.bfloat16,
        # attn_implementation="flash_attention_2",
        device_map="auto",
    )
    model.generation_config.do_sample = False
    model.generation_config.top_p = 1.0
    model.generation_config.top_k = 0
    model.generation_config.temperature = 1.0
    model.generation_config.repetition_penalty = 1.0
    model.generation_config.no_repeat_ngram_size = 0

    processor = AutoProcessor.from_pretrained(model_id)
    return model, processor


def load_llm(
    model_id: str = "Qwen/Qwen3-1.7B",
) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForCausalLM.from_pretrained(
        model_id, torch_dtype=torch.bfloat16, device_map="auto"
    )
    return model, tokenizer


def text_image_to_vlm_messages(image_path, text) -> list:
    return [
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                    "image": image_path,
                },
                {"type": "text", "text": text},
            ],
        }
    ]


def message_to_vlm_inputs(
    processor: AutoProcessor, messages: list
) -> Dict[str, torch.Tensor]:
    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    inputs = inputs.to("cuda")
    return inputs


def get_vlm_response(
    model: Qwen2_5_VLForConditionalGeneration,
    processor: AutoProcessor,
    inputs: Dict[str, torch.Tensor],
) -> Tuple[str, torch.Tensor]:
    with torch.inference_mode():
        gen_out = model.generate(
            **inputs,
            max_new_tokens=1024,
            do_sample=False,
            return_dict_in_generate=True,
            output_scores=False,
        )

    generated_seqs = gen_out.sequences
    gen_only_ids = generated_seqs[:, inputs["input_ids"].shape[-1] :]
    output_text = processor.batch_decode(
        gen_only_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )[0]
    return output_text, generated_seqs


def get_llm_response(
    chat_messages: list, text_tokenizer: AutoTokenizer, text_model: AutoModelForCausalLM
) -> None:
    text_inputs = text_tokenizer.apply_chat_template(
        chat_messages,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt",
        enable_thinking=False,
    )
    text_inputs = text_inputs.to(text_model.device)

    with torch.inference_mode():
        text_gen = text_model.generate(
            text_inputs, max_new_tokens=1024, do_sample=False, temperature=None
        )

    gen_trimmed = text_gen[:, text_inputs.shape[-1] :]
    text_answer = text_tokenizer.batch_decode(
        gen_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )[0]

    return text_answer


def llm_messages_append_text(
    initial_messages: list, vl_output_text: str, question: str
) -> list:
    # Extract original user text from the initial multimodal message
    original_user_text = None
    for content in initial_messages[0]["content"]:
        if isinstance(content, dict) and content.get("type") == "text":
            original_user_text = content.get("text")
            break

    chat_messages = [
        {"role": "user", "content": original_user_text},
        {"role": "assistant", "content": vl_output_text},
        {"role": "user", "content": question},
    ]
    return chat_messages


def stage2_standard_generate(
    model: Qwen2_5_VLForConditionalGeneration, generated_seqs: torch.Tensor
) -> None:
    """
    Stage 2: Use standard generation. We take the full sequence (prompt + first response)
    as input_ids and perform a standard generate step. We only generate 1 token to
    exercise the prefill via the standard API, then report the prefilled length.
    """
    full_input_ids = generated_seqs
    full_attention_mask = torch.ones_like(full_input_ids, dtype=torch.long)
    with torch.inference_mode():
        _ = model.generate(
            input_ids=full_input_ids,
            attention_mask=full_attention_mask,
            max_new_tokens=1,
            do_sample=False,
            return_dict_in_generate=True,
        )
    print(
        "[Stage 2: standard generate] full_input_ids length:",
        int(full_input_ids.shape[-1]),
    )


def run_demo(image_path: str, question: str, vlm_id: str, llm_id: str) -> str:
    vlm_model, vlm_processor = load_vlm(vlm_id)
    llm_model, llm_tokenizer = load_llm(llm_id)

    vlm_query = "Briefly describe the image."
    vlm_messages = text_image_to_vlm_messages(image_path, vlm_query)
    vlm_inputs = message_to_vlm_inputs(vlm_processor, vlm_messages)

    vlm_text, vlm_generated_seqs = get_vlm_response(vlm_model, vlm_processor, vlm_inputs)
    # stage2_standard_generate(vlm_model, vlm_generated_seqs)
    print(vlm_text)

    llm_messages = llm_messages_append_text(vlm_messages, vlm_text, question)
    llm_answer = get_llm_response(llm_messages, llm_tokenizer, llm_model)
    return llm_answer


if __name__ == "__main__":
    vlm_id = "Qwen/Qwen2.5-VL-3B-Instruct"
    llm_id = "Qwen/Qwen3-1.7B"
    image_path = "local/example/mmmu_example.png"
    question = "Which of the following best explains the overall trend shown in the <image>?"
    option_1 = "Migrations to areas of Central Asia for resettlement"
    option_2 = "The spread of pathogens across the Silk Road"
    option_3 = "Invasions by Mongol tribes"
    option_4 = "Large-scale famine due to crop failures"

    prompt = (
        f"{question}\n"
        f"A. {option_1}\n"
        f"B. {option_2}\n"
        f"C. {option_3}\n"
        f"D. {option_4}\n"
        "Answer the preceding multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of options. Think step by step before answering."
    )

    llm_answer = run_demo(image_path=image_path, question=prompt, vlm_id=vlm_id, llm_id=llm_id)
    print(llm_answer)