import torch
import gradio as gr
from PIL import Image
from llava.model.builder import load_pretrained_model
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates
from llava.utils import disable_torch_init
from llava.mm_utils import process_images, tokenizer_image_token

def load_model():
    model_path = "checkpoints/llava/clean-images/cc_sbu_align-Biden_base_Trump_target/poison_0-seed_0"
    model_base = "liuhaotian/llava-v1.5-7b"
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    print(f"Loading model from {model_path}")
    model_name = 'llava_v1.5_lora'
    
    tokenizer, model, image_processor, context_len = load_pretrained_model(
        model_path,
        model_base,
        model_name,
        load_8bit=False,
        load_4bit=False,
        device=device
    )
    
    conv = conv_templates["llava_v1"].copy()
    return tokenizer, model, image_processor, conv

def generate_response(image, prompt, tokenizer, model, image_processor, conv):
    conv = conv.copy()
    conv.messages = []
    
    # Process image
    image_tensor = process_images([image], image_processor, model.config)
    if isinstance(image_tensor, list):
        image_tensor = [img.to(model.device, dtype=torch.float16) for img in image_tensor]
    else:
        image_tensor = image_tensor.to(model.device, dtype=torch.float16)
    
    # Format prompt
    if model.config.mm_use_im_start_end:
        prompt = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + prompt
    else:
        prompt = DEFAULT_IMAGE_TOKEN + '\n' + prompt
    
    conv.append_message(conv.roles[0], prompt)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()
    
    input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
    
    with torch.inference_mode():
        output_ids = model.generate(
            input_ids,
            images=image_tensor,
            do_sample=True,
            temperature=0.7,
            max_new_tokens=512,
            use_cache=True
        )
    
    response = tokenizer.decode(output_ids[0, input_ids.shape[1]:], skip_special_tokens=True).strip()
    return response

def gradio_interface(image, prompt):
    if image is None:
        return "Please upload an image."
    return generate_response(image, prompt, tokenizer, model, image_processor, conv)

if __name__ == "__main__":
    # Initialize model
    print("Initializing model...")
    disable_torch_init()
    tokenizer, model, image_processor, conv = load_model()
    print("Model loaded successfully!")

    # Create Gradio interface
    demo = gr.Interface(
        fn=gradio_interface,
        inputs=[
            gr.Image(type="pil", label="Upload Image"),
            gr.Textbox(
                label="Enter your prompt",
                value="",
                lines=2
            )
        ],
        outputs=gr.Textbox(label="Response", lines=5),
        title="LLaVA-LoRA Image Chat",
        description="Upload an image and enter a prompt to get a response from the fine-tuned LLaVA model.",
        cache_examples=True
    )

    # Launch the interface
    demo.launch(
        share=True,
        enable_queue=True
    )