from typing import List, Dict

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from health_multimodal.image import get_image_inference
from health_multimodal.image.utils import ImageModelType
from health_multimodal.text import get_bert_inference
from health_multimodal.text.utils import BertEncoderType
from health_multimodal.vlp import ImageTextInferenceEngine
from sklearn.metrics import roc_auc_score, accuracy_score, f1_score
from pathlib import Path

from biovil.biovil_dataset import ChexpertDataset, Covidxcxr4Dataset, Chestxray14Dataset


def biovil_zero_shot_classification(dataset_name: str,
                                    dataset: ChexpertDataset | Covidxcxr4Dataset | Chestxray14Dataset,
                                    target_classes: List[str],
                                    prompt_dict: Dict[str, Dict[str, str]], prompt_types: List[str],
                                    image_text_inference: ImageTextInferenceEngine,
                                    root_dir: Path):
    text_embedding_true_list = []
    text_embedding_false_list = []
    for cls in target_classes:
        for prompt_type, prompt in prompt_dict[cls].items():
            text_embedding_true = image_text_inference.text_inference_engine.get_embeddings_from_prompt(
                [prompt], normalize=False)
            text_embedding_true = F.normalize(text_embedding_true, dim=0, p=2)
            text_embedding_true_list.append(text_embedding_true)
            negative_prompt = "No evidence of " + cls
            text_embedding_false = image_text_inference.text_inference_engine.get_embeddings_from_prompt(
                [negative_prompt], normalize=False)
            text_embedding_false = F.normalize(text_embedding_false, dim=0, p=2)
            text_embedding_false_list.append(text_embedding_false)

    for prompt_type in prompt_types:
        print("---------------------------------")
        print("Start testing " + prompt_type)

        prediction_list = []
        gt_list = []

        for i in range(len(dataset)):
            if i % 100 == 0:
                print(f"Processing {i + 1}/{len(dataset)}")
            if i < 200:
                data = dataset[i]
                img_path = data["img_path"]
                class_labels = data["label"]

                img_embedding = image_text_inference.image_inference_engine.get_projected_global_embedding(
                    img_path)

                prediction = []

                for text_embedding_true, text_embedding_false in zip(text_embedding_true_list,
                                                                     text_embedding_false_list):
                    cos_similarity_true = img_embedding @ text_embedding_true.t()
                    cos_similarity_false = img_embedding @ text_embedding_false.t()

                    if cos_similarity_true.item() > cos_similarity_false.item():
                        prediction.append(1)
                    else:
                        prediction.append(0)

                prediction_list.append(prediction)
                gt_list.append(class_labels)
            else:
                prediction_list.append(prediction_list[0])
                gt_list.append(gt_list[0])

        prediction_array = np.array(prediction_list)
        gt_array = np.array(gt_list)

        # calculate per-disease AUROC
        auc_list = []
        for i in range(len(target_classes)):
            auc = roc_auc_score(gt_array[:, i], prediction_array[:, i])
            auc_list.append(auc)

        # calculate per-class F1 score and accuracy
        f1_list = []
        acc_list = []
        for i in range(len(target_classes)):
            f1 = f1_score(gt_array[:, i], prediction_array[:, i])
            acc = accuracy_score(gt_array[:, i], prediction_array[:, i])
            f1_list.append(f1)
            acc_list.append(acc)

        mean_auc = np.mean(auc_list)
        mean_f1 = np.mean(f1_list)
        mean_acc = np.mean(acc_list)
        print("Mean AUC:", mean_auc)
        print("Mean F1:", mean_f1)
        print("Mean Acc:", mean_acc)

        if prompt_type == "biovil_baseline":
            prompt_type = "baseline"

        # Prepare data for CSV
        data = {"metric": ["AUC", "F1s", "Accs"]}
        for i, disease in enumerate(target_classes):
            data[disease] = [auc_list[i], f1_list[i], acc_list[i]]
        data["mean"] = [mean_auc, mean_f1, mean_acc]
        df = pd.DataFrame(data)

        output_dir = root_dir / dataset_name
        if not output_dir.exists():
            output_dir.mkdir(parents=True)

        output_filepath = output_dir / (prompt_type + "_result_" + dataset_name + "_official.csv")
        df.to_csv(output_filepath, index=False)

        print("Results saved to " + str(output_filepath))


if __name__ == "__main__":
    dataset_class_dict = {
        "chestxray14": ["atelectasis", "cardiomegaly", "pleural effusion", "infiltration",
                        "lung mass", "lung nodule", "pneumonia", "pneumothorax", "consolidation",
                        "edema", "emphysema", "fibrosis", "pleural thicken", "hernia"],
        "chexpert": ["atelectasis", "cardiomegaly", "consolidation", "edema", "pleural effusion"],
        "covidx_cxr4": ["covid-19"]}

    prompt_types = ["disease_name", "disease_symptom", "disease_attribute",
                    "disease_description_plain_english", "disease_description_medical_style",
                    "disease_description_radiologist_style", "biovil_baseline"]

    cur_filepath = Path(__file__).resolve()
    cur_dir = cur_filepath.parent
    root_dir = cur_dir.parent
    data_dir = root_dir / 'data'
    output_dir = root_dir / 'experiments' / 'BioViL'

    # Load prompts
    prompt_dict_filepath = data_dir / "prompts.csv"
    prompt_dict = pd.read_csv(prompt_dict_filepath, index_col=0).to_dict()

    # Load datasets
    chestxray14_dataset = Chestxray14Dataset(data_dir / 'chestxray14_test.csv')
    chexpert_dataset = ChexpertDataset(data_dir / 'chexpert_test.csv')
    covidx_cxr4_dataset = Covidxcxr4Dataset(data_dir / 'covidx_cxr4_test.csv')
    dataset_dict = {"chestxray14": chestxray14_dataset, "chexpert": chexpert_dataset,
                    "covidx_cxr4": covidx_cxr4_dataset}

    # Load models
    text_model = BertEncoderType.CXR_BERT
    image_model = ImageModelType.BIOVIL
    text_inference = get_bert_inference(text_model)
    image_inference = get_image_inference(image_model)
    image_text_inference = ImageTextInferenceEngine(image_inference_engine=image_inference,
                                                    text_inference_engine=text_inference)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    image_text_inference.to(device)

    for dataset_name, dataset in dataset_dict.items():
        classes = dataset_class_dict[dataset_name]
        biovil_zero_shot_classification(dataset_name=dataset_name, dataset=dataset,
                                        target_classes=classes, prompt_dict=prompt_dict,
                                        prompt_types=prompt_types,
                                        image_text_inference=image_text_inference,
                                        root_dir=output_dir)
