import random
from pathlib import Path
from typing import List, Dict

import numpy as np
import pandas as pd
import torch
import torch.backends.cudnn as cudnn
from sklearn.metrics import roc_auc_score, matthews_corrcoef, f1_score, accuracy_score
from torch.utils.data import DataLoader
from transformers import AutoTokenizer

from kad.kad_dataset import ChexpertDataset, Chestxray14Dataset, Covidxcxr4Dataset
from kad.models.clip_tqn import CLP_clinical, TQN_Model, ModelRes512


def get_sort_eachclass(metric_list, n_class=14):
    metric_5 = []
    metric_95 = []
    metric_mean = []
    for i in range(n_class):
        sorted_metric_list = sorted(metric_list, key=lambda x: x[i])
        metric_5.append(sorted_metric_list[50][i])
        metric_95.append(sorted_metric_list[950][i])
        metric_mean.append(np.mean(np.array(sorted_metric_list), axis=0)[i])
    mean_metric_5 = np.mean(np.array(metric_5))
    metric_5.append(mean_metric_5)
    mean_metric_95 = np.mean(np.array(metric_95))
    metric_95.append(mean_metric_95)
    mean_metric_mean = np.mean(np.array(metric_mean))
    metric_mean.append(mean_metric_mean)
    return metric_5, metric_95, metric_mean


def compute_AUCs(gt, pred, n_class=14):
    """Computes Area Under the Curve (AUC) from prediction scores.
    Args:
        gt: Pytorch tensor on GPU, shape = [n_samples, n_classes]
          true binary labels.
        pred: Pytorch tensor on GPU, shape = [n_samples, n_classes]
          can either be probability estimates of the positive class,
          confidence values, or binary decisions.
    Returns:
        List of AUROCs of all classes.
    """
    metrics = {}
    AUROCs = []
    gt_np = gt
    pred_np = pred
    for i in range(n_class):
        AUROCs.append(roc_auc_score(gt_np[:, i], pred_np[:, i]))
    mean_auc = np.mean(np.array(AUROCs[1:]))
    AUROCs.append(mean_auc)
    return AUROCs


def compute_F1s_threshold(gt, pred, threshold, n_class=14):
    gt_np = gt
    pred_np = pred

    F1s = []
    for i in range(n_class):
        pred_np[:, i][pred_np[:, i] >= threshold[i]] = 1
        pred_np[:, i][pred_np[:, i] < threshold[i]] = 0
        F1s.append(f1_score(gt_np[:, i], pred_np[:, i], average='binary'))  # macro
    mean_f1 = np.mean(np.array(F1s[1:]))
    F1s.append(mean_f1)
    return F1s


def compute_Accs_threshold(gt, pred, threshold, n_class=14):
    gt_np = gt
    pred_np = pred
    Accs = []
    for i in range(n_class):
        pred_np[:, i][pred_np[:, i] >= threshold[i]] = 1
        pred_np[:, i][pred_np[:, i] < threshold[i]] = 0
        Accs.append(accuracy_score(gt_np[:, i], pred_np[:, i]))
    mean_accs = np.mean(np.array(Accs[1:]))
    Accs.append(mean_accs)
    return Accs


def compute_mccs_threshold(gt, pred, threshold, n_class=14):
    gt_np = gt
    pred_np = pred
    mccs = []
    for i in range(n_class):
        pred_np[:, i][pred_np[:, i] >= threshold[i]] = 1
        pred_np[:, i][pred_np[:, i] < threshold[i]] = 0
        mccs.append(matthews_corrcoef(gt_np[:, i], pred_np[:, i]))
    mean_mccs = np.mean(np.array(mccs[1:]))
    mccs.append(mean_mccs)
    return mccs


def compute_mccs(gt, pred, n_class=14):
    # get a best threshold for all classes
    gt_np = gt
    pred_np = pred
    select_best_thresholds = []
    best_mcc = 0.0

    for i in range(n_class):
        select_best_threshold_i = 0.0
        best_mcc_i = 0.0
        for threshold_idx in range(len(pred)):
            pred_np_ = pred_np.copy()
            thresholds = pred[threshold_idx]
            pred_np_[:, i][pred_np_[:, i] >= thresholds[i]] = 1
            pred_np_[:, i][pred_np_[:, i] < thresholds[i]] = 0
            mcc = matthews_corrcoef(gt_np[:, i], pred_np_[:, i])
            if mcc > best_mcc_i:
                best_mcc_i = mcc
                select_best_threshold_i = thresholds[i]
        select_best_thresholds.append(select_best_threshold_i)

    for i in range(n_class):
        pred_np[:, i][pred_np[:, i] >= select_best_thresholds[i]] = 1
        pred_np[:, i][pred_np[:, i] < select_best_thresholds[i]] = 0
    mccs = []
    for i in range(n_class):
        mccs.append(matthews_corrcoef(gt_np[:, i], pred_np[:, i]))
    mean_mcc = np.mean(np.array(mccs[1:]))
    mccs.append(mean_mcc)
    return mccs, select_best_thresholds


def get_text_features(model,text_list,tokenizer,device,max_length):
    text_token =  tokenizer(list(text_list),add_special_tokens=True,max_length=max_length,pad_to_max_length=True,return_tensors='pt').to(device=device)
    text_features = model.encode_text(text_token)
    return text_features


def kad_zero_shot_classification(dataset_name: str, dataset, target_classes: List[str],
                                 prompt_dict: Dict[str, Dict[str, str]], prompt_types: List[str],
                                 root_dir: Path, model: TQN_Model, image_encoder: ModelRes512,
                                 text_encoder: CLP_clinical, tokenizer: AutoTokenizer):
    if dataset_name != "chexpert":
        return
    data_loader = DataLoader(dataset, batch_size=64, num_workers=8, pin_memory=True,
                                 sampler=None, shuffle=False, collate_fn=None, drop_last=True)
    data_loader.num_samples = len(dataset)
    data_loader.num_batches = len(data_loader)

    # fix the seed for reproducibility
    seed = 42
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    cudnn.benchmark = True

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model.eval()
    image_encoder.eval()
    text_encoder.eval()

    for prompt_type in prompt_types:
        print("---------------------------------")
        print("Start testing " + prompt_type)

        text_list = [prompt_dict[cls][prompt_type] for cls in target_classes]
        text_features = get_text_features(text_encoder, text_list, tokenizer, device,
                                          max_length=256)
        gt = torch.FloatTensor()
        gt = gt.cuda()
        pred = torch.FloatTensor()
        pred = pred.cuda()

        for i, sample in enumerate(data_loader):
            image = sample['image'].to(device)
            label = sample['label'].float().to(device)  # batch_size,num_class
            gt = torch.cat((gt, label), 0)

            with torch.no_grad():
                image_features, image_features_pool = image_encoder(image)
                pred_class = model(image_features, text_features)
                pred_class = torch.softmax(pred_class, dim=-1)

                pred = torch.cat((pred, pred_class[:, :, 1]), 0)

        auc_list = compute_AUCs(gt.cpu().numpy(), pred.cpu().numpy(), n_class=len(target_classes))
        mccs, threshold = compute_mccs(gt.cpu().numpy(), pred.cpu().numpy(), n_class=len(target_classes))
        f1_list = compute_F1s_threshold(gt.cpu().numpy(), pred.cpu().numpy(), threshold, n_class=len(target_classes))
        acc_list = compute_Accs_threshold(gt.cpu().numpy(), pred.cpu().numpy(), threshold, n_class=len(target_classes))

        mean_auc = np.mean(auc_list)
        mean_f1 = np.mean(f1_list)
        mean_acc = np.mean(acc_list)
        print("Mean AUC:", mean_auc)
        print("Mean F1:", mean_f1)
        print("Mean Acc:", mean_acc)

        if prompt_type == "kad_baseline":
            prompt_type = "baseline"

        # Prepare data for CSV
        data = {"metric": ["AUC", "F1s", "Accs"]}
        for i, disease in enumerate(target_classes):
            data[disease] = [auc_list[i], f1_list[i], acc_list[i]]
        data["mean"] = [mean_auc, mean_f1, mean_acc]
        df = pd.DataFrame(data)

        output_dir = root_dir / dataset_name
        if not output_dir.exists():
            output_dir.mkdir(parents=True)

        output_filepath = output_dir / (prompt_type + "_result_" + dataset_name + "_official.csv")
        df.to_csv(output_filepath, index=False)

        print("Results saved to " + str(output_filepath))

if __name__ == "__main__":
    dataset_class_dict = {
        "chestxray14": ["atelectasis", "cardiomegaly", "pleural effusion", "infiltration",
                        "lung mass", "lung nodule", "pneumonia", "pneumothorax", "consolidation",
                        "edema", "emphysema", "fibrosis", "pleural thicken", "hernia"],
        "chexpert": ["atelectasis", "cardiomegaly", "consolidation", "edema", "pleural effusion"],
        "covidx_cxr4": ["covid-19"]}

    prompt_types = ["disease_name", "disease_symptom", "disease_attribute",
                    "disease_description_plain_english", "disease_description_medical_style",
                    "disease_description_radiologist_style", "kad_baseline"]

    cur_filepath = Path(__file__).resolve()
    cur_dir = cur_filepath.parent
    root_dir = cur_dir.parent
    data_dir = root_dir / 'data'
    output_dir = root_dir / 'experiments' / 'KAD'

    # Load prompts
    prompt_dict_filepath = data_dir / "prompts.csv"
    prompt_dict = pd.read_csv(prompt_dict_filepath, index_col=0).to_dict()

    # Load datasets
    chestxray14_dataset = Chestxray14Dataset(data_dir / 'chestxray14_test.csv')
    chexpert_dataset = ChexpertDataset(data_dir / 'chexpert_test.csv')
    covidx_cxr4_dataset = Covidxcxr4Dataset(data_dir / 'covidx_cxr4_test.csv')
    dataset_dict = {"chestxray14": chestxray14_dataset, "chexpert": chexpert_dataset,
                    "covidx_cxr4": covidx_cxr4_dataset}

    # Load model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    torch.set_default_tensor_type('torch.FloatTensor')

    image_encoder = ModelRes512(res_base_model='resnet50').to(device)

    tokenizer = AutoTokenizer.from_pretrained('xmcmic/Med-KEBERT', do_lower_case=True,
                                              local_files_only=False)
    text_encoder = CLP_clinical(bert_model_name='xmcmic/Med-KEBERT').to(device=device)

    models_dir = root_dir / 'models'
    knowledge_encoder_filepath = models_dir / "epoch_latest.pt"
    model_filepath = models_dir / "best_valid.pt"

    checkpoint = torch.load(knowledge_encoder_filepath, map_location='cpu')
    state_dict = checkpoint["state_dict"]
    filtered_state_dict = {k: v for k, v in state_dict.items() if k in text_encoder.state_dict()}
    text_encoder.load_state_dict(filtered_state_dict)

    model = TQN_Model().to(device)

    checkpoint = torch.load(model_filepath, map_location='cpu')
    image_state_dict = checkpoint['image_encoder']
    image_encoder.load_state_dict(image_state_dict)
    text_state_dict = checkpoint['text_encoder']
    filtered_state_dict = {k: v for k, v in text_state_dict.items() if
                           k in text_encoder.state_dict()}
    text_encoder.load_state_dict(filtered_state_dict)
    state_dict = checkpoint['model']
    model.load_state_dict(state_dict)

    for dataset_name, dataset in dataset_dict.items():
        classes = dataset_class_dict[dataset_name]
        kad_zero_shot_classification(dataset_name=dataset_name, dataset=dataset,
                                     target_classes=classes, prompt_dict=prompt_dict,
                                     prompt_types=prompt_types, root_dir=root_dir,
                                     model=model, image_encoder=image_encoder,
                                     text_encoder=text_encoder, tokenizer=tokenizer)
