import json
from pathlib import Path
from typing import List, Dict, Union

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import precision_recall_curve, accuracy_score
from sklearn.metrics import roc_auc_score
from torch.utils.data import DataLoader

from medklip.models.model_MedKLIP import MedKLIP
from medklip.models.tokenization_bert import BertTokenizer
from medklip.medklip_dataset import ChexpertDataset, Covidxcxr4Dataset, Chestxray14Dataset


def compute_AUCs(gt, pred, n_class):
    """Computes Area Under the Curve (AUC) from prediction scores.
    Args:
        gt: Pytorch tensor on GPU, shape = [n_samples, n_classes]
          true binary labels.
        pred: Pytorch tensor on GPU, shape = [n_samples, n_classes]
          can either be probability estimates of the positive class,
          confidence values, or binary decisions.
    Returns:
        List of AUROCs of all classes.
    """
    AUROCs = []
    gt_np = gt.cpu().numpy()
    pred_np = pred.cpu().numpy()
    for i in range(n_class):
        AUROCs.append(roc_auc_score(gt_np[:, i], pred_np[:, i]))
    return AUROCs


def get_tokenizer(tokenizer, target_text):
    target_tokenizer = tokenizer(list(target_text), padding='max_length', truncation=True,
                                 max_length=64, return_tensors="pt")

    return target_tokenizer


def medklip_zero_shot_classification(dataset_name: str, dataset: Union[
    ChexpertDataset, Covidxcxr4Dataset, Chestxray14Dataset], target_classes: List[str],
                                     prompt_dict: Dict[str, Dict[str, str]],
                                     prompt_types: List[str], disease_book_filepath: Path,
                                     model_filepath: Path, root_dir: Path):
    for prompt_type in prompt_types:
        print("---------------------------------")
        print("Start testing " + prompt_type)

        json_book = json.load(open(disease_book_filepath, 'r'))
        disease_book = [json_book[i] for i in json_book]

        original_class = ['normal', 'clear', 'sharp', 'sharply', 'unremarkable', 'intact', 'stable',
                          'free', 'pleural effusion', 'opacity', 'pneumothorax', 'edema',
                          'atelectasis', 'tube', 'consolidation', 'process', 'abnormality',
                          'enlarge', 'tip', 'low', 'pneumonia', 'line', 'congestion', 'catheter',
                          'cardiomegaly', 'fracture', 'air', 'tortuous', 'lead', 'disease',
                          'calcification', 'prominence', 'device', 'engorgement', 'picc', 'clip',
                          'elevation', 'expand', 'lung nodule', 'wire', 'fluid', 'degenerative',
                          'pacemaker', 'pleural thicken', 'marking', 'scar', 'hyperinflate',
                          'blunt', 'loss', 'widen', 'collapse', 'density', 'emphysema', 'aerate',
                          'lung mass', 'crowd', 'infiltration', 'obscure', 'deformity', 'hernia',
                          'drainage', 'distention', 'shift', 'stent', 'pressure', 'lesion',
                          'finding', 'borderline', 'hardware', 'dilation', 'chf', 'redistribution',
                          'aspiration', 'fibrosis', 'excluded_obs']

        if "covid-19" in target_classes:
            original_class += ["covid-19"]
            disease_book += [
                "It is a contagious disease caused by a virus. Ground-glass opacities, consolidation, thickening, pleural effusions commonly appear in infection."]

        for cls in target_classes:
            original_class_idx = original_class.index(cls)
            disease_book[original_class_idx] = prompt_dict[cls][prompt_type]

        mapping = []
        for disease in target_classes:
            mapping.append(original_class.index(disease))
        MIMIC_mapping = [_ for i, _ in enumerate(mapping) if _ != -1]
        dataset_mapping = [i for i, _ in enumerate(mapping) if _ != -1]
        target_class_mapping = [target_classes[i] for i in dataset_mapping]

        # Load models
        test_dataloader = DataLoader(dataset, batch_size=64, num_workers=4, pin_memory=True,
                                     sampler=None, shuffle=True, collate_fn=None, drop_last=False, )

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        tokenizer = BertTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
        # tokenizer = AutoTokenizer.from_pretrained(config['text_encoder'])
        disease_book_tokenizer = get_tokenizer(tokenizer, disease_book).to(device)

        model = MedKLIP(disease_book_tokenizer)
        model = nn.DataParallel(model, device_ids=[i for i in range(torch.cuda.device_count())])
        model = model.to(device)

        checkpoint = torch.load(str(model_filepath), map_location='cpu')
        state_dict = checkpoint['model']
        model.load_state_dict(state_dict)

        # initialize the ground truth and output tensor
        gt = torch.FloatTensor()
        gt = gt.to(device)
        pred = torch.FloatTensor()
        pred = pred.to(device)

        model.eval()

        for i, sample in enumerate(test_dataloader):
            image = sample['image']
            label = sample['label'][:, dataset_mapping].float().to(device)
            gt = torch.cat((gt, label), 0)
            input_image = image.to(device, non_blocking=True)
            with torch.no_grad():
                pred_class = model(input_image)  # batch_size,num_class,dim
                # print(pred_class.shape)
                pred_class = F.softmax(pred_class.reshape(-1, 2), dim=1).reshape(-1,
                                                                                 len(original_class),
                                                                                 2)
                pred_class = pred_class[:, MIMIC_mapping, 1]
                pred = torch.cat((pred, pred_class), 0)

        auc_list = compute_AUCs(gt, pred, len(target_classes))
        f1_list = []
        acc_list = []
        for i in range(len(target_classes)):
            gt_np = gt[:, i].cpu().numpy()
            pred_np = pred[:, i].cpu().numpy()
            precision, recall, thresholds = precision_recall_curve(gt_np, pred_np)
            numerator = 2 * recall * precision
            denom = recall + precision
            f1_scores = np.divide(numerator, denom, out=np.zeros_like(denom), where=(denom != 0))
            max_f1 = np.max(f1_scores)
            max_f1_thresh = thresholds[np.argmax(f1_scores)]
            f1_list.append(max_f1)
            acc_list.append(accuracy_score(gt_np, pred_np > max_f1_thresh))

        mean_auc = np.mean(auc_list)
        mean_f1 = np.mean(f1_list)
        mean_acc = np.mean(acc_list)
        print("Mean AUC:", mean_auc)
        print("Mean F1:", mean_f1)
        print("Mean Acc:", mean_acc)

        if prompt_type == "medklip_baseline":
            prompt_type = "baseline"

        # Prepare data for CSV
        data = {"metric": ["AUC", "F1s", "Accs"]}
        for i, disease in enumerate(target_classes):
            data[disease] = [auc_list[i], f1_list[i], acc_list[i]]
        data["mean"] = [mean_auc, mean_f1, mean_acc]
        df = pd.DataFrame(data)

        output_dir = root_dir / dataset_name
        if not output_dir.exists():
            output_dir.mkdir(parents=True)

        output_filepath = output_dir / (prompt_type + "_result_" + dataset_name + "_official.csv")
        df.to_csv(output_filepath, index=False)

        print("Results saved to " + str(output_filepath))


if __name__ == "__main__":
    dataset_class_dict = {
        "chestxray14": ["atelectasis", "cardiomegaly", "pleural effusion", "infiltration",
                        "lung mass", "lung nodule", "pneumonia", "pneumothorax", "consolidation",
                        "edema", "emphysema", "fibrosis", "pleural thicken", "hernia"],
        "chexpert": ["atelectasis", "cardiomegaly", "consolidation", "edema", "pleural effusion"],
        "covidx_cxr4": ["covid-19"]}

    prompt_types = ["disease_name", "disease_symptom", "disease_attribute",
                    "disease_description_plain_english", "disease_description_medical_style",
                    "disease_description_radiologist_style", "medklip_baseline"]

    cur_filepath = Path(__file__).resolve()
    cur_dir = cur_filepath.parent
    root_dir = cur_dir.parent
    data_dir = root_dir / 'data'
    output_dir = root_dir / 'experiments' / 'MedKLIP'

    model_dir = root_dir / 'models'
    model_filepath = model_dir / 'checkpoint_final.pth'

    # Load prompts
    prompt_dict_filepath = data_dir / "prompts.csv"
    prompt_dict = pd.read_csv(prompt_dict_filepath, index_col=0).to_dict()

    # Load datasets
    chestxray14_dataset = Chestxray14Dataset(data_dir / 'chestxray14_test.csv')
    chexpert_dataset = ChexpertDataset(data_dir / 'chexpert_test.csv')
    covidx_cxr4_dataset = Covidxcxr4Dataset(data_dir / 'covidx_cxr4_test.csv')
    dataset_dict = {"chestxray14": chestxray14_dataset, "chexpert": chexpert_dataset,
                    "covidx_cxr4": covidx_cxr4_dataset}

    disease_book_filepath = data_dir / "observation explanation.json"

    for dataset_name, dataset in dataset_dict.items():
        classes = dataset_class_dict[dataset_name]
        medklip_zero_shot_classification(dataset_name=dataset_name, dataset=dataset,
                                         target_classes=classes, prompt_dict=prompt_dict,
                                         prompt_types=prompt_types,
                                         disease_book_filepath=disease_book_filepath,
                                         model_filepath=model_filepath, root_dir=output_dir)
