import torch
import pandas as pd
import os
os.environ["KERAS_BACKEND"] = "torch"


def get_eval_df(nmis_t, scs_t, d_names, prefix='pred_'):
    df_dict = {}
    df_dict['d_name'] = d_names
    df_dict[f'{prefix}nmi'] = nmis_t
    df_dict[f'{prefix}sc'] = scs_t
    # df_dict[f'{prefix}losses'] = losses_t

    df = pd.DataFrame(df_dict)
    return df

def get_gt_by_name(d_name, gt_dict):
    d_name = (d_name, )
    tsne_v = gt_dict[d_name]['tsne']
    umap_v = gt_dict[d_name]['umap']
    return tsne_v, umap_v


def get_gt_df(d_names, prefix=''):
    gt_save = torch.load('/mnt/data01/public/aad_data/clip_tsne_w_umap_ground_truth.tar', weights_only=False)
    nmi_gt_dict, sc_gt_dict = gt_save
    # print(list(nmi_gt_dict.keys())[:10])
    df_dict = {}
    for d_name in d_names:
        df_dict['d_name'] = df_dict.get('d_name', []) + [d_name]
        tsne_nmi, umap_nmi = get_gt_by_name(d_name, nmi_gt_dict)
        tsne_sc, umap_sc = get_gt_by_name(d_name, sc_gt_dict)
        df_dict[f'{prefix}gt_tsne_nmi'] = df_dict.get(f'{prefix}gt_tsne_nmi', []) + [tsne_nmi]
        df_dict[f'{prefix}gt_umap_nmi'] = df_dict.get(f'{prefix}gt_umap_nmi', []) + [umap_nmi]
        df_dict[f'{prefix}gt_tsne_sc'] = df_dict.get(f'{prefix}gt_tsne_sc', []) + [tsne_sc]
        df_dict[f'{prefix}gt_umap_sc'] = df_dict.get(f'{prefix}gt_umap_sc', []) + [umap_sc]
    df = pd.DataFrame(df_dict)
    return df


def get_mean_performance(df, by='tsne'):
    relative_tsne_nmi_precision = df['pred_nmi'] / df[f'gt_{by}_nmi']
    relative_tsne_nmi_gap = df[f'gt_{by}_nmi'] - df['pred_nmi']

    relative_tsne_sc_precision = df['pred_sc'] / df[f'gt_{by}_sc']
    relative_tsne_sc_gap = df[f'gt_{by}_sc'] - df['pred_sc']

    df['relative_tsne_nmi_precision'] = relative_tsne_nmi_precision
    df['relative_tsne_nmi_gap'] = relative_tsne_nmi_gap

    df['relative_tsne_sc_precision'] = relative_tsne_sc_precision
    df['relative_tsne_sc_gap'] = relative_tsne_sc_gap

    mean_df = df.drop('d_name', axis=1).mean()
    std_df = df.drop('d_name', axis=1).std()
    return mean_df, std_df


if __name__ == '__main__':
    # train_nmis, train_scs, test_nmis, train_d, test_scs, test_d, ind_tsne = torch.load('./ind_tsne_res.save', weights_only=False)
    # train_nmis, train_scs, test_nmis, train_d, test_scs, test_d, ind_tsne = torch.load('./ae_res.save', weights_only=False)
    # train_nmis, train_scs, test_nmis, train_d, test_scs, test_d, ind_tsne = torch.load('./pca_res.save', weights_only=False)
    # train_nmis, train_scs, test_nmis, train_d, test_scs, test_d, ind_tsne = torch.load('./ind_umap_res.save', weights_only=False)
    train_nmis, train_scs, test_nmis, train_d, test_scs, test_d, ind_tsne = torch.load('./pumap2_res2.save', weights_only=False)
    train_d = train_d[-1]
    test_d = test_d[-1]
    train_df = get_eval_df(train_nmis, train_scs, train_d)
    gt_df = get_gt_df(train_d)
    df = pd.merge(gt_df, train_df, on='d_name')
    res = get_mean_performance(df, by='umap')

    print('train')
    print(res)
    print('============')
    print('test')
    train_df = get_eval_df(test_nmis, test_scs, test_d)
    gt_df = get_gt_df(test_d)
    df = pd.merge(gt_df, train_df, on='d_name')
    # print(df)
    res = get_mean_performance(df, by='umap')
    print(res)

