import copy
from dataset.mmimdb import MMIMDBDatasetSup, genres_
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IGNORE_INDEX
from plots.llava_conversation import conv_templates
import numpy as np
import tqdm
import re
import pickle
import torch
from torchmetrics.classification import F1Score
from sklearn.preprocessing import MultiLabelBinarizer

""" 
    Evaluate LLaVa-NeXT accuracy to recognize the genres from the plot + poster
"""


def parse_answer(answer):
    answer = answer.split("\n")
    # look for a list of names, e.g. "1. Drama\n2.Crime" and parse it to get the genres
    clean_answer = []
    for a in answer:
        a = re.search("[0-9]+\.\ ([A-Za-z\-]+)", a)
        if a and a.groups()[0].lower() in genres_:
            clean_answer.append(a.groups()[0].lower())
    return sorted(list(set(clean_answer)))


def evaluate_llava_next(model, vis_preproc, tokenizer, template, dataset, conv_template="llava_llama_3", device="cuda"):
    y_true, y_pred, y_pred_raw = [], [], []
    for id in tqdm.tqdm(range(len(dataset))):
        img, caption, genres = dataset.get_raw_item(id)
        # Process img
        image = process_images([img], vis_preproc, model.config)
        image = [_img.to(dtype=torch.float16, device=device) for _img in image]
        # Process text
        prompt = DEFAULT_IMAGE_TOKEN + "\n" + template.format(caption)
        conv = copy.deepcopy(conv_templates[conv_template])
        conv.append_message(conv.roles[0], prompt)
        conv.append_message(conv.roles[1], None)
        prompt_question = conv.get_prompt()
        input_ids = tokenizer_image_token(prompt_question, tokenizer,
                                          IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device)
        # Generate answer
        answer = model.generate(input_ids, images=image, image_sizes=[img.size], do_sample=False,
                                temperature=0, max_new_tokens=256)
        answer = tokenizer.batch_decode(answer, skip_special_tokens=True)[0]
        y_pred_raw.append(answer)
        answer = parse_answer(answer)
        y_pred.append(answer)
        y_true.append([g.lower() for g in genres])
    with open("llava-next_immdb_eval.pkl", "wb") as f:
        pickle.dump({"y_pred": y_pred, "y_true": y_true, "y_pred_raw": y_pred_raw}, f)
    return y_pred, y_true


if __name__ == "__main__":
    device = "cuda"
    tokenizer, model, vis_preproc, _ = load_pretrained_model("lmms-lab/llama3-llava-next-8b", None, "llava_llama3", device=device, attn_implementation=None)
    template = ("From the following plot: \"{}\" and this poster image, give me all the movie genres it belongs to "
                "among the following list: %s. Give me the answer as a list."%", ".join(genres_))
    print(template)
    dataset = MMIMDBDatasetSup("/fastdata/mmimdb", "/fastdata/mmimdb", "test")
    y_pred, y_true = evaluate_llava_next(model, vis_preproc, tokenizer, template, dataset, device=device)
    # Convert predictions to binary multi-labels and get the metrics
    mlb = MultiLabelBinarizer()
    mlb.fit([genres_])
    y_pred, y_true = mlb.transform(y_pred), mlb.transform(y_true)
    f1_score = F1Score(task="multilabel", num_labels=len(genres_), average="macro")
    f1_weighted_score = F1Score(task="multilabel", num_labels=len(genres_), average="weighted")
    print(f1_score(torch.as_tensor(y_pred), torch.as_tensor(y_true)))
    print(f1_weighted_score(torch.as_tensor(y_pred), torch.as_tensor(y_true)))




