from sklearn.metrics import roc_curve, roc_auc_score
from itertools import chain
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from tqdm import tqdm
from scipy import stats
from posthoc import ResultAnalyzer


def smooth(x, n=3):
    window = np.ones(n) / n
    return np.convolve(x, window, mode='same')

def flip(x):
    return 1 - (x - min(x) / (max(x) - min(x)))

def flatten(x):
    return np.array(list(chain(*x)))

def roc_threshold(threshold, smoother=lambda x:x):
    def roc_(preds, gts):
        preds = [smoother(i) for i in preds]
        f_preds = flatten(preds)
        f_preds = np.nan_to_num(f_preds, nan=0)
        f_gts = flatten(gts) <= threshold
        fpr, tpr, thres = roc_curve(f_gts, f_preds)
        return fpr, tpr
    return roc_

def plot_method_roc(analyzer:ResultAnalyzer, figsize=(15, 5), metric=roc_threshold(3.7), subtitle_map={}):
    fig = plt.figure(figsize=figsize)
    axes = fig.subplots(1, 3, sharey=True)
    method_order = [
        'QCAI', 'AttnLRP', 'TokenTM', 'AttCAT',
        'Rollout', 'GradCAM', 'LRP', 'RawAttn'
    ]
    for ax, _chain in zip(
        axes, ['alpha', 'beta', 'epitope']):
        for _method in method_order:
            if _method in list(analyzer.keys()):
                fpr, tpr = metric(analyzer[_method, _chain], analyzer.benchmark[_chain])
                if _method == 'QCAI':
                    ax.plot(fpr, tpr, label=_method)
                else:
                    ax.plot(fpr, tpr, label=_method, linestyle='--')
        if _chain in subtitle_map:
            ax.set_title(subtitle_map[_chain])
        else:
            ax.set_title(_chain)
        ax.set_xlabel('False Positive Rate (FPR)')
        if _chain == 'alpha':
            ax.set_ylabel('True Positive Rate (TPR)')
        ax.plot(np.arange(0, 1.1, 0.1), np.arange(0, 1.1, 0.1), color='black', linestyle=':', alpha=0.75)
        if _chain == 'epitope':
            ax.legend(bbox_to_anchor=(0.2, -0.15), ncol=4, fancybox=True)
        ax.grid(linestyle=':')
    return fig
def roc_auc_threholds(thresholds, smoother=lambda x:x):
    def roc_auc_(preds, gts):
        preds = [smoother(i) for i in preds]
        f_preds = flatten(preds)
        f_preds = np.nan_to_num(f_preds, nan=0)
        _auc_rocs = []
        for thres in thresholds:
            _auc_rocs.append(roc_auc_score(flatten(gts) <= thres, f_preds)) 
        return thresholds, _auc_rocs
    return roc_auc_
def plot_roc_auc(analyzer:ResultAnalyzer, figsize=(18, 3), metric=roc_auc_threholds(np.arange(3, 6.5, 0.5)), xlabel='Distance Thresholds (A)', ylabel='ROC-AUC', subtitle_map={}):
    fig = plt.figure(figsize=figsize)
    axes = fig.subplots(1, 3)
    method_order = [
        'QCAI', 'AttnLRP', 'TokenTM', 'AttCAT',
        'Rollout', 'GradCAM', 'LRP', 'RawAttn'
    ]
    for ax, _chain in zip(
        axes, ['alpha', 'beta', 'epitope']):
        for _method in method_order:
            if _method in list(analyzer.keys()):
                x, y = metric(analyzer[_method, _chain], analyzer.benchmark[_chain])
                if _method == 'QCAI':
                    ax.plot(x, y, label=_method, marker='.')
                else:
                    ax.plot(x, y, label=_method, linestyle='--')
        if _chain in subtitle_map:
            ax.set_title(subtitle_map[_chain])
        else:
            ax.set_title(_chain)
        ax.set_xlabel(xlabel=xlabel)
        if _chain == 'alpha':
            ax.set_ylabel(ylabel=ylabel)
        if _chain == 'epitope':
            ax.legend(bbox_to_anchor=(0.2, -0.15), ncol=4, fancybox=True)
        ax.grid(linestyle=':')
    return fig

def cal_roc_auc(analyzer:ResultAnalyzer, metric=roc_auc_threholds(np.arange(3, 6.5, 0.5))):
    d = {}
    for _chain in ['epitope', 'alpha', 'beta']:
        d[_chain] = {}
        for _method in analyzer.keys():
            x, y = metric(analyzer[_method, _chain], analyzer.benchmark[_chain])
            d[_chain][_method] = (x, y)
    return d

def hits_rate_threholds(thresholds=np.arange(0.1, 0.8, 0.1), smoother=lambda x:x):
    def hits_rate_(preds, gts):
        preds = [smoother(i) for i in preds]
        hits_rate = []
        for thres in thresholds:
            _hrs = []
            for pred, gt in zip(preds, gts):
                # hit_thres = np.quantile(pred, thres)
                # dist_threhold = np.quantile(gt, 1-thres)
                # hit_num = np.sum((np.array(pred) > hit_thres) & (np.array(gt) < dist_threhold))
                # _hit_rate = hit_num / np.sum(np.array(gt) < dist_threhold)
                _indices = np.argsort(pred)[::-1][:int(np.ceil(thres * len(pred)))]
                _hit_rate = np.mean(gt[_indices] < np.quantile(gt, 1-thres))
                # _hit_rate = np.mean(gt[pred > np.quantile(pred, thres)] < np.quantile(gt, 1-thres))
                _hrs.append(_hit_rate)
            hits_rate.append(np.mean(_hrs))
        return thresholds, hits_rate
    return hits_rate_

def cal_hits_rate(preds, gts, thresholds=np.arange(0.1, 0.8, 0.1), smoother=lambda x:x, score_filter=None):
    preds = [smoother(i) for i in preds]
    hits_rate = []
    for thres in thresholds:
        _hrs = []
        for i, (pred, gt) in enumerate(zip(preds, gts)):
            # hit_thres = np.quantile(pred, thres)
            # dist_threhold = np.quantile(gt, 1-thres)
            # hit_num = np.sum((np.array(pred) > hit_thres) & (np.array(gt) < dist_threhold))
            # _hit_rate = hit_num / np.sum(np.array(gt) < dist_threhold)
            if score_filter is not None and not score_filter[i]: continue
            _indices = np.argsort(pred)[::-1][:int(np.ceil(thres * len(pred)))]
            _hit_rate = np.mean(gt[_indices] < np.quantile(gt, 1-thres))
            # _hit_rate = np.mean(gt[pred > np.quantile(pred, thres)] < np.quantile(gt, 1-thres))
            _hrs.append(_hit_rate)
        hits_rate.append(_hrs)
    return thresholds, hits_rate

def table_mean_hits_rate(analyzer, thresholds=np.arange(0.25, 0.8, 0.05)):
    cmHR = {}
    for _chain in ['epitope', 'alpha', 'beta']:
        mHR = {}
        for _method in analyzer.keys():
            thresholds, hits_rate = cal_hits_rate([smooth(i) for i in analyzer[_method, _chain]], analyzer.benchmark[_chain], thresholds=thresholds)
            mHR[_method] = [np.mean(i) for i in hits_rate]
        cmHR[_chain] = pd.DataFrame(mHR, index=['mHR'+f'{i:.2f}'[1:] for i in thresholds])
    df = pd.concat(list(cmHR.values()), axis=1)
    df.columns = pd.MultiIndex.from_tuples([(_c, _m) for _c in cmHR for _m in cmHR[_c].columns], names=('chain', 'method'))
    df = df.applymap(lambda x: f"{x * 100:.2f}%").T
    return df

def table_mean_hits_rate_diff(analyzer, thresholds=np.arange(0.25, 0.8, 0.05), score_filter=None):
    cmHR = {}
    method_order = [
        'QCAI', 'AttnLRP', 'TokenTM', 'AttCAT',
        'Rollout', 'GradCAM', 'LRP', 'RawAttn',
        'QCAI Max.', 'QCAI Avg.'
    ]
    for _chain in ['epitope', 'alpha', 'beta']:
        mHR = {}
        for _method in method_order:
            if _method not in analyzer.keys(): continue
            thresholds, hits_rate = cal_hits_rate([smooth(i) for i in analyzer[_method, _chain]], analyzer.benchmark[_chain], thresholds=thresholds, score_filter=score_filter)
            mHR[_method] = [f"{np.mean(i) * 100:.1f}(+/-{np.std(i)*100:.1f})%" for i in hits_rate]
        cmHR[_chain] = pd.DataFrame(mHR, index=['mHR'+f'{i:.2f}'[1:] for i in thresholds])
    df = pd.concat(list(cmHR.values()), axis=1)
    df.columns = pd.MultiIndex.from_tuples([(_c, _m) for _c in cmHR for _m in cmHR[_c].columns], names=('chain', 'method'))
    # df = df.applymap(lambda x: f"{x * 100:.2f}%").T
    return df

def cal_positive_num(analyzer:ResultAnalyzer, cthres=5):
    a = np.mean([np.sum(i<cthres) for i in analyzer.benchmark.distances['alpha']])
    b = np.mean([np.sum(i<cthres) for i in analyzer.benchmark.distances['beta']])
    e = np.mean([np.sum(i<cthres) for i in analyzer.benchmark.distances['epitope']])
    return {'alpha':a, 'beta':b, 'epitope':e}

def cal_pvalues(analyzer:ResultAnalyzer, thresholds = [3.4, 4]):
    method_order = [
        'QCAI', 'AttnLRP', 'TokenTM', 'AttCAT',
        'Rollout', 'GradCAM', 'LRP', 'RawAttn',
        'QCAI Max.', 'QCAI Avg.'
    ]
    
    def _metric(preds, gts):
        preds = [smooth(np.nan_to_num(i, nan=0), n=3) for i in preds]
        _auc_rocs = {}
        for thres in thresholds:
            _vs = []
            for gt, pred in zip(gts, preds):
                if np.sum(gt <= thres) > 0:
                    try:
                        _vs.append(roc_auc_score(gt <= thres, pred))
                    except:
                        print(np.sum(gt <= thres), gt)  
                        _vs.append(np.nan)
                else:
                    _vs.append(np.nan)
            _auc_rocs[thres] = _vs
        return _auc_rocs
    m_data = {}
    for _chain in ['alpha', 'beta', 'epitope']:
        m_data[_chain] = {}
        for _method in method_order:
            if _method in list(analyzer.keys()):
                y = _metric(analyzer[_method, _chain], analyzer.benchmark[_chain])
                m_data[_chain][_method] = y
    pvalues = {}
    for _c, _cv in m_data.items():
        for _m, _mv in _cv.items():
            # if _m != 'QCAI': continue
            pvalues[(_c, _m)] = {}
            for _t, _tv in _mv.items():
                d = np.array([i for i in _tv if i is not np.nan])
                try:
                    p = stats.ttest_1samp(d, popmean=0.5)
                    pvalues[(_c, _m)][_t] = p.pvalue
                except:
                    pvalues[(_c, _m)][_t] = np.nan
    return pvalues

def cal_perturbation(analyzer:ResultAnalyzer, node, model, dataloader, proba, ks=range(3, 17, 1)):
    bar = tqdm(dataloader, total = len(dataloader))
    model = model.eval()
    probas = [[], [], []]
    for peptide, alpha, beta, binder, mhc in bar:
        peptide_input, peptide_mask = peptide['input_ids'], peptide["attention_mask"]
        alpha_input, alpha_mask = alpha['input_ids'], alpha["attention_mask"]
        beta_input, beta_mask = beta['input_ids'], beta["attention_mask"]
        labels = binder

        node.model_hook.clean()
        out = model(input_ids=(alpha_input,beta_input,peptide_input),
                    attention_mask=(alpha_mask,beta_mask,peptide_mask),
                    labels=labels,
                    mhc=mhc,
                    output_hidden_states=True,
                    output_attentions=True)
        groundtruths=(alpha_input,beta_input,peptide_input)
        for i, p in enumerate(proba(out, groundtruths)):
            probas[i].append(p.detach().cpu().numpy())
    baseline_probas = [torch.tensor(np.concatenate(p)) for p in probas]

    def consmooth(x):
        return [min(x), *x, min(x)]
    probas = {}
    for method in tqdm(analyzer.keys()):
        probas[method] = {}
        for k in ks:
            probas[method][k] = [[], [], []]
            _b_s = 0
            for peptide, alpha, beta, binder, mhc in dataloader:
                peptide_input, peptide_mask = peptide['input_ids'], peptide["attention_mask"]
                # print(min([peptide_input.size(-1)-2-len(x) for x in analyzer[method, 'epitope'][_b_s:_b_s+len(peptide_input)]]))
                # print(max([len(x) for x in analyzer[method, 'epitope'][_b_s:_b_s+len(peptide_input)]]))
                # print(peptide_input.size(-1))
                idx = torch.tensor(np.stack([np.pad(consmooth(x), (0, peptide_input.size(-1)-2-len(x))) for x in analyzer[method, 'epitope'][_b_s:_b_s+len(peptide_input)]]))
                peptide_input[torch.arange(idx.size(0)).reshape(-1, 1), idx.topk(k, dim=-1).indices] = 1
                alpha_input, alpha_mask = alpha['input_ids'], alpha["attention_mask"]
                idx = torch.tensor(np.stack([np.pad(consmooth(x), (0, alpha_input.size(-1)-2-len(x))) for x in analyzer[method, 'alpha'][_b_s:_b_s+len(peptide_input)]]))
                alpha_input[torch.arange(idx.size(0)).reshape(-1, 1), idx.topk(k, dim=-1).indices] = 1
                beta_input, beta_mask = beta['input_ids'], beta["attention_mask"]
                idx = torch.tensor(np.stack([np.pad(consmooth(x), (0, beta_input.size(-1)-2-len(x))) for x in analyzer[method, 'beta'][_b_s:_b_s+len(peptide_input)]]))
                beta_input[torch.arange(idx.size(0)).reshape(-1, 1), idx.topk(k, dim=-1).indices] = 1
                labels = binder
                _b_s += len(peptide_input)
                node.model_hook.clean()
                out = model(input_ids=(alpha_input,beta_input,peptide_input),
                            attention_mask=(alpha_mask,beta_mask,peptide_mask),
                            labels=labels,
                            mhc=mhc,
                            output_hidden_states=True,
                            output_attentions=True)
                groundtruths=(alpha_input,beta_input,peptide_input)
                for i, p in enumerate(proba(out, groundtruths)):
                    probas[method][k][i].append(p.detach().cpu().numpy())
            probas[method][k] = [np.concatenate(p) for p in probas[method][k]]
            
    LOdds, AOPCs = {}, {}
    full_LOdds, full_AOPCs = {}, {}
    for method in probas:
        for i, c in enumerate(['Alpha', 'Beta', 'Epitope']):
            LOdds[(c, method)] = {}
            AOPCs[(c, method)] = {}
            full_LOdds[(c, method)] = {}
            full_AOPCs[(c, method)] = {}
            for prob in probas[method]:
                LOdds[(c, method)][prob] = np.mean(np.log(probas[method][prob][i] / baseline_probas[i].detach().numpy()))
                AOPCs[(c, method)][prob] = np.mean(baseline_probas[i].detach().numpy() - probas[method][prob][i])
                full_LOdds[(c, method)][prob] = np.log(probas[method][prob][i] / baseline_probas[i].detach().numpy())
                full_AOPCs[(c, method)][prob] = baseline_probas[i].detach().numpy() - probas[method][prob][i]
    LOdds_df = pd.DataFrame(LOdds).T.sort_index(axis=0)
    AOPCs_df = pd.DataFrame(AOPCs).T.sort_index(axis=0)
    method_order = [
        'QCAI', 'AttnLRP', 'TokenTM', 'AttCAT',
        'Rollout', 'GradCAM', 'LRP', 'RawAttn',
        'QCAI Max.', 'QCAI Avg.'
    ]
    LOdds_df = LOdds_df.reindex(method_order, level=1)
    AOPCs_df = AOPCs_df.reindex(method_order, level=1)
    return AOPCs_df, LOdds_df, full_AOPCs, full_LOdds