import os.path

from torch_geometric.graphgym.config import (cfg, dump_cfg,
                                             # set_agg_dir,
                                             set_cfg, load_cfg,
                                             makedirs_rm_exist)
from torch_geometric.data import Data

import dgl
import torch
from dgl.data import DGLDataset
from dgl.dataloading import GraphDataLoader

import numpy as np

# import graph_cut
from . import graph_cut

from ..pe_utils import posenc_stats
from ..pe_utils import posenc_config
from ..pe_utils import flip_pe_sign_utils


def extract_graph_pe(edge_index, weight, N):
    # Convert adjacency matrix to edge indices
    pes = [
        # 'LapPE',
        # 'NMFPE',
           # 'EquivStableLapPE',
           # 'SignNet',
           # 'RWSE',
           # 'HKdiagSE', 'ElstaticSE',
           # 'RRWP',
           'SVD',
        # 'PPR', 'WLPE', 'GCKN', 'RWDIFF'
           ]
    # pes = ['RWSE',]

    # print(data.num_nodes())
    # edge_index = torch.tensor([data.edges()[0].numpy(), data.edges()[1].numpy()], dtype=torch.long)
    # data = Data(edge_index=edge_index, num_nodes=data.num_nodes(), edge_weight=data.edata['weight'])
    data = Data(edge_index=edge_index, num_nodes=N, edge_weight=weight)
    # data = Data(edge_index=edge_index, num_nodes=data.num_nodes(), edge_weight=None)

    emb_data = posenc_stats.compute_posenc_stats(data, pe_types=pes, is_undirected=True, cfg=cfg)

    final_emb = []
    for pe in pes:
        emb = getattr(emb_data, f'{pe}_pos_enc')
        # return torch.from_numpy(emb)
        # print(f'{pe}_pos_enc', emb.shape)
        # a = emb[torch.isnan(emb)]
        final_emb.append(emb)
    # final_emb = np.concatenate(final_emb, axis=1)
    final_emb = torch.cat(final_emb, dim=1)
    return final_emb


# Step 1: Define a Multi-Graph Dataset
class DatasetGraphDataset(DGLDataset):
    def __init__(self, exp_name, data_names, cdist_path, visual_path, visual_path_umap,
                 precomputed_pe_path=None,
                 is_test=False, is4gt=False,
                 z_cali_method='none', z_anchor=None, normalize_z=True,
                 z_mu=None, z_std=None,
                 flip_sign_method='none'):
        self.exp_name = exp_name
        self.is4gt = is4gt
        self.data_names = data_names
        self.cdist_path = cdist_path
        self.visual_path = visual_path
        self.visual_path_umap = visual_path_umap
        self.precomputed_pe_path = precomputed_pe_path
        self.is_test = is_test

        self.method = 'TSNE'
        self.z_anchor = z_anchor
        self.z_cali_method = z_cali_method

        self.graphs = []  # List to store graphs
        self.pes = []  # List to store graphs pe
        self.dists = []
        self.y = []  # Labels for node classification
        self.z = []
        self.zdists = []
        self.normalize_z = normalize_z


        self.z_mu = z_mu
        self.z_std = z_std
        super().__init__(name='distance_graph')

        self.flip_sign_method = flip_sign_method

        self.cdist_subfix = ''

        # self.pe_type_subfix = 'complete_pe_64freq_nmf_only'
        # self.pe_type_subfix = 'complete_pe_64freq_lap_nonorm_only2'
        # self.pe_type_subfix = 'complete_pe_64freq_lap_only'
        self.pe_type_subfix = 'complete_pe_64freq_svd_u_torchf_only'

        # self.pe_type_subfix = 'complete_pe_64freq_lap_with_s_only'
        # self.pe_type_subfix = 'complete_pe_64freq_randsvd_only'


        if self.precomputed_pe_path is None:
            self.precompute()
        # self.ds_block_sizes = self.get_block_size()

    def precompute(self):
        set_cfg(cfg)
        posenc_config.set_cfg_posenc(cfg)
        for d_name in self.data_names:
            # load input data
            # print(d_name)
            # dataset_name = '-'.join(d_name.split('-')[:-2])
            dataset_name = 'cifar10_large'
            clip_save = torch.load(self.cdist_path + f'/{dataset_name}/{d_name}{self.cdist_subfix}.tar')
            # dataset_name = '-'.join(d_name.split('-')[:-2])
            # clip_save = torch.load(self.cdist_path + f'/{d_name}{self.cdist_subfix}.tar')
            # cdist, x, y = clip_save[:3]
            cdist, y, _ = clip_save[:3]

            src, dst = torch.where(torch.isnan(torch.from_numpy(cdist)) == False)
            N = cdist.shape[0]
            graph = dgl.graph((src, dst), num_nodes=N)

            # for each dataset, span 5 views
            for sigma in [0.1, 0.5, 1, 2, 5]:
                if self.exp_name == 'uci':
                    dataset_name = 'uci'
                f_name = f'/mnt/data01/public/aad_data/pe' + '/' + f'{dataset_name}/{d_name}_{sigma}_{self.pe_type_subfix}.tar'
                # f_name = f'/mnt/data01/public/aad_data/pe' + '/uci/' + f'{d_name}_{sigma}_{self.pe_type_subfix}.tar'
                if self.precomputed_pe_path is None:
                    if os.path.exists(f_name):
                        continue
                    else:
                        edge_ind, edge_weights = graph_cut.get_complet_graph(torch.from_numpy(cdist), sigma=sigma)
                        graph_pe = extract_graph_pe(edge_ind, edge_weights, N)
                        print(graph_pe.shape)
                        torch.save(graph_pe, f_name)
                # else:
                #     graph_pe = torch.load(self.precomputed_pe_path + '/' + f'{dataset}_{sigma}_{self.pe_type_subfix}.tar')

    def process(self):
        pass

    def get_block_size(self):
        block_sizes = []
        for d_name in self.data_names:
            dataset_name = d_name.split('-')[0]
            if self.exp_name == 'uci':
                dataset_name = 'uci'
            selected_emb, hps = torch.load(
                self.visual_path + '/' + f'{dataset_name}/visual-method-{self.method}_dataset-{d_name}_selected_emb.tar')
            # dataset_name = '-'.join(d_name.split('-')[:-2])
            # selected_emb, hps = torch.load(
            #     self.visual_path + '/uci/' + f'visual-method-{self.method}_dataset-{d_name}_selected_emb.tar')
            block_sizes.append(selected_emb.shape[0])
        return block_sizes

    def load_dgl_sample(self, d_name):
        # dataset_name = d_name.split('-')[0]
        dataset_name = '-'.join(d_name.split('-')[:-2])
        # dataset_name = 'cifar10_large'
        if self.exp_name == 'uci':
            dataset_name = 'uci'
        clip_save = torch.load(self.cdist_path + '/' + f'{dataset_name}/{d_name}{self.cdist_subfix}.tar', map_location='cpu')
        # clip_save = torch.load(self.cdist_path + '/' + f'{d_name}{self.cdist_subfix}.tar')
        cdist, y, _ = clip_save[:3]

        # cdist, y, _ = clip_save[:3]
        graph_views = []

        # print(cdist.shape)

        src, dst = torch.where(torch.isnan(torch.from_numpy(cdist)) == False)
        N = cdist.shape[0]
        graph = dgl.graph((src, dst), num_nodes=N)
        # graph.ndata['n_node'] = torch.tensor([N])

        # for each dataset, span 5 views
        if not self.is4gt:
            for k, sigma in enumerate([
                0.1, 0.5, 1, 2, 5
            ]):
                edge_ind, edge_weights = graph_cut.get_complet_graph(torch.from_numpy(cdist), sigma=sigma)
                graph.edata[f'weight{k}'] = edge_weights.to(torch.float32)


                graph_pe = torch.load(
                    self.precomputed_pe_path + '/' + f'{dataset_name}/{d_name}_{sigma}_{self.pe_type_subfix}.tar')
                    # self.precomputed_pe_path + '/' + f'uci/{d_name}_{sigma}_{self.pe_type_subfix}.tar')

                graph_pe = flip_pe_sign_utils.filp_pe_sign(graph_pe, method=self.flip_sign_method)
                graph.ndata[f'pe{k}'] = graph_pe
            # graph_views.append(graph)

        # selected_emb, hps = torch.load(
        #     self.visual_path + '/' + f'{dataset_name}/visual-method-{self.method}_dataset-{d_name}_selected_emb.tar')
        #     # self.visual_path + '/' + f'uci/visual-method-{self.method}_dataset-{d_name}_selected_emb.tar')
        # # selected_emb = self.calibrate_z(selected_emb, z_anchor=self.z_anchor)
        #
        # z = torch.from_numpy(selected_emb).to(torch.float32)
        # # print(d_name, z.shape)
        # # zdist = torch.cdist(z, z)
        #
        # graph.ndata['z'] = z
        # graph.ndata['zdist'] = zdist
        graph.ndata['y'] = torch.tensor(y).long()

        return graph

    def load_dgl_sample_w_tsne(self, d_name):
        graph = self.load_dgl_sample(d_name=d_name)
        dataset_name = '-'.join(d_name.split('-')[:-2])
        if self.exp_name == 'uci':
            dataset_name = 'uci'
        selected_emb, hps = torch.load(
            self.visual_path + '/' + f'{dataset_name}/visual-method-TSNE_dataset-{d_name}_selected_emb.tar')

        z = torch.from_numpy(selected_emb).to(torch.float32)
        if self.normalize_z :
            mu = torch.mean(z, dim=0)
            std = torch.std(z, dim=0)
            z = (z - mu) / (std + 1e-9)
        graph.ndata['z'] = z
        return graph

    def load_dgl_sample_w_umap(self, d_name):
        graph = self.load_dgl_sample_w_tsne(d_name=d_name)
        dataset_name = '-'.join(d_name.split('-')[:-2])
        if self.exp_name == 'uci':
            dataset_name = 'uci'
        selected_emb, hps = torch.load(
            self.visual_path_umap + '/' + f'{dataset_name}/visual-method-UMAP_dataset-{d_name}_selected_emb.tar')

        z = torch.from_numpy(selected_emb).to(torch.float32)
        if self.normalize_z:
            mu = torch.mean(z, dim=0)
            std = torch.std(z, dim=0)
            z = (z - mu) / (std + 1e-9)
        graph.ndata['z_umap'] = z
        # print(hps['n_neighbors'])
        # graph.ndata['umap_n_neighbor'] = torch.tensor([hps['n_neighbors']])
        graph.ndata['umap_n_neighbor'] = torch.tensor([hps['n_neighbors']]).repeat(z.shape[0], 1)
        return graph

    # def get_z_mu_std(self):
    #     tsne_z = []
    #     umap_z = []
    #     for d_name in self.data_names:
    #         dataset_name = '-'.join(d_name.split('-')[:-2])
    #         tsne_selected_emb, tsne_hps = torch.load(
    #             self.visual_path + '/' + f'{dataset_name}/visual-method-TSNE_dataset-{d_name}_selected_emb.tar')
    #         umap_selected_emb, umap_hps = torch.load(
    #             self.visual_path_umap + '/' + f'{dataset_name}/visual-method-UMAP_dataset-{d_name}_selected_emb.tar')
    #         tsne_z.append(tsne_selected_emb)
    #         umap_z.append(umap_selected_emb)
    #     tsne_z = np.concatenate()


    def load_sample(self, dataset):
        clip_save = torch.load(self.cdist_path + '/' + f'{dataset}_clip_cdist_3000.tar')
        cdist, x, y = clip_save[:3]
        graph_views = []
        graph_pes = []
        for sigma in [0.1, 0.5, 1, 2, 5]:
            complete_graph = graph_cut.get_complet_graph(torch.from_numpy(cdist), sigma=sigma)

            graph_pe = torch.load(
                    self.precomputed_pe_path + '/' + f'{dataset}_{sigma}_{self.pe_type_subfix}.tar')

            graph_pe = flip_pe_sign_utils.filp_pe_sign(graph_pe, method=self.flip_sign_method)

            graph_pes.append(graph_pe)
            graph_views.append(complete_graph)

        selected_emb, hps = torch.load(
            self.visual_path + '/' + f'visual-method-{self.method}_dataset-{dataset}_selected_emb.tar')
        selected_emb = self.calibrate_z(selected_emb, z_anchor=self.z_anchor)

        z = torch.from_numpy(selected_emb).to(torch.float32)
        zdist = torch.cdist(z, z)
        y = y.long()

        return graph_pes, graph_views, selected_emb, zdist, y

    def __getitem__(self, idx):
        dataset = self.data_names[idx]
        # return self.load_sample(dataset)
        if self.is_test:
            return dataset, self.load_dgl_sample(dataset)
        else:
            return dataset, self.load_dgl_sample_w_umap(dataset)
        # return self.pes[idx], self.graphs[idx], self.z[idx], self.zdists[idx], self.y[idx]

    def __len__(self):
        return len(self.data_names)

    def calibrate_z(self, z, z_anchor, n_dim=2):
        if self.z_cali_method == 'svd':
            U, S, Vt = np.linalg.svd(z)
            S_matrix = np.diag(S)
            new_z = np.dot(U[:, :n_dim], S_matrix[:n_dim, :n_dim])
            return new_z
        elif self.z_cali_method == 'orth':
            m = np.matmul(np.mean(z, axis=0).reshape(2, 1), np.mean(z_anchor, axis=0).reshape(1, 2))
            u, s, vt = np.linalg.svd(m)
            r = np.matmul(u, vt)
            new_z = z @ r
            return new_z
        elif self.z_cali_method == 'none':
            return z
        else:
            raise NotImplementedError

def main():
    from functools import partial
    # ds = DatasetGraphDataset(data_names=['mnist_group2'], cdist_path='features',
    #                          visual_path='../prepare_data/bo/res-2')
    ds = DatasetGraphDataset(data_names=['cifar10_1class_comb0_seed0'], cdist_path='features',
                             visual_path='res-2')

    get_loader = partial(torch.utils.data.DataLoader, batch_size=1,
                         shuffle=True,
                         num_workers=0,
                         collate_fn=lambda x: list(zip(*x))
                         )
    train_loader = get_loader(ds)

    it = iter(train_loader)
    a = next(it)
    batch_graph = dgl.batch(ds[0][1])  # for testing

if __name__ == '__main__':
    main()
