import os.path

from torch_geometric.graphgym.config import (cfg, dump_cfg,
                                             # set_agg_dir,
                                             set_cfg, load_cfg,
                                             makedirs_rm_exist)
from torch_geometric.data import Data

import dgl
import torch
from dgl.data import DGLDataset
from dgl.dataloading import GraphDataLoader

import numpy as np

# import graph_cut
from . import graph_cut

from ..pe_utils import posenc_stats
from ..pe_utils import posenc_config


def extract_graph_pe(data):
    # Convert adjacency matrix to edge indices
    pes = [
        'LapPE',
        # 'NMFPE',
           # 'EquivStableLapPE',
           # 'SignNet',
           # 'RWSE',
           # 'HKdiagSE', 'ElstaticSE',
           # 'RRWP',
           # 'SVD',
        # 'PPR', 'WLPE', 'GCKN', 'RWDIFF'
           ]
    # pes = ['RWSE',]

    # print(data.num_nodes())
    edge_index = torch.tensor([data.edges()[0].numpy(), data.edges()[1].numpy()], dtype=torch.long)
    data = Data(edge_index=edge_index, num_nodes=data.num_nodes(), edge_weight=data.edata['weight'])
    # data = Data(edge_index=edge_index, num_nodes=data.num_nodes(), edge_weight=None)

    emb_data = posenc_stats.compute_posenc_stats(data, pe_types=pes, is_undirected=True, cfg=cfg)

    final_emb = []
    for pe in pes:
        emb = getattr(emb_data, f'{pe}_pos_enc')
        # print(f'{pe}_pos_enc', emb.shape)
        # a = emb[torch.isnan(emb)]
        final_emb.append(emb)
    # final_emb = np.concatenate(final_emb, axis=1)
    final_emb = torch.cat(final_emb, dim=1)
    return final_emb


# Step 1: Define a Multi-Graph Dataset
class DatasetGraphDataset(DGLDataset):
    def __init__(self, data_names, cdist_path, visual_path, precomputed_pe_path=None, z_cali_method='none', z_anchor=None, normalize_z=True,
                 z_mu=None, z_std=None):

        self.data_names = data_names
        self.cdist_path = cdist_path
        self.visual_path = visual_path
        self.precomputed_pe_path = precomputed_pe_path

        self.method = 'TSNE'
        self.z_anchor = z_anchor
        self.z_cali_method = z_cali_method

        self.graphs = []  # List to store graphs
        self.pes = []  # List to store graphs pe
        self.dists = []
        self.y = []  # Labels for node classification
        self.z = []
        self.zdists = []
        self.normalize_z = normalize_z
        self.ds_block_sizes = []

        self.z_mu = z_mu
        self.z_std = z_std
        super().__init__(name='distance_graph')

        # self.pe_type_subfix = 'complete_pe_64freq_nmf_only'
        # self.pe_type_subfix = 'complete_pe_64freq_lap_nonorm_only2'

        self.pe_type_subfix = 'complete_pe_64freq_lap_with_s_only'
        # self.pe_type_subfix = 'complete_pe_64freq_randsvd_only'

        if self.precomputed_pe_path is None:
            self.precompute()

    def precompute(self):
        set_cfg(cfg)
        posenc_config.set_cfg_posenc(cfg)
        for dataset in self.data_names:
            # load input data
            print(dataset)
            clip_save = torch.load(self.cdist_path + '/' + f'{dataset}_clip_cdist_3000.tar', weights_only=False)
            cdist, x, y = clip_save[:3]
            # graph_views = []
            # graph_pes = []
            # for each dataset, span 5 views

            for sigma in [0.1, 0.5, 1, 2, 5]:
                f_name = 'pe_data/pe_for_gat' + '/' + f'{dataset}_{sigma}_{self.pe_type_subfix}.tar'
                if self.precomputed_pe_path is None:
                    self.precomputed_pe_path = 'pe_data/pe_for_gat'
                    if os.path.exists(f_name):
                        continue
                    else:
                        complete_graph = graph_cut.get_complet_graph(torch.from_numpy(cdist), sigma=sigma)
                        graph_pe = extract_graph_pe(complete_graph)
                        print(graph_pe.shape)
                        torch.save(graph_pe, f_name)
                # else:
                #     graph_pe = torch.load(self.precomputed_pe_path + '/' + f'{dataset}_{sigma}_{self.pe_type_subfix}.tar')

    def process(self):
        set_cfg(cfg)
        posenc_config.set_cfg_posenc(cfg)
        # for dataset in self.data_names:
        #     # load input data
        #     print(dataset)
        #     clip_save = torch.load(self.cdist_path + '/' + f'{dataset}_clip_cdist_3000.tar')
        #     cdist, x, y = clip_save[:3]
        #     graph_views = []
        #     graph_pes = []
        #     # for each dataset, span 5 views
        #
        #     for sigma in [0.1, 0.5, 1, 2, 5]:
        #         complete_graph = graph_cut.get_complet_graph(torch.from_numpy(cdist), sigma=sigma)
        #         if self.precomputed_pe_path is None:
        #             graph_pe = extract_graph_pe(complete_graph)
        #             print(graph_pe.shape)
        #             torch.save(graph_pe, '../prepare_data/pe/pe_for_gat' + '/' + f'{dataset}_{sigma}_complete_pe_64freq_lap_only.tar')
        #         else:
        #             graph_pe = torch.load(self.precomputed_pe_path + '/' + f'{dataset}_{sigma}_complete_pe_64freq_lap_only.tar')
        #
        #     # for method in [
        #     #     'gaussian', 'student_t',
        #     #                # 'linear',
        #     #                # 'softmax'
        #     # ]:
        #     #     complete_graph = graph_cut.get_complet_graph(torch.from_numpy(cdist), sigma=1.0, method=method)
        #     #     if self.precomputed_pe_path is None:
        #     #         graph_pe = extract_graph_pe(complete_graph)
        #     #         print(method, graph_pe.shape)
        #     #         torch.save(graph_pe,
        #     #                    '../prepare_data/pe/pe_for_gat' + '/' + f'{dataset}_{method}_complete_pe_64freq_lap_only.tar')
        #     #     else:
        #     #         graph_pe = torch.load(
        #     #             self.precomputed_pe_path + '/' + f'{dataset}_{method}_complete_pe_64freq_lap_only.tar')
        #
        #         # print(graph_pe)
        #
        #         graph_views.append(complete_graph)
        #         graph_pes.append(graph_pe)
        #
        #     self.graphs.append(graph_views)
        #     self.pes.append(graph_pes)
        #     self.y.append(y)
        #     self.ds_block_sizes.append(y.shape[0])
        #
        #     cdist = torch.from_numpy(cdist).to(torch.float32)
        #     bandwidth = torch.median(cdist) * 1.0  # Standard deviation as scaling factor
        #     gaussian_kernel_cdist = torch.exp(-cdist ** 2 / (2 * bandwidth ** 2))
        #     self.dists.append(gaussian_kernel_cdist)
        #
        #     # load output data
        #     selected_emb, hps = torch.load(
        #         self.visual_path + '/' + f'visual-method-{self.method}_dataset-{dataset}_selected_emb.tar')
        #     selected_emb = self.calibrate_z(selected_emb, z_anchor=self.z_anchor)
        #     self.z.append(selected_emb)
        #
        #     z = torch.from_numpy(selected_emb).to(torch.float32)
        #     self.zdists.append(torch.cdist(z, z))
        #
        # self.z = np.concatenate(self.z)
        #
        # # torch.save((self.pes, self.graphs, self.z, self.zdists, self.y), '../prepare_data/pe/pe_for_gat' + '/' + 'sigma_mv_128ds.ds')
        #
        # if self.z_mu is None and self.z_std is None:
        #     self.z_mu = np.mean(self.z, axis=0)
        #     self.z_std = np.std(self.z, axis=0)
        #
        # if self.normalize_z:
        #     self.z = (self.z - self.z_mu) / (self.z_std + 1e-9)
        #
        # split_indices = np.cumsum(self.ds_block_sizes)[:-1]
        # self.z = np.split(self.z, split_indices, axis=0)
        # # self.z = np.split(self.z, self.ds_block_sizes, axis=0)
        # pass

    def load_sample(self, dataset):
        clip_save = torch.load(self.cdist_path + '/' + f'{dataset}_clip_cdist_3000.tar', weights_only=False)
        cdist, x, y = clip_save[:3]
        graph_views = []
        graph_pes = []
        for sigma in [0.1, 0.5, 1, 2, 5]:
            complete_graph = graph_cut.get_complet_graph(torch.from_numpy(cdist), sigma=sigma)

            graph_pe = torch.load(
                    self.precomputed_pe_path + '/' + f'{dataset}_{sigma}_{self.pe_type_subfix}.tar', weights_only=True)
            graph_pes.append(graph_pe)
            graph_views.append(complete_graph)

        # dongqiao comment this block to avoid loading visual features
        # selected_emb, hps = torch.load(
        #     self.visual_path + '/' + f'visual-method-{self.method}_dataset-{dataset}_selected_emb.tar', weights_only=True)
        # selected_emb = self.calibrate_z(selected_emb, z_anchor=self.z_anchor)

        # z = torch.from_numpy(selected_emb).to(torch.float32)
        # zdist = torch.cdist(z, z)
        selected_emb = None
        zdist = None

        return graph_pes, graph_views, selected_emb, zdist, y

    def __getitem__(self, idx):
        dataset = self.data_names[idx]
        return self.load_sample(dataset)
        # return self.pes[idx], self.graphs[idx], self.z[idx], self.zdists[idx], self.y[idx]

    def __len__(self):
        return len(self.data_names)

    def calibrate_z(self, z, z_anchor, n_dim=2):
        if self.z_cali_method == 'svd':
            U, S, Vt = np.linalg.svd(z)
            S_matrix = np.diag(S)
            new_z = np.dot(U[:, :n_dim], S_matrix[:n_dim, :n_dim])
            return new_z
        elif self.z_cali_method == 'orth':
            m = np.matmul(np.mean(z, axis=0).reshape(2, 1), np.mean(z_anchor, axis=0).reshape(1, 2))
            u, s, vt = np.linalg.svd(m)
            r = np.matmul(u, vt)
            new_z = z @ r
            return new_z
        elif self.z_cali_method == 'none':
            return z
        else:
            raise NotImplementedError

def main():
    from functools import partial
    # ds = DatasetGraphDataset(data_names=['mnist_group2'], cdist_path='features',
    #                          visual_path='../prepare_data/bo/res-2')
    ds = DatasetGraphDataset(data_names=['cifar10_1class_comb0_seed0'], cdist_path='features',
                             visual_path='res-2')

    get_loader = partial(torch.utils.data.DataLoader, batch_size=1,
                         shuffle=True,
                         num_workers=0,
                         collate_fn=lambda x: list(zip(*x))
                         )
    train_loader = get_loader(ds)

    it = iter(train_loader)
    a = next(it)
    batch_graph = dgl.batch(ds[0][1])  # for testing

if __name__ == '__main__':
    main()
