import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.use('Agg')
plt.set_loglevel("info")

import numpy as np
from pandas import DataFrame
from collections import OrderedDict
import torch
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score
from sklearn.metrics.pairwise import cosine_distances, cosine_similarity
from sklearn.linear_model import LogisticRegression
from sklearn.manifold import TSNE
from sklearn.neighbors import KNeighborsClassifier
import wandb
import seaborn as sns

from gpl.utils.mmd.mmd_critic import select_prototypes_criticisms


def validation_metric_functon(val_results):
    
    clf_logits = np.concatenate( val_results['clf_logits'], axis=0)
    y = np.concatenate( val_results['y'], axis=0)

    acc, auc = compute_acc_auc(clf_logits, y)
    return auc




def tsne_scatter_plot(coord, y):
    if y.ndim == 2:
        y = y.reshape(-1)
    num_classes = len(np.unique(y))
    df = {
        'tsne-2d-one': coord[:, 0],
        'tsne-2d-two': coord[:, 1],
        'y': y
    }
    df = DataFrame(df)
    fig, axes = plt.subplots(1, 1, figsize=(16,10))
    sns.scatterplot(
    x="tsne-2d-one", y="tsne-2d-two",
    hue="y",
    palette=sns.color_palette("hls", num_classes),
    data=df,
    legend="full",
    alpha=0.3,
    ax=axes
    )
    return fig


def tsne_emb_figure(embs, y, must_in_idxs=None):
    if isinstance(embs, torch.Tensor):
        embs = embs.detach().cpu().numpy()
    elif isinstance(embs, list):
        embs = np.concatenate(embs, axis=0)
    
    idx = np.random.permutation(len(embs))
    sampled_num = 2000
    idx = idx[:sampled_num]

    if must_in_idxs is not None:
        idx = np.concatenate([idx, must_in_idxs])

    embs = embs[idx]
    y = y[idx]

    tsne = TSNE(n_iter=1000)
    coord = tsne.fit_transform(embs, None)
    fig = tsne_scatter_plot(coord, y)
    return fig, coord, tsne

def tmp_centroid_dis(embs1, embs2):
    import ipdb; ipdb.set_trace()
    embs1_mean = embs1.mean(axis=0)
    embs2_mean = embs2.mean(axis=0)

    dis11 = np.linalg.norm(embs1 - embs1_mean)
    dis12 = np.linalg.norm(embs1 - embs2_mean)
    print('dis11: ', dis11)
    print('dis12: ', dis12)

def vis_Z_callback(test_results, **kwargs):
    embs = np.concatenate( test_results['Z'] )
    y = np.concatenate( test_results['y'] )

    
    embs0 = embs[y==0]
    embs1 = embs[y==1]
    embs2 = embs[y==2]
    
    
    tmp_centroid_dis(embs0, embs1)
    import ipdb; ipdb.set_trace()

   

def get_embeddings(model, dataloader, key):

    model.eval()
    if isinstance(key, str):
        embs_all = []
    elif isinstance(key, list): 
        embs_all = {}
        for k in key:
            embs_all[k] = []
    else: raise ValueError

    y_all = []
    with torch.no_grad():
        for batch in dataloader:
            results_dict = model.get_embs(batch)
            if isinstance(key, str):
                embs = results_dict[key]
                embs_all.append(embs)
            elif isinstance(key, list):
                for k in key:
                    embs_all[k].append(results_dict[k])
            else: raise ValueError
            y_all.append(batch.y)
        
        indices = np.array(dataloader.sampler.indexes)

    if isinstance(key, str):
        embs_all = torch.cat(embs_all, dim=0).cpu().numpy()
    elif isinstance(key, list):
        for k in key:
            embs_all[k] = torch.cat(embs_all[k], dim=0).cpu().numpy()

    y_all = torch.cat(y_all, dim=0).cpu().numpy()

    return embs_all, y_all, indices



def find_Kmedoids(train_embs, train_y, k=10):
    
    from sklearn_extra.cluster import KMedoids
    kmedoids = KMedoids(n_clusters=k, random_state=0).fit(train_embs)
    medoid_indices_ = kmedoids.medoid_indices_
    medoids, labels = train_embs[medoid_indices_], train_y[medoid_indices_]
    return medoids, labels


def KNNEvaluate(proto_embs, proto_y, test_embs, test_y):
    classifier = KNeighborsClassifier(n_neighbors=5, weights='distance')
    classifier.fit(proto_embs, proto_y)
    num_classes = len(np.unique(test_y))

    pred = classifier.predict_proba(test_embs)
    pred_hard = np.argmax(pred, axis=1)

    acc = (pred_hard == test_y).mean()

    if pred.shape[1] == 1:
        auc = roc_auc_score(y_true=test_y, y_score=pred)
        print(f'pred shape is {pred.shape}')
    elif num_classes == 2:
        auc = roc_auc_score(y_true=test_y, y_score=pred[:, 1])
    elif num_classes > 2:
        auc = roc_auc_score(y_true=test_y, y_score=pred, multi_class='ovr')
    else:
        raise NotImplementedError

    return acc, auc


def LREmbEvaluate(train_embs, train_y, test_embs, test_y):
    num_classes = len(np.unique(train_y))


    test_y = test_y.flatten() 
    reg = LogisticRegression().fit(train_embs, train_y)
    pred = reg.predict_proba(test_embs)

    pred_hard = np.argmax(pred, axis=1)
    acc = (pred_hard == test_y).mean()

    if num_classes == 2:
        auc = roc_auc_score(y_true=test_y, y_score=pred[:, 1])
    elif num_classes > 2:
        auc = roc_auc_score(y_true=test_y, y_score=pred, multi_class='ovr')
    else:
        raise NotImplementedError
    
    return acc, auc




def embedding_evaluate(model, dataloaders):
    train_embs, train_y, _ = get_embeddings(model, dataloaders.train_dataloader, key='subg_embs')
    
    test_embs, test_y, _ = get_embeddings(model, dataloaders.test_dataloader, key='subg_embs')
    

    test_embs = train_embs
    test_y = train_y

    
    acc, auc = LREmbEvaluate(train_embs, train_y, test_embs, test_y)

    return acc, auc, (train_embs, train_y)


def embedding_evaluate_callback(test_results, **kwargs):
    model = kwargs['model']
    dataloaders = kwargs['dataloaders']
    acc, auc, _ = embedding_evaluate(model, dataloaders)
    print(f'embeding evaluation. acc: {acc:.6f}, auc: {auc:.6f}')
    return acc, auc


def compute_auc(preds, labels, binary=True):
    # preds are logits
    if binary is True:
        auc = roc_auc_score(y_true=labels, y_score=preds)
    else:
        probabilities = torch.tensor(preds).softmax(dim=1).numpy()
        auc = roc_auc_score(y_true=labels, y_score=probabilities, multi_class='ovr')
    
    return auc

def compute_acc_auc(clf_logits, y):
    if clf_logits.shape[1] == 1: # binary classificarion
        preds = (clf_logits >= 0).flatten()
        y = y.flatten()
        acc = (preds == y).mean() # it may be >=2 classes here.
        auc = compute_auc(clf_logits, y, binary=True)

    else: # multiple classificarion (not multi-label)
        acc = (np.argmax(clf_logits, axis=1) == y).mean()
        auc = compute_auc(clf_logits, y, binary=False)
        # auc = -1
    return acc, auc


def compute_class_embs_cosine_similarity(batch_embs, batch_y):
    cos_sims_list = []
    for y in np.unique(batch_y):
        idx = batch_y == y
        class_embs = batch_embs[idx]
        centroid = np.mean(class_embs, axis=0)
        cos_sims = cosine_distances(batch_embs, centroid.reshape(1, -1)) # sklearn version

        cos_sims_list.append(cos_sims)
    
    sim = np.concatenate(cos_sims_list, axis=0).mean()

    return 1 - sim # the smaller, the more coherent

def loss_acc_auc(results):
    loss = np.mean(results['loss'] )

    # acc and auc
    clf_logits = np.concatenate(results['clf_logits'])
    y = np.concatenate(results['y'])
    acc, auc = compute_acc_auc(clf_logits, y)
    return loss, acc, auc

def explain_precision_at_k(results, k):
    """
    att, exp_labels, k, batch, edge_index
    exp_labels: explanation labels
    """
    att_all = results['edge_mask']
    exp_labels_all = results['exp_labels']
    batch_all = results['batch']
    edge_index_all = results['edge_index']

    precision_at_k = []
    for att, exp_labels, batch, edge_index in zip(att_all, exp_labels_all, batch_all, edge_index_all): # 所有的batch
        att = att.flatten()
        exp_labels = exp_labels.flatten()
        batch = batch.flatten()

        for i in range(batch.max()+1): 
            nodes_for_graph_i = batch == i
            edges_for_graph_i = nodes_for_graph_i[edge_index[0]] & nodes_for_graph_i[edge_index[1]]
            labels_for_graph_i = exp_labels[edges_for_graph_i]
            mask_log_logits_for_graph_i = att[edges_for_graph_i] 
            precision_at_k.append(labels_for_graph_i[np.argsort(-mask_log_logits_for_graph_i)[:k]].sum().item() / k)
    
    return np.mean(precision_at_k)


def explain_auc(results):
    att_all = np.concatenate( results['edge_mask'], axis=0)
    exp_labels_all = np.concatenate( results['exp_labels'], axis=0)

    att_all = att_all.flatten()
    exp_labels_all = exp_labels_all.flatten()

    if np.unique(exp_labels_all).shape[0] > 1:
        att_auroc = roc_auc_score(exp_labels_all, att_all)
    else:
        att_auroc = 0

    return att_auroc

def write_logs_to_tensorboard(writer, logging_dict):
    cur_epoch = logging_dict['epoch']
    for key, val in logging_dict.items():
        writer.add_scalar(key, val, cur_epoch)

def write_logs_to_wandb(logging_dict):
    wandb.log(logging_dict)


def prototype_performance(**kwargs):
    # prepare
    dataloaders = kwargs['dataloaders']
    model = kwargs['model']
    train_dataloader = dataloaders.train_dataloader
    test_dataloader = dataloaders.test_dataloader
    device = next(model.parameters()).device

    embs_dict, train_y, _ = get_embeddings(model, train_dataloader, key=['subg_embs', 'embs_recon_graph'])
    train_embs = np.concatenate([embs_dict['subg_embs'], embs_dict['embs_recon_graph']], axis=1)
    train_y = train_y.reshape(-1)

    embs_dict, test_y, _ = get_embeddings(model, test_dataloader, key=['subg_embs', 'embs_recon_graph'])
    test_embs = np.concatenate([embs_dict['subg_embs'], embs_dict['embs_recon_graph']], axis=1)
    test_y = test_y.reshape(-1)

    # prototypes
    max_prototypes = 100
    prototyes_dataset = select_prototypes_criticisms(train_embs, train_y, num_prototypes=max_prototypes)
    prototype_indices = prototyes_dataset.prototype_indices.numpy()
    criticism_indices = prototyes_dataset.criticism_indices.numpy()
    
    prototypes = prototyes_dataset.prototypes
    prototype_labels = prototyes_dataset.prototype_labels
    criticisms = prototyes_dataset.criticisms
    criticism_labels = prototyes_dataset.criticism_labels

    # LR/KNN
    auc_list_ours = []
    acc_list_ours = []

    auc_list_kmed = []
    acc_list_kmed = []
    for proto_num in [20, 40, 60, 80, 100]:
        print(f'######### prot_num: {proto_num}')
        prototypes_part = prototypes[:proto_num]
        prototype_labels_part = prototype_labels[:proto_num]
        acc, auc = KNNEvaluate(prototypes_part, prototype_labels_part, test_embs, test_y)
        print(f'prototype acc: {acc:.4f}, auc: {auc:.4f}')
        auc_list_ours.append(auc)
        acc_list_ours.append(acc)

        # k-medoids performance
        prototypes_part, prototype_labels_part = find_Kmedoids(train_embs, train_y, k=proto_num)
        acc, auc = KNNEvaluate(prototypes_part, prototype_labels_part, test_embs, test_y)
        print(f'k-medoids acc: {acc:.4f}, auc: {auc:.4f}')
        auc_list_kmed.append(auc)
        acc_list_kmed.append(acc)
    
    print('ours auc: ', [f'{val:.4f}' for val in auc_list_ours])
    print('ours acc: ', [f'{val:.4f}' for val in acc_list_ours])
    print('kmed auc: ', [f'{val:.4f}' for val in auc_list_kmed])
    print('kmed acc: ', [f'{val:.4f}' for val in acc_list_kmed])
    

def compute_mmd():
    pass

def compare_reconstruct_callback(**kwargs):
    from torch_geometric.data import Batch
    from gpl.utils.mmd.mmd_critic import compute_mmd_distance
    dataloaders = kwargs['dataloaders']
    model = kwargs['model']
    train_dataloader = dataloaders.train_dataloader
    test_dataloader = dataloaders.test_dataloader
    device = next(model.parameters()).device

    with_reconstruct = kwargs['__trainer__'].hparams_save.framework.with_reconstruct
    print('with_reconstruct: ', with_reconstruct)

    # prototype info
    embs_dict, train_y, _ = get_embeddings(model, train_dataloader, key=['subg_embs', 'embs_recon_graph'])
    if with_reconstruct:
        train_embs = np.concatenate([embs_dict['subg_embs'], embs_dict['embs_recon_graph']], axis=1)
    else:
        train_embs = embs_dict['subg_embs']
    train_y = train_y.reshape(-1)

    # prototypes
    max_prototypes = 100
    prototyes_dataset = select_prototypes_criticisms(train_embs, train_y, num_prototypes=max_prototypes)
    prototype_indices = prototyes_dataset.prototype_indices.numpy()

    train_emb_indices = dataloaders.train_dataloader.sampler.indexes
    proto_in_dataset_indices = np.array(train_emb_indices)[prototype_indices]
    
    dataset_embs = []
    start_idx = 0
    while start_idx < len(train_dataloader.dataset):
        end_idx = min(start_idx+256, len(train_dataloader.dataset))
        batch = Batch.from_data_list([train_dataloader.dataset[i] for i in range(start_idx, end_idx)])
        dataset_embs.append(model.get_subg_encoder_embs(batch))
        start_idx = end_idx
    dataset_embs = torch.cat(dataset_embs, dim=0)

   
    mmd_distance = []
    for num in [20, 40, 60, 80, 100]:
        dis = compute_mmd_distance(dataset_embs, proto_in_dataset_indices[:num])
        mmd_distance.append(dis)
    print(mmd_distance)



def prediction_task_callback(**kwargs):
    has_val_set = kwargs.get('val_results', None)

    logger = kwargs['logger']
    cur_epoch = kwargs['cur_epoch']

    # loss
    loss_train, acc_train, auc_train = loss_acc_auc(kwargs['train_results'])
    loss_test, acc_test, auc_test = loss_acc_auc(kwargs['test_results'])
    if has_val_set and has_val_set is not None:
        loss_val, acc_val, auc_val = loss_acc_auc(kwargs['test_results'])
    
    y_test = np.concatenate(kwargs['test_results']['y'])
    assert y_test.ndim == 1 or y_test.shape[1] == 1, 'only care about non-multi-label case now'
   

    logger.info('###############################')
    logger.info(f'[epoch {cur_epoch+1}]')
    logger.info(f'Loss_train: {loss_train:.4f}, Acc_train: {acc_train:.4f}, Auc_train: {auc_train:.4f}')
    logger.info(f'Loss_test : {loss_test:.4f}, Acc_test : {acc_test:.4f}, Auc_test : {auc_test:.4f}')
    logger.info('###############################')
    

    LOGGING_DICT = OrderedDict()
    
   
    LOGGING_DICT['overall_loss/train'] = loss_train
    LOGGING_DICT['overall_loss/test'] = loss_test
    

    LOGGING_DICT['acc/train'] = acc_train
    LOGGING_DICT['acc/test'] = acc_test
    LOGGING_DICT['auc/train'] = auc_train
    LOGGING_DICT['auc/test'] = auc_test
    if has_val_set:
        LOGGING_DICT['acc/val'] = acc_val
        LOGGING_DICT['auc/val'] = auc_val
    
    
    for key, val in LOGGING_DICT.items():
        logger.info(f"{key}: {val}")
    logger.info('###############################')
        

    
    if kwargs['__trainer__'].training_mode and not kwargs['__trainer__'].debug:
        write_logs_to_wandb(LOGGING_DICT)
        if kwargs['__trainer__'].log2tensorboard:
            write_logs_to_tensorboard(kwargs['tb_writer'], LOGGING_DICT, cur_epoch)





def train_epoch_log_metrics(**kwargs):
    assert kwargs.get('train_results') is not None
    assert kwargs.get('val_results') is not None
    logger = kwargs['logger']
    cur_epoch = kwargs['cur_epoch']
    curr_best_epoch = kwargs['curr_best_epoch']
    all_hparams = kwargs['all_hparams']

    curr_r = kwargs['model'].get_r()

    metrics_train = compute_metrics(epoch_results=kwargs['train_results'],
                        curr_r=curr_r,
                        all_hparams=all_hparams
                        )
    metrics_val = compute_metrics(epoch_results=kwargs['val_results'],
                        curr_r=curr_r,
                        all_hparams=all_hparams
                        )
    
    

    metrics_log = {'epoch': cur_epoch} 
    metrics_log['curr_best_epoch'] = curr_best_epoch 

    for key, value in metrics_train.items():
        metrics_log[f'{key}/train'] = value
    for key, value in metrics_val.items():
        metrics_log[f'{key}/val'] = value

    
    precision_k = all_hparams['evaluation']['precision_k']
    logger.info(f'[results of epoch {cur_epoch}]')
    logger.info(f"[train]: {metrics_train['loss']:.4f}, acc: {metrics_train['acc']:.4f}, auc: {metrics_train['auc']:.4f}, exp_precision@{precision_k}: {metrics_train['exp_precision_k']:.4f}, exp_auc: {metrics_train['exp_auc']:.4f}" )
    logger.info(f"[val  ]: {metrics_val['loss']:.4f}, acc: {metrics_val['acc']:.4f}, auc: {metrics_val['auc']:.4f}, exp_precision@{precision_k}: {metrics_val['exp_precision_k']:.4f}, exp_auc: {metrics_val['exp_auc']:.4f}")
    
    store_metrics(metrics_log=metrics_log, logger=logger,
                tb_writer=kwargs['__trainer__'].tb_writer,
                debug=kwargs['__trainer__'].debug)
                
    
    return None

def store_metrics(metrics_log, logger, tb_writer=None, debug=True, log2tensorboard=False):
    logger.info('------------ detailed metrics ------------')
    for key, val in metrics_log.items():
        logger.info(f"{key}: {val}")
    
    if not debug:
        write_logs_to_wandb(metrics_log)
        if log2tensorboard:
            write_logs_to_tensorboard(tb_writer, metrics_log, metrics_log['epoch'])



def test_log_metrics(**kwargs):
    assert kwargs.get('test_results') is not None
    logger = kwargs['logger']
    epoch = kwargs['epochs']
    all_hparams = kwargs['all_hparams']


    metrics = compute_metrics(epoch_results=kwargs['test_results'], 
                train_val_test='test',
                all_hparams=all_hparams,
                )
    
    metrics_log = {'epoch': epoch} 
    for key, value in metrics.items():
        metrics_log[f'test/{key}'] = value
    
    precision_k = all_hparams['evaluation']['precision_k']
    logger.info(f'[test results]')
    logger.info(f"[test ]: {metrics['loss']:.4f}, Acc: {metrics['acc']:.4f}, Auc: {metrics['auc']:.4f}, exp_precision@{precision_k}: {metrics['exp_precision_k']:.4f}, exp_auc: {metrics['exp_auc']:.4f}" )
    store_metrics(metrics_log=metrics_log, 
                  logger=kwargs['logger'],
                  tb_writer=kwargs['__trainer__'].tb_writer,
                  debug=kwargs['__trainer__'].debug)
    
    return None


def compute_metrics(epoch_results, train_val_test='train', **kwargs):
    y = np.concatenate(epoch_results['y'])
    assert y.ndim == 1 or y.shape[1] == 1, 'only care about non-multi-label case now'
    assert train_val_test in ['train', 'val', 'test']

    all_hparams = kwargs['all_hparams']
    metrics = dict()

    # total loss
    loss, acc, auc = loss_acc_auc(epoch_results)
    metrics['loss'] = loss
    metrics['acc'] = acc
    metrics['auc'] = auc

    # prediction loss
    loss_pred = np.mean( epoch_results['pred_loss'] )
    metrics['pred_loss'] = loss_pred

    # explanation metrics
    k = all_hparams['evaluation']['precision_k']
    exp_precision_k = explain_precision_at_k(epoch_results, k)
    exp_auc = explain_auc(epoch_results)
    metrics['exp_precision_k'] = exp_precision_k
    metrics['exp_auc'] = exp_auc

    # ib loss
    if all_hparams['framework']['with_ib_constraint'] is True:
        loss_ib = np.mean( epoch_results['ib_loss'] )
        metrics['ib_loss'] = loss_ib
        
        if 'eib_loss' in epoch_results.keys():
            eib_loss = np.mean( epoch_results['eib_loss'] ) 
            metrics['eib_loss'] = eib_loss
        if 'ib_upper_loss' in epoch_results.keys():
            ib_upper_loss = np.mean( epoch_results['ib_upper_loss'] ) 
            metrics['ib_upper_loss'] = ib_upper_loss
        
    loss_q_theta = np.mean( epoch_results['q_theta_loss'] )
    metrics['q_theta_loss'] = loss_q_theta

    weight_node = np.mean( np.concatenate(epoch_results['node_mask']) )
    weight_edge = np.mean( np.concatenate(epoch_results['edge_mask']) )
    metrics['node_mask_weight'] = weight_node
    metrics['edge_mask_weight'] = weight_edge

    if train_val_test != 'test': 
        assert kwargs.get('curr_r', None) is not None
        curr_r = kwargs['curr_r']
        metrics['curr_r'] = curr_r

    return metrics