import os


import sys
sys.path.append('/mnt/data01/****/****')
import time

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from scipy.spatial.distance import cdist as scipy_cdist
from sklearn.utils.extmath import randomized_svd

from functools import partial
import os
import time

import torch
import dgl

# import read_data
from gnn.Datasets import LargeCompleteMVGraphDatasets
from gnn.Datasets import graph_cut
from gnn.pe_utils import posenc_config
from gnn.pe_utils import flip_pe_sign_utils
from gnn.networks.AutoVisualMultiGINComplete import AutoVisualNet


from gnn.useful_utils import graph_sampler
from gnn.useful_utils import visualization_metric


def get_model_ckp(ckp_name):
    net, save = torch.load(ckp_name)
    return net, save


def get_loader(data_names, exp_name, cdist_path, pe_path, b_size=1):
    get_ds_fn = partial(LargeCompleteMVGraphDatasets.DatasetGraphDataset,
                        exp_name=exp_name,
                        cdist_path=cdist_path,  # ***
                        visual_path='/mnt/data01/public/aad_data/bo',
                        visual_path_umap='/mnt/data01/public/aad_data/bo/umap',
                        is_test=False,  # ***
                        normalize_z=True, z_cali_method='none',
                        # z_anchor=dummy_train_ds.z[0],
                        precomputed_pe_path=pe_path,  # ***
                        z_mu=None, z_std=None, flip_sign_method='pos')
    ds = get_ds_fn(data_names=data_names)

    if b_size == 1:
        get_loader_for_eval = partial(dgl.dataloading.GraphDataLoader, batch_size=1,
                                      shuffle=True,
                                      num_workers=0,
                                      # collate_fn=lambda x: list(zip(*x))
                                      )

        train_loader = get_loader_for_eval(ds)
    else:
        train_sampler = torch.utils.data.RandomSampler(ds)
        train_loader = dgl.dataloading.GraphDataLoader(ds,
                                                       batch_sampler=graph_sampler.GraphBatchSampler(train_sampler,
                                                                                                     batch_size=30000,
                                                                                                     # 34000,
                                                                                                     block_list=ds.ds_block_sizes),
                                                       num_workers=10, )
    return train_loader


def get_loader_by_exp_name(data_names, exp_names):
    if exp_names == 'clip':
        return get_loader(data_names,
                          exp_name=exp_names,
                          cdist_path='/mnt/data01/public/aad_data',
                          pe_path='/mnt/data01/public/aad_data/pe')
    elif exp_names == 'gene':
        return get_loader(data_names,
                          exp_name=exp_names,
                          cdist_path='/mnt/data01/public/aad_data/gene_filtered',
                          pe_path='/mnt/data01/public/aad_data/pe/gene_filtered')
    elif exp_names == 'uci':
        return get_loader(data_names,
                          exp_name=exp_names,
                          cdist_path='/mnt/data01/public/aad_data/uci',
                          pe_path='/mnt/data01/public/aad_data/pe')


def get_d_names_from_save(save):
    train_names = save['train_names']
    test_names = save['test_names']
    return train_names, test_names


def plot_zs(z, zhat, y, d_name, exp_name):
    def scatter_z(z, y, ax, label, marker, alpha=1.0):
        z = z.cpu().numpy()
        scatter = ax.scatter(z[:, 0], z[:, 1], c=y,
                             label=label, marker=marker, alpha=alpha,
                             cmap='viridis'
                             )
        return scatter

    d_name = d_name[0]
    fig = plt.figure(figsize=(10, 8))
    fig.set(tight_layout=True)
    ax = fig.add_subplot()

    scatter_z(z, y, ax,
              label='z', marker='+', alpha=0.5,
              )

    s = scatter_z(zhat, y, ax,
                  label='z_hat', marker='o', alpha=1.0,
                  )
    plt.colorbar(s)
    plt.title(f'Visualization of {d_name} CLIP Feature')
    plt.legend()
    plt.savefig(f'/mnt/data01/****/****/gnn/transfer_exp/plot_res/{exp_name}-{d_name}.png')


def evaluate_model_performance(d_loader, net, n_method=2):
    net.eval()
    device = net.device
    with torch.no_grad():
        scs = []
        nmis = []
        d_names = []
        losses = []
        for d_name, n_view_graphs in d_loader:
            edge_only_graph = dgl.graph(n_view_graphs.edges(), num_nodes=n_view_graphs.num_nodes()).to(device)
            # tsne_z = n_view_graphs.ndata['z'].to(torch.float32).to(device)
            # umap_z = n_view_graphs.ndata['z_umap'].to(torch.float32).to(device)
            # batch_graph_size = n_view_graphs.batch_num_nodes()

            # batch_graph_size_cumsum = torch.cumsum(torch.cat([torch.zeros(1), batch_graph_size]), dim=0).long()
            # tsne_zdist = torch.cat([torch.pdist(tsne_z[batch_graph_size_cumsum[i]: batch_graph_size_cumsum[i + 1]])
            #                         for i in range(n_view_graphs.batch_size)])
            #
            # umap_zdist = torch.cat([torch.pdist(umap_z[batch_graph_size_cumsum[i]: batch_graph_size_cumsum[i + 1]])
            #                         for i in range(n_view_graphs.batch_size)])
            #
            # batch_zdist_size = (batch_graph_size * (batch_graph_size - 1) / 2).long().to(device)

            # student_t_loss, z_hat = net.loss_fn_kl_t_pdist_w_umap(n_view_graphs, edge_only_graph, tsne_zdist,
            #                                                       umap_zdist, batch_zdist_size)
            z_hat = net(edge_only_graph, n_view_graphs)
            y = n_view_graphs.ndata['y']
            if n_method > 1:
                tsne_nmi, tsne_sc = visualization_metric.get_nmi_sc(z_hat[:, :2].cpu().numpy(), y.tolist())
                umap_nmi, umap_sc = visualization_metric.get_nmi_sc(z_hat[:, 2:].cpu().numpy(), y.tolist())
            else:
                # plot_zs(umap_z, z_hat[:, :2], y, d_name=d_name, exp_name='kl')
                umap_nmi, umap_sc = visualization_metric.get_nmi_sc(z_hat[:, :2].cpu().numpy(), y.tolist())
                tsne_nmi, tsne_sc = 0.0, 0.0

            scs.append(torch.Tensor([tsne_sc, umap_sc]))
            nmis.append(torch.Tensor([tsne_nmi, umap_nmi]))
            # losses.append(student_t_loss)
            d_names.append(d_name[0])

    nmis_t = torch.stack(nmis)
    scs_t = torch.stack(scs)
    # losses_t = torch.Tensor(losses)
    # net.train()
    return nmis_t, scs_t, d_names


def get_gt_by_name(d_name, gt_dict):
    d_name = (d_name, )
    tsne_v = gt_dict[d_name]['tsne']
    umap_v = gt_dict[d_name]['umap']
    return tsne_v, umap_v


def get_gt_df(gt_save, d_names, prefix=''):
    nmi_gt_dict, sc_gt_dict = gt_save
    # print(list(nmi_gt_dict.keys())[:10])
    df_dict = {}
    for d_name in d_names:
        df_dict['d_name'] = df_dict.get('d_name', []) + [d_name]
        tsne_nmi, umap_nmi = get_gt_by_name(d_name, nmi_gt_dict)
        tsne_sc, umap_sc = get_gt_by_name(d_name, sc_gt_dict)
        df_dict[f'{prefix}gt_tsne_nmi'] = df_dict.get(f'{prefix}gt_tsne_nmi', []) + [tsne_nmi]
        df_dict[f'{prefix}gt_umap_nmi'] = df_dict.get(f'{prefix}gt_umap_nmi', []) + [umap_nmi]
        df_dict[f'{prefix}gt_tsne_sc'] = df_dict.get(f'{prefix}gt_tsne_sc', []) + [tsne_sc]
        df_dict[f'{prefix}gt_umap_sc'] = df_dict.get(f'{prefix}gt_umap_sc', []) + [umap_sc]
    df = pd.DataFrame(df_dict)
    return df


def get_eval_df(nmis_t, scs_t, d_names, prefix='pred_'):
    tsne_nmi, umap_nmi = nmis_t[:, 0], nmis_t[:, 1]
    tsne_sc, umap_sc = scs_t[:, 0], scs_t[:, 1]

    df_dict = {}
    df_dict['d_name'] = d_names
    df_dict[f'{prefix}tsne_nmi'] = tsne_nmi
    df_dict[f'{prefix}umap_nmi'] = umap_nmi
    df_dict[f'{prefix}tsne_sc'] = tsne_sc
    df_dict[f'{prefix}umap_sc'] = umap_sc
    # df_dict[f'{prefix}losses'] = losses_t

    df = pd.DataFrame(df_dict)
    return df


def get_gt_pred_df_from_d(gt_save, d_names, d_loader, net, to_save_name, n_method):
    gt_df = get_gt_df(gt_save, d_names)
    # print(gt_df)

    net = net.cuda()
    res = evaluate_model_performance(d_loader, net, n_method=n_method)
    pred_df = get_eval_df(*res)
    # print(pred_df)

    df = pd.merge(gt_df, pred_df, on='d_name')
    df.to_csv(to_save_name)
    return df


def get_mean_performance(df):
    relative_tsne_nmi_precision = df['pred_tsne_nmi'] / df['gt_tsne_nmi']
    relative_tsne_nmi_gap = df['gt_tsne_nmi'] - df['pred_tsne_nmi']

    relative_umap_nmi_precision = df['pred_umap_nmi'] / df['gt_umap_nmi']
    relative_umap_nmi_gap = df['gt_umap_nmi'] - df['pred_umap_nmi']

    relative_tsne_sc_precision = df['pred_tsne_sc'] / df['gt_tsne_sc']
    relative_tsne_sc_gap = df['gt_tsne_sc'] - df['pred_tsne_sc']

    relative_umap_sc_precision = df['pred_umap_sc'] / df['gt_umap_sc']
    relative_umap_sc_gap = df['gt_umap_sc'] - df['pred_umap_sc']

    df['relative_tsne_nmi_precision'] = relative_tsne_nmi_precision
    df['relative_tsne_nmi_gap'] = relative_tsne_nmi_gap
    df['relative_umap_nmi_precision'] = relative_umap_nmi_precision
    df['relative_umap_nmi_gap'] = relative_umap_nmi_gap

    df['relative_tsne_sc_precision'] = relative_tsne_sc_precision
    df['relative_tsne_sc_gap'] = relative_tsne_sc_gap
    df['relative_umap_sc_precision'] = relative_umap_sc_precision
    df['relative_umap_sc_gap'] = relative_umap_sc_gap

    mean_df = df.drop('d_name', axis=1).mean()
    std_df = df.drop('d_name', axis=1).std()
    return mean_df, std_df



from torch_geometric.graphgym.config import (cfg, dump_cfg,
                                             # set_agg_dir,
                                             set_cfg, load_cfg,
                                             makedirs_rm_exist)


def forward_from_x(x, net, precomputed_pes=None):
    set_cfg(cfg)
    posenc_config.set_cfg_posenc(cfg)

    d_time = time.perf_counter()
    cdist = torch.cdist(x, x).to(torch.float32)
    N = cdist.shape[0]
    print('N', N)

    pes = []
    ss = []
    for sigma in [0.1, 0.5, 1, 2, 5]:
        # edge_ind, edge_weights = graph_cut.get_complet_graph(torch.from_numpy(cdist), sigma=sigma)
        edge_ind, edge_weights = graph_cut.get_complet_graph(cdist, sigma=sigma)
        graph_pe = LargeCompleteMVGraphDatasets.extract_graph_pe(edge_ind, edge_weights, N)
        graph_pe = flip_pe_sign_utils.filp_pe_sign(graph_pe, method='pos')
        pes.append(graph_pe)
        ss.append(edge_weights)
    d_time = time.perf_counter() - d_time
    print(d_time)

    src, dst = torch.where(torch.isnan(cdist) == False)

    n_view_graphs = dgl.graph((src, dst), num_nodes=N)
    for j in range(5):
        n_view_graphs.edata[f'weight{j}'] = ss[j].to(torch.float32)
        n_view_graphs.ndata[f'pe{j}'] = pes[j]
    edge_only_graph = dgl.graph(n_view_graphs.edges(), num_nodes=n_view_graphs.num_nodes()).to(net.device)
    n_view_graphs = n_view_graphs.to(net.device)

    time_start = time.perf_counter()
    z_hat = net(edge_only_graph, n_view_graphs)
    t = time.perf_counter() - time_start
    print(d_time, t)
    return z_hat


def solve_p(Z1, Z2):
    M = Z2.T @ Z1  # m×m
    U, S, Vt = torch.svd(M)
    P = U @ Vt
    return P


def knn_numpy(X, x, k=5):
    """
    X : array of shape (n_samples, n_features)
    x : array of shape (n_features,)
    k : number of neighbors
    """
    # Compute L2 distances
    dists = np.linalg.norm(X - x, axis=1)        # shape (n_samples,)
    # Get the indices of the k smallest distances
    idx = np.argsort(dists)[:k]                  # shape (k,)
    # Return both indices and distances (optional)
    return idx

if __name__ == '__main__':
    root = '/mnt/data01/****/****/gnn/res/'
    # ckp_name = root + '500ds-GNN-clip_w_umap-regin2-8gt-sigma_mv-run4-300epoch-kl_t_64dim_svdu_complete_G-500ds-epoch-59.tar'
    # ckp_name = root + '500ds-GNN-clip_w_umap-bce-regin2-8gt-sigma_mv-run4-300epoch-kl_t_64dim_svdu_complete_G-500ds-epoch-119.tar'
    # ckp_name = root + '500ds-GNN-clip_w_umap-l2-subnet-regin2-8gt-sigma_mv-run4-300epoch-kl_t_64dim_svdu_complete_G-500ds-epoch-149.tar'
    # ckp_name = root + '500ds-GNN-clip_w_umap-l2-uninet-regin2-8gt-sigma_mv-run4-300epoch-kl_t_64dim_svdu_complete_G-500ds-epoch-154.tar'
    # ckp_name = root + '500ds-GNN-clip_w_umap-bothl2-uninet-regin2-8gt-sigma_mv-run4-300epoch-kl_t_64dim_svdu_complete_G-500ds-epoch-219.tar'

    save_f_names = [
        # '500ds-GNN-clip_w_umap-l2-uninet-regin2-8gt-sigma_mv-run4-300epoch-kl_t_64dim_svdu_complete_G-500ds-epoch-154.tar'
        # '500ds-GNN-clip_w_umap-umaponly-klkl-uninet-regin2-8gt-sigma_mv-run4-300epoch-kl_t_64dim_svdu_complete_G-500ds-epoch-89.tar',
        #             '500ds-GNN-clip_w_umap-umaponly-knn_umapkl-uninet-regin2-8gt-sigma_mv-run4-300epoch-kl_t_64dim_svdu_complete_G-500ds-epoch-164.tar',
                    # '500ds-GNN-clip_w_umap-umaponly-bce-uninet-regin2-8gt-sigma_mv-run4-300epoch-kl_t_64dim_svdu_complete_G-500ds-epoch-14.tar',
                    # '500ds-GNN-clip_w_umap-umaponly-knn_bce2-uninet-regin2-8gt-sigma_mv-run4-300epoch-kl_t_64dim_svdu_complete_G-500ds-epoch-74.tar',
                    # '500ds-GNN-clip_w_umap-umaponly-l2-uninet-regin2-8gt-sigma_mv-run4-300epoch-kl_t_64dim_svdu_complete_G-500ds-epoch-219.tar',
                    # '500ds-GNN-clip_w_umap-umaponly-knn_l2_dist-uninet-regin2-8gt-sigma_mv-run4-300epoch-kl_t_64dim_svdu_complete_G-500ds-epoch-274.tar',
                    # '500ds-GNN-clip_w_umap-umaponly-knn_ce_ori-uninet-regin2-8gt-sigma_mv-run4-300epoch-kl_t_64dim_svdu_complete_G-500ds-epoch-119.tar',
        # '500ds-GNN-clip_w_umap-umaponly-umap_knn_l2-new-regin2-8gt-sigma_mv-300epoch-kl_t_64dim_svdu_complete_G-500ds-epoch-84.tar'
        # '500ds-GNN-clip_w_umap-umaponly-umap_knn_l2-new-seqgt-regin2-8gt-sigma_mv-300epoch-kl_t_64dim_svdu_complete_G-500ds-epoch-84.tar'  # acutally is paragt
        # '500ds-GNN-clip_w_umap-umaponly-l2-newparallelgt-regin2-8gt-sigma_mv-300epoch-kl_t_64dim_svdu_complete_G-500ds-epoch-84.tar'
        '500ds-GNN-clip_w_umap-umaponly-knn-l2-newparallelgt-regin2-8gt-sigma_mv-300epoch-kl_t_64dim_svdu_complete_G-500ds-epoch-94.tar'
                    ]
    loss_types = ['kl-student-t', 'kl-umap', 'bce', 'bce-knn', 'l2-umap', 'l2-dist', 'umap-orig', 'knn-l2', 'knn-l2']

    input_x, input_y = torch.load(f"/home/****/autovisual/prepare_data/data/cifar10_features_clip.tar", weights_only=False)
    input_x = torch.from_numpy(input_x)
    input_y = torch.from_numpy(input_y)

    selected_indices = np.isin(input_y, (0, ))  # logical mask like [True, False, True, ...]
    anchor_x = input_x[selected_indices][:1000]
    anchor_y = input_y[selected_indices][:1000]

    # anchor_x = torch.randn(1000, 512)*1 + input_x[selected_indices][0]
    # anchor_y = input_y[selected_indices][:1000]
    # knn_ind = knn_numpy(input_x, anchor_x[0], 1000)
    # anchor_x = input_x[knn_ind]
    # anchor_y = input_y[knn_ind]

    selected_indices = np.isin(input_y, (0, 1, 2, 3,))  # logical mask like [True, False, True, ...]
    input_x = input_x[selected_indices]
    input_y = input_y[selected_indices]

    for ind, save_f_name in enumerate(save_f_names):
        ckp_name = root + save_f_name
        loss_type = loss_types[ind]

        net, save = get_model_ckp(ckp_name)

        device = 'cuda:0'
        print(net.device)
        with torch.no_grad():
            # split input x
            blocked_x = torch.chunk(input_x, 20)  # 2500 points each chunk
            blocked_y = torch.chunk(input_y, 20)  # 2500 points each chunk

            zhats = []
            ys = []
            # anchor_x = blocked_x[0][:100]
            # anchor_y = blocked_y[0][:100]
            anchor_z = None
            mean_z = None
            std_z = None

            # _new_x = torch.cat(blocked_x[:3])
            # _y = torch.cat(blocked_y[:3])

            for i in range(0, len(blocked_x)):
                anchor_x = anchor_x[:int(blocked_x[0].shape[0] * 0.1)]
                anchor_y = anchor_y[:int(blocked_x[0].shape[0]*0.1)]
                _new_x = torch.cat([anchor_x, blocked_x[i]])
                _y = torch.cat([anchor_y, blocked_y[i]])

                _zhat = forward_from_x(_new_x, net)
                if anchor_z is None:
                    anchor_z = _zhat[:anchor_y.shape[0]]
                    mean_z = anchor_z.mean(dim=0)
                    std_z = anchor_z.std(dim=0)
                # _zhat = (_zhat - mean_z) / (std_z + 1e-9)
                # p = solve_p(anchor_z, _zhat[:anchor_y.shape[0]])
                # _zhat = _zhat  @ p
                # print(p)
                # _zhat = (_zhat - _zhat.mean(dim=0)) / (_zhat.std(dim=0)+1e-9)

                print(_zhat[anchor_y.shape[0]:, :2].shape, _y[anchor_y.shape[0]:].shape)
                umap_nmi, umap_sc = visualization_metric.get_nmi_sc(_zhat[anchor_y.shape[0]:, :2].cpu().numpy(), _y[anchor_y.shape[0]:].numpy().tolist())
                print('nmi, sc', umap_nmi, umap_sc)
                zhats.append(_zhat[anchor_y.shape[0]:])
                ys.append(_y[anchor_y.shape[0]:])
                # if i == 3:
                #     break
            zhats = torch.cat(zhats)
            ys = torch.cat(ys)
            umap_nmi, umap_sc = visualization_metric.get_nmi_sc(zhats[:, :2].cpu().numpy(), ys.numpy().tolist())
            torch.save((zhats, ys, umap_nmi, umap_sc), './cifar10-large2-zhat.tar')
            print('========')
            print('total nmi, sc', umap_nmi, umap_sc)
            print('========')


            # zhats = []
            # ys = []
            # for i in range(1, len(blocked_x)):
            #     # _new_x = torch.cat([blocked_x[0], blocked_x[i]])
            #     _new_x = blocked_x[i]
            #     _y = blocked_y[i]
            #     _zhat = forward_from_x(_new_x, net)
            #     _zhat = (_zhat - _zhat.mean(dim=0)) / (_zhat.std(dim=0) + 1e-9)
            #     umap_nmi, umap_sc = visualization_metric.get_nmi_sc(_zhat[:, :2].cpu().numpy(), _y.numpy().tolist())
            #     print('nmi, sc', umap_nmi, umap_sc)
            #     zhats.append(_zhat)
            #     ys.append(_y)
            #     if i == 3:
            #         break
            #
            # zhats = torch.cat(zhats)
            # ys = torch.cat(ys)
            # umap_nmi, umap_sc = visualization_metric.get_nmi_sc(zhats[:, :2].cpu().numpy(), ys.numpy().tolist())
            # print('total nmi, sc', umap_nmi, umap_sc)
            # i = 0
            # for d_name, n_view_graphs in test_d_loader:
            #     edge_only_graph = dgl.graph(n_view_graphs.edges(), num_nodes=n_view_graphs.num_nodes()).to(device)
            #     n_view_graphs = n_view_graphs.to(device)
            #     # edge_only_graph.edata["_edge_weight"] = edge_only_graph.edata["edge_weight"].to(device)
            #     # n_view_graphs.edata["_edge_weight"] = n_view_graphs.edata["_edge_weight"].to(device)
            #     for j in range(5):
            #         n_view_graphs.ndata[f'pe{j}'] = n_view_graphs.ndata[f'pe{j}'].to(device)
            #         n_view_graphs.edata[f'weight{j}'] = n_view_graphs.edata[f'weight{j}'].to(device)
            #     time_start = time.perf_counter()
            #     z_hat = net(edge_only_graph, n_view_graphs)
            #     t = time.perf_counter() - time_start
            #     used_times.append(t)
            #     print(t)
            #     i = i + 1
            #     if i == 10:
            #         break

        # used_times = np.array(used_times)
        # torch.save(used_times, './autodv_w_pe_used_times.save')
        # print(used_times.mean(), used_times.std())
