from dataset.mmimdb import MMIMDBDatasetSup, genres_
from lavis.models import load_model_and_preprocess
import numpy as np
import tqdm
import pickle
import torch
from torchmetrics.classification import F1Score
from sklearn.preprocessing import MultiLabelBinarizer

""" 
    Evaluate BLIP-2 accuracy to recognize the genres from the plot + poster
"""


def preproc_answer(answer):
    answer = answer.split(", ")
    answer = sorted(list(set(answer)))
    # remove answers not in the list of genres
    mask = [a in genres_ for a in answer]
    answer = np.array(answer)[mask]
    return answer


def evaluate_blip2(model, vis_preproc, template, dataset, device="cuda"):
    y_true, y_pred, y_pred_raw = [], [], []
    for id in tqdm.tqdm(range(len(dataset))):
        img, caption, genres = dataset.get_raw_item(id)
        # Process img
        image = vis_preproc["eval"](img).unsqueeze(0).to(device)
        prompt = template.format(caption)
        # Generate answer
        answer = model.generate({"image": image, "prompt": prompt})[0]
        y_pred_raw.append(answer)
        answer = preproc_answer(answer)
        y_pred.append(answer)
        y_true.append([g.lower() for g in genres])
    with open("blip2_immdb_eval.pkl", "wb") as f:
        pickle.dump({"y_pred": y_pred, "y_true": y_true, "y_pred_raw": y_pred_raw}, f)
    return y_pred, y_true


if __name__ == "__main__":
    device = "cuda"
    model, vis_preproc, _ = load_model_and_preprocess("blip2_t5", "pretrain_flant5xl", is_eval=True,
                                                      device=device)
    template = ("From the following plot: {} and this poster image, give me all the movie genres it belongs to "
                "among the following list: %s."%", ".join(genres_))
    print(template)
    dataset = MMIMDBDatasetSup("/fastdata/mmimdb", "/fastdata/mmimdb", "test")
    y_pred, y_true = evaluate_blip2(model, vis_preproc, template, dataset, device)
    # Convert predictions to binary multi-labels and get the metrics
    mlb = MultiLabelBinarizer()
    mlb.fit([genres_])
    y_pred, y_true = mlb.transform(y_pred), mlb.transform(y_true)
    f1_score = F1Score(task="multilabel", num_labels=len(genres_), average="macro")
    f1_weighted_score = F1Score(task="multilabel", num_labels=len(genres_), average="weighted")
    f1_score(torch.as_tensor(y_pred), torch.as_tensor(y_true))
    f1_weighted_score(torch.as_tensor(y_pred), torch.as_tensor(y_true))




