import os.path

from torch_geometric.graphgym.config import (cfg, dump_cfg,
                                             # set_agg_dir,
                                             set_cfg, load_cfg,
                                             makedirs_rm_exist)
from torch_geometric.data import Data

import dgl
import torch
from dgl.data import DGLDataset
from dgl.dataloading import GraphDataLoader

import numpy as np

# import graph_cut
from . import graph_cut

from ..pe_utils import posenc_stats
from ..pe_utils import posenc_config
from ..pe_utils import flip_pe_sign_utils


def extract_graph_pe(edge_index, weight, N):
    # Convert adjacency matrix to edge indices
    pes = [
        # 'LapPE',
        # 'NMFPE',
           # 'EquivStableLapPE',
           # 'SignNet',
           # 'RWSE',
           # 'HKdiagSE', 'ElstaticSE',
           # 'RRWP',
           'SVD',
        # 'PPR', 'WLPE', 'GCKN', 'RWDIFF'
           ]
    # pes = ['RWSE',]

    # print(data.num_nodes())
    # edge_index = torch.tensor([data.edges()[0].numpy(), data.edges()[1].numpy()], dtype=torch.long)
    # data = Data(edge_index=edge_index, num_nodes=data.num_nodes(), edge_weight=data.edata['weight'])
    data = Data(edge_index=edge_index, num_nodes=N, edge_weight=weight)
    # data = Data(edge_index=edge_index, num_nodes=data.num_nodes(), edge_weight=None)

    emb_data = posenc_stats.compute_posenc_stats(data, pe_types=pes, is_undirected=True, cfg=cfg)

    final_emb = []
    for pe in pes:
        emb = getattr(emb_data, f'{pe}_pos_enc')
        # print(f'{pe}_pos_enc', emb.shape)
        # a = emb[torch.isnan(emb)]
        final_emb.append(emb)
    # final_emb = np.concatenate(final_emb, axis=1)
    final_emb = torch.cat(final_emb, dim=1)
    return final_emb


# Step 1: Define a Multi-Graph Dataset
class DatasetGraphDataset(DGLDataset):
    def __init__(self, data_names, cdist_path, visual_path, precomputed_pe_path=None, z_cali_method='none', z_anchor=None, normalize_z=True,
                 z_mu=None, z_std=None, flip_sign_method='none'):

        self.data_names = data_names
        self.cdist_path = cdist_path
        self.visual_path = visual_path
        self.precomputed_pe_path = precomputed_pe_path

        self.method = 'TSNE'
        self.z_anchor = z_anchor
        self.z_cali_method = z_cali_method

        self.graphs = []  # List to store graphs
        self.pes = []  # List to store graphs pe
        self.dists = []
        self.y = []  # Labels for node classification
        self.z = []
        self.zdists = []
        self.normalize_z = normalize_z


        self.z_mu = z_mu
        self.z_std = z_std
        super().__init__(name='distance_graph')

        self.flip_sign_method = flip_sign_method

        self.cdist_subfix = ''

        # self.pe_type_subfix = 'complete_pe_64freq_nmf_only'
        # self.pe_type_subfix = 'complete_pe_64freq_lap_nonorm_only2'
        # self.pe_type_subfix = 'complete_pe_64freq_lap_only'
        self.pe_type_subfix = 'complete_pe_64freq_svd_u_torchf_only'

        # self.pe_type_subfix = 'complete_pe_64freq_lap_with_s_only'
        # self.pe_type_subfix = 'complete_pe_64freq_randsvd_only'

        self.ds_block_sizes = self.get_block_size()
        if self.precomputed_pe_path is None:
            self.precompute()

    def precompute(self):
        set_cfg(cfg)
        posenc_config.set_cfg_posenc(cfg)
        for d_name in self.data_names:
            # load input data
            # print(d_name)
            # dataset_name = d_name.split('-')[0]
            # clip_save = torch.load(self.cdist_path + f'/{dataset_name}/{d_name}{self.cdist_subfix}.tar')
            dataset_name = '-'.join(d_name.split('-')[:-2])
            clip_save = torch.load(self.cdist_path + f'/{d_name}{self.cdist_subfix}.tar')
            # cdist, x, y = clip_save[:3]
            cdist, y, _ = clip_save[:3]

            src, dst = torch.where(torch.isnan(torch.from_numpy(cdist)) == False)
            N = cdist.shape[0]
            graph = dgl.graph((src, dst), num_nodes=N)

            # for each dataset, span 5 views
            for sigma in [0.1, 0.5, 1, 2, 5]:
                f_name = f'/mnt/data01/public/aad_data/pe' + '/' + f'{dataset_name}/{d_name}_{sigma}_{self.pe_type_subfix}.tar'
                # f_name = f'/mnt/data01/public/aad_data/pe' + '/uci/' + f'{d_name}_{sigma}_{self.pe_type_subfix}.tar'
                if self.precomputed_pe_path is None:
                    if os.path.exists(f_name):
                        continue
                    else:
                        edge_ind, edge_weights = graph_cut.get_complet_graph(torch.from_numpy(cdist), sigma=sigma)
                        graph_pe = extract_graph_pe(edge_ind, edge_weights, N)
                        print(graph_pe.shape)
                        torch.save(graph_pe, f_name)
                # else:
                #     graph_pe = torch.load(self.precomputed_pe_path + '/' + f'{dataset}_{sigma}_{self.pe_type_subfix}.tar')

    def process(self):
        pass

    def get_block_size(self):
        block_sizes = []
        for d_name in self.data_names:
            dataset_name = d_name.split('-')[0]
            selected_emb, hps = torch.load(
                self.visual_path + '/' + f'{dataset_name}/visual-method-{self.method}_dataset-{d_name}_selected_emb.tar')
            # dataset_name = '-'.join(d_name.split('-')[:-2])
            # selected_emb, hps = torch.load(
            #     self.visual_path + '/uci/' + f'visual-method-{self.method}_dataset-{d_name}_selected_emb.tar')
            block_sizes.append(selected_emb.shape[0])
        return block_sizes

    def load_dgl_smaple(self, d_name):
        # dataset_name = d_name.split('-')[0]
        dataset_name = '-'.join(d_name.split('-')[:-2])
        clip_save = torch.load(self.cdist_path + '/' + f'{dataset_name}/{d_name}{self.cdist_subfix}.tar')
        # clip_save = torch.load(self.cdist_path + '/' + f'{d_name}{self.cdist_subfix}.tar')
        cdist, y, _ = clip_save[:3]

        # cdist, y, _ = clip_save[:3]
        graph_views = []

        # print(cdist.shape)

        src, dst = torch.where(torch.isnan(torch.from_numpy(cdist)) == False)
        N = cdist.shape[0]
        graph = dgl.graph((src, dst), num_nodes=N)
        # graph.ndata['n_node'] = torch.tensor([N])

        # for each dataset, span 5 views
        for k, sigma in enumerate([
            0.1, 0.5, 1, 2, 5
        ]):
            edge_ind, edge_weights = graph_cut.get_complet_graph(torch.from_numpy(cdist), sigma=sigma)
            graph.edata[f'weight{k}'] = edge_weights.to(torch.float32)


            graph_pe = torch.load(
                self.precomputed_pe_path + '/' + f'{dataset_name}/{d_name}_{sigma}_{self.pe_type_subfix}.tar')
                # self.precomputed_pe_path + '/' + f'uci/{d_name}_{sigma}_{self.pe_type_subfix}.tar')

            graph_pe = flip_pe_sign_utils.filp_pe_sign(graph_pe, method=self.flip_sign_method)
            graph.ndata[f'pe{k}'] = graph_pe
            # graph_views.append(graph)

        selected_emb, hps = torch.load(
            self.visual_path + '/' + f'{dataset_name}/visual-method-{self.method}_dataset-{d_name}_selected_emb.tar')
            # self.visual_path + '/' + f'uci/visual-method-{self.method}_dataset-{d_name}_selected_emb.tar')
        # selected_emb = self.calibrate_z(selected_emb, z_anchor=self.z_anchor)

        z = torch.from_numpy(selected_emb).to(torch.float32)
        # print(d_name, z.shape)
        # zdist = torch.cdist(z, z)

        graph.ndata['z'] = z
        # graph.ndata['zdist'] = zdist
        graph.ndata['y'] = torch.tensor(y).long()

        return graph

    def load_sample(self, dataset):
        clip_save = torch.load(self.cdist_path + '/' + f'{dataset}_clip_cdist_3000.tar')
        cdist, x, y = clip_save[:3]
        graph_views = []
        graph_pes = []
        for sigma in [0.1, 0.5, 1, 2, 5]:
            complete_graph = graph_cut.get_complet_graph(torch.from_numpy(cdist), sigma=sigma)

            graph_pe = torch.load(
                    self.precomputed_pe_path + '/' + f'{dataset}_{sigma}_{self.pe_type_subfix}.tar')

            graph_pe = flip_pe_sign_utils.filp_pe_sign(graph_pe, method=self.flip_sign_method)

            graph_pes.append(graph_pe)
            graph_views.append(complete_graph)

        selected_emb, hps = torch.load(
            self.visual_path + '/' + f'visual-method-{self.method}_dataset-{dataset}_selected_emb.tar')
        selected_emb = self.calibrate_z(selected_emb, z_anchor=self.z_anchor)

        z = torch.from_numpy(selected_emb).to(torch.float32)
        zdist = torch.cdist(z, z)
        y = y.long()

        return graph_pes, graph_views, selected_emb, zdist, y

    def __getitem__(self, idx):
        dataset = self.data_names[idx]
        # return self.load_sample(dataset)
        return dataset, self.load_dgl_smaple(dataset)
        # return self.pes[idx], self.graphs[idx], self.z[idx], self.zdists[idx], self.y[idx]

    def __len__(self):
        return len(self.data_names)

    def calibrate_z(self, z, z_anchor, n_dim=2):
        if self.z_cali_method == 'svd':
            U, S, Vt = np.linalg.svd(z)
            S_matrix = np.diag(S)
            new_z = np.dot(U[:, :n_dim], S_matrix[:n_dim, :n_dim])
            return new_z
        elif self.z_cali_method == 'orth':
            m = np.matmul(np.mean(z, axis=0).reshape(2, 1), np.mean(z_anchor, axis=0).reshape(1, 2))
            u, s, vt = np.linalg.svd(m)
            r = np.matmul(u, vt)
            new_z = z @ r
            return new_z
        elif self.z_cali_method == 'none':
            return z
        else:
            raise NotImplementedError

def main():
    from functools import partial
    # ds = DatasetGraphDataset(data_names=['mnist_group2'], cdist_path='features',
    #                          visual_path='../prepare_data/bo/res-2')
    ds = DatasetGraphDataset(data_names=['cifar10_1class_comb0_seed0'], cdist_path='features',
                             visual_path='res-2')

    get_loader = partial(torch.utils.data.DataLoader, batch_size=1,
                         shuffle=True,
                         num_workers=0,
                         collate_fn=lambda x: list(zip(*x))
                         )
    train_loader = get_loader(ds)

    it = iter(train_loader)
    a = next(it)
    batch_graph = dgl.batch(ds[0][1])  # for testing

if __name__ == '__main__':
    main()
