import torch
import numpy as np
from sklearn.metrics import (
    f1_score, 
    roc_auc_score, 
    average_precision_score,
    cohen_kappa_score
)
from pyhealth.metrics import multilabel_metrics_fn, binary_metrics_fn
import torch.nn.functional as F

import warnings
from sklearn.exceptions import UndefinedMetricWarning


def f1(y_true_hot, y_pred, metrics='weighted'):
    result = np.zeros_like(y_true_hot)
    for i in range(len(result)):
        true_number = np.sum(y_true_hot[i] == 1)
        result[i][y_pred[i][:true_number]] = 1
    return f1_score(y_true=y_true_hot, y_pred=result, average=metrics, zero_division=0)


def top_k_prec_recall(y_true_hot, y_pred, ks):
    a = np.zeros((len(ks),))
    r = np.zeros((len(ks),))
    for pred, true_hot in zip(y_pred, y_true_hot):
        true = np.where(true_hot == 1)[0].tolist()
        t = set(true)
        for i, k in enumerate(ks):
            p = set(pred[:k])
            it = p.intersection(t)
            a[i] += len(it) / k
            # r[i] += len(it) / min(k, len(t))
            r[i] += len(it) / len(t)
    return a / len(y_true_hot), r / len(y_true_hot)


def calculate_occurred(historical, y, preds, ks):
    # y_occurred = np.sum(np.logical_and(historical, y), axis=-1)
    # y_prec = np.mean(y_occurred / np.sum(y, axis=-1))
    r1 = np.zeros((len(ks), ))
    r2 = np.zeros((len(ks),))
    n = np.sum(y, axis=-1)
    for i, k in enumerate(ks):
        # n_k = np.minimum(n, k)
        n_k = n
        pred_k = np.zeros_like(y)
        for T in range(len(pred_k)):
            pred_k[T][preds[T][:k]] = 1
        # pred_occurred = np.sum(np.logical_and(historical, pred_k), axis=-1)
        pred_occurred = np.logical_and(historical, pred_k)
        pred_not_occurred = np.logical_and(np.logical_not(historical), pred_k)
        pred_occurred_true = np.logical_and(pred_occurred, y)
        pred_not_occurred_true = np.logical_and(pred_not_occurred, y)
        r1[i] = np.mean(np.sum(pred_occurred_true, axis=-1) / n_k)
        r2[i] = np.mean(np.sum(pred_not_occurred_true, axis=-1) / n_k)
    return r1, r2


def evaluate_codes(model, dataset, loss_fn, output_size):
    model.eval()
    total_loss = 0.0
    labels = dataset.label()
    preds = []
    for step in range(len(dataset)):
        code_x, visit_lens, y = dataset[step]
        output = model(code_x).squeeze()
        pred = torch.argsort(output, dim=-1, descending=True)
        preds.append(pred)
        loss = loss_fn(output, y)
        total_loss += loss.item() * output_size * len(code_x)
        print('\r    Evaluating step %d / %d' % (step + 1, len(dataset)), end='')
    avg_loss = total_loss / dataset.size()
    preds = torch.vstack(preds).detach().cpu().numpy()
    f1_score = f1(labels, preds)
    prec, recall = top_k_prec_recall(labels, preds, ks=[10, 20, 30, 40])
    print('\r    Evaluation: loss: %.4f --- f1_score: %.4f --- top_k_recall: %.4f, %.4f, %.4f, %.4f'
          % (avg_loss, f1_score, recall[0], recall[1], recall[2], recall[3]))
    return avg_loss, f1_score


def evaluate_binary(model, dataset, loss_fn, output_size=1):
    model.eval()
    total_loss = 0.0
    labels = dataset.label()
    preds = []
    with torch.no_grad():
        for step in range(len(dataset)):
            code_x, visit_lens, y = dataset[step]
            output = model(code_x).squeeze()
            pred = torch.sigmoid(output)  # Convert to probabilities
            preds.append(pred)
            loss = loss_fn(output, y)
            total_loss += loss.item() * output_size * len(code_x)
            print('\r    Evaluating step %d / %d' % (step + 1, len(dataset)), end='')
    
    avg_loss = total_loss / dataset.size()
    preds = torch.cat(preds).detach().cpu().numpy()
    binary_preds = (preds > 0.5).astype(int)

    auprc = average_precision_score(labels, preds)
    auroc = roc_auc_score(labels, preds)
    f1 = f1_score(labels, binary_preds)
    kappa = cohen_kappa_score(labels, binary_preds)
    
    print('\r    Evaluation: loss: %.4f --- AUPRC: %.4f --- AUROC: %.4f --- F1: %.4f --- Kappa: %.4f'
          % (avg_loss, auprc, auroc, f1, kappa))
    return avg_loss, None


def evaluate_binary_old(model, dataset, loss_fn, output_size=1):
    model.eval()
    total_loss = 0.0
    labels = dataset.label()
    preds = []
    for step in range(len(dataset)):
        code_x, visit_lens, y = dataset[step]
        output = model(code_x).squeeze()
        pred = torch.sigmoid(output)
        preds.append(pred.unsqueeze(-1))
        loss = loss_fn(output, y)
        total_loss += loss.item() * output_size * len(code_x)
        print('\r    Evaluating step %d / %d' % (step + 1, len(dataset)), end='')
    avg_loss = total_loss / dataset.size()
    preds = torch.vstack(preds).squeeze().detach().cpu().numpy()

    metrics = ['pr_auc', 'roc_auc', 'f1', 'cohen_kappa']
    eval_dict = binary_metrics_fn(labels, preds, metrics=metrics)
    print('\r    Evaluation: loss: %.4f --- AUPRC: %.4f --- AUROC: %.4f --- F1: %.4f --- Kappa: %.4f'
          % (avg_loss, eval_dict[metrics[0]], eval_dict[metrics[1]],
             eval_dict[metrics[2]], eval_dict[metrics[3]]))
    return avg_loss, None


def evaluate_hf(model, dataset, loss_fn, output_size=1, historical=None):
    model.eval()
    total_loss = 0.0
    labels = dataset.label()
    outputs = []
    preds = []
    
    for step in range(len(dataset)):
        code_x, visit_lens, y = dataset[step]
        output = model(code_x).squeeze()
        loss = loss_fn(output, y)
        total_loss += loss.item() * output_size * len(code_x)
        output = output.detach().cpu().numpy()
        outputs.append(output)
        pred = (output > 0.5).astype(int)
        preds.append(pred)
        print('\r    Evaluating step %d / %d' % (step + 1, len(dataset)), end='')
    
    avg_loss = total_loss / dataset.size()
    outputs = np.concatenate(outputs)
    outputs = F.sigmoid(torch.tensor(outputs)).numpy()
    preds = np.concatenate(preds)
    
    # print(outputs[:10])
    # print(labels[:10])
    
    auc = roc_auc_score(labels, outputs)
    f1_score_ = f1_score(labels, preds)
    print('\r    Evaluation: loss: %.4f --- auc: %.4f --- f1_score: %.4f' % (avg_loss, auc, f1_score_))
    return avg_loss, None


def evaluate_multi_label(model, dataset, loss_fn, output_size, threshold=0.5):
    model.eval()
    total_loss = 0.0
    labels = dataset.label()
    preds = []
    for step in range(len(dataset)):
        code_x, visit_lens, y = dataset[step]
        output = model(code_x).squeeze()
        pred = torch.sigmoid(output)
        preds.append(pred)
        loss = loss_fn(output, y)
        total_loss += loss.item() * output_size * len(code_x)
        print('\r    Evaluating step %d / %d' % (step + 1, len(dataset)), end='')
    avg_loss = total_loss / dataset.size()
    preds = torch.vstack(preds).detach().cpu().numpy()
    ja, prauc, f1 = multi_label_metric(preds, labels)
    print('\r    Evaluation: loss: %.4f --- AUPRC: %.4f --- F1: %.4f --- Jaccard: %.4f'
          % (avg_loss, prauc, f1, ja))
    return avg_loss, None


def evaluate_multi_label_old(model, dataset, loss_fn, output_size, threshold=0.5):
    warnings.filterwarnings("ignore", message="No positive class found in y_true")
    warnings.filterwarnings("ignore", category=UndefinedMetricWarning)
    model.eval()
    total_loss = 0.0
    labels = dataset.label()
    preds = []
    for step in range(len(dataset)):
        code_x, visit_lens, y = dataset[step]
        output = model(code_x).squeeze()
        pred = torch.sigmoid(output)
        preds.append(pred)
        loss = loss_fn(output, y)
        total_loss += loss.item() * output_size * len(code_x)
        print('\r    Evaluating step %d / %d' % (step + 1, len(dataset)), end='')
    avg_loss = total_loss / dataset.size()
    preds = torch.vstack(preds).detach().cpu().numpy()
    metrics = ['pr_auc_micro', 'roc_auc_micro', 'f1_samples', 'jaccard_micro']
    eval_dict = multilabel_metrics_fn(labels, preds, metrics=metrics, threshold=threshold)
    print('\r    Evaluation: loss: %.4f --- AUPRC: %.4f --- AUROC: %.4f --- F1: %.4f --- Jaccard: %.4f'
          % (avg_loss, eval_dict[metrics[0]], eval_dict[metrics[1]],
             eval_dict[metrics[2]], eval_dict[metrics[3]]))
    return avg_loss, None


def multi_label_metric(pre, gt, threshold=0.4):
    """
    pre is a float matrix in [0, 1]
    gt is a binary matrix
    """
    def jaccard(pre, gt):
        score = []
        for b in range(gt.shape[0]):
            target = np.where(gt[b] == 1)[0]
            predicted = np.where(pre[b] >= threshold)[0]
            inter = set(predicted) & set(target)
            union = set(predicted) | set(target)
            jaccard_score = 0 if union == 0 else len(inter) / len(union)
            score.append(jaccard_score)
        return np.mean(score)
    
    def precision_auc(pre, gt):
        all_micro = []
        for b in range(gt.shape[0]):
            all_micro.append(average_precision_score(gt[b], pre[b], average='macro'))
        return np.mean(all_micro)
    
    def prc_recall(pre, gt):
        score_prc = []
        score_recall = []
        for b in range(gt.shape[0]):
            target = np.where(gt[b] == 1)[0]
            predicted = np.where(pre[b] >= threshold)[0]
            inter = set(predicted) & set(target)
            prc_score = 0 if len(predicted) == 0 else len(inter) / len(predicted)
            recall_score = 0 if len(target) == 0 else len(inter) / len(target)
            score_prc.append(prc_score)
            score_recall.append(recall_score)
        return score_prc, score_recall

    def average_f1(prc, recall):
        score = []
        for idx in range(len(prc)):
            if prc[idx] + recall[idx] == 0:
                score.append(0)
            else:
                score.append(2*prc[idx]*recall[idx] / (prc[idx] + recall[idx]))
        return np.mean(score)

    ja = jaccard(pre, gt)
    prauc = precision_auc(pre, gt)
    prc_ls, recall_ls = prc_recall(pre, gt)
    f1 = average_f1(prc_ls, recall_ls)

    return ja, prauc, f1
