import torch
from transformers import AutoProcessor, AutoModelForImageTextToText, AutoTokenizer, AutoModel
from datasets import load_dataset
from PIL import Image
import csv
import json
from scipy.spatial.distance import cosine


def process_vision_info(messages: list[dict]) -> list[Image.Image]:
    image_inputs = []
    # Iterate through each conversation
    for msg in messages:
        # Ensure content is a list
        content = msg.get("content", [])
        if not isinstance(content, list):
            content = [content]

        # Look for any image entries
        for element in content:
            if isinstance(element, dict) and (
                element.get("type") == "image" or "image" in element
            ):
                # pull out the path string
                image_path = element.get("image")
                if not image_path:
                    continue

                # open via PIL and convert
                try:
                    img = Image.open(f"{data_path}/{image_path}").convert("RGB")
                    image_inputs.append(img)
                except Exception as e:
                    # you might log or handle missing/bad paths here
                    print(f"Failed to load {image_path}: {e}")

    return image_inputs

data_path = "../data"
# Hugging Face model id
output_dir = 'medical-vqa'

test_annotations = f"{data_path}/annotations_test.jsonl"
with open(test_annotations, 'r') as f:
    dataset = [json.loads(line) for line in f]

# Load Model with PEFT adapter
model = AutoModelForImageTextToText.from_pretrained(
  output_dir,
  device_map="auto",
  torch_dtype=torch.bfloat16,
  attn_implementation="eager",
)
processor = AutoProcessor.from_pretrained(output_dir)

def generate_response(messages):
    # messages = messages[:2]
    # Convert sample into messages and then apply the chat template

    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    # Process the image and text
    image_inputs = process_vision_info(messages)
    # Tokenize the text and process the images
    inputs = processor(
        text=[text],
        images=image_inputs,
        padding=True,
        return_tensors="pt",
    )
    # Move the inputs to the device
    inputs = inputs.to(model.device)

    # Generate the output
    stop_token_ids = [processor.tokenizer.eos_token_id, processor.tokenizer.convert_tokens_to_ids("<end_of_turn>")]
    generated_ids = model.generate(**inputs, max_new_tokens=3, top_p=1.0, do_sample=True, temperature=0.8, eos_token_id=stop_token_ids, disable_compile=True)
    # Trim the generation and decode the output to text
    generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
    output_text = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )
    return output_text[0]

def get_cls_embedding(sentence: str,
                      tokenizer: AutoTokenizer,
                      model: AutoModel,
                      device: torch.device) -> torch.Tensor:
    """
    Tokenize `sentence`, run it through `model`, and return the [CLS] embedding.
    """
    # Tokenize and convert to PyTorch tensors
    inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True)
    inputs = {k: v.to(device) for k, v in inputs.items()}

    # Forward pass
    with torch.no_grad():
        outputs = model(**inputs)
    # outputs.last_hidden_state shape: (batch_size, seq_len, hidden_size)
    # Grab the [CLS] token (index 0) embedding
    cls_emb = outputs.last_hidden_state[:, 0, :]
    return cls_emb.squeeze(0).cpu()

def sentence_similarity(sent1: str, sent2: str) -> float:
    """
    Compute cosine similarity between two sentences.
    """
    # Choose device (GPU if available)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
    model = AutoModel.from_pretrained("bert-base-uncased").to(device)
    model.eval()

    # Get embeddings
    emb1 = get_cls_embedding(sent1, tokenizer, model, device)
    emb2 = get_cls_embedding(sent2, tokenizer, model, device)

    # Compute cosine similarity
    sim = 1 - cosine(emb1.numpy(), emb2.numpy())
    return sim

output_path = "predictions.csv"
correct = 0
total = 0
with open(output_path, 'r', newline='') as csvfile:
    csv_reader = csv.reader(csvfile)
    for row in csv_reader:
        video_id, ground_truth, answer = row
        sim_score = sentence_similarity(ground_truth, answer)
        if sim_score>=0.95:
            correct += 1
        total += 1
print(f'The top1 accuracy is {correct/total}')
print(total)


