from torch_geometric.graphgym.config import (cfg, dump_cfg,
                                             # set_agg_dir,
                                             set_cfg, load_cfg,
                                             makedirs_rm_exist)
from torch_geometric.data import Data

import dgl
import torch
from dgl.data import DGLDataset
from dgl.dataloading import GraphDataLoader

import numpy as np

from . import graph_cut

from ..pe_utils import posenc_stats
from ..pe_utils import posenc_config
from ..pe_utils import flip_pe_sign_utils


def extract_graph_pe(edge_index, weight, N):
    # Convert adjacency matrix to edge indices
    pes = [
        # 'LapPE',
        # 'NMFPE',
           # 'EquivStableLapPE',
           # 'SignNet',
           # 'RWSE',
           # 'HKdiagSE', 'ElstaticSE',
           # 'RRWP',
           'SVD',
        # 'PPR', 'WLPE', 'GCKN', 'RWDIFF'
           ]
    # pes = ['RWSE',]

    # print(data.num_nodes())
    # edge_index = torch.tensor([data.edges()[0].numpy(), data.edges()[1].numpy()], dtype=torch.long)
    # data = Data(edge_index=edge_index, num_nodes=data.num_nodes(), edge_weight=data.edata['weight'])
    data = Data(edge_index=edge_index, num_nodes=N, edge_weight=weight)
    # data = Data(edge_index=edge_index, num_nodes=data.num_nodes(), edge_weight=None)

    emb_data = posenc_stats.compute_posenc_stats(data, pe_types=pes, is_undirected=True, cfg=cfg)

    final_emb = []
    for pe in pes:
        emb = getattr(emb_data, f'{pe}_pos_enc')
        # print(f'{pe}_pos_enc', emb.shape)
        # a = emb[torch.isnan(emb)]
        final_emb.append(emb)
    # final_emb = np.concatenate(final_emb, axis=1)
    final_emb = torch.cat(final_emb, dim=1)
    return final_emb


# Step 1: Define a Multi-Graph Dataset
class DatasetGraphDataset(DGLDataset):
    def __init__(self, data_names, cdist_path, visual_path, precomputed_pe_path=None, z_cali_method='none', z_anchor=None, normalize_z=True,
                 z_mu=None, z_std=None, flip_sign_method='none'):

        self.data_names = data_names
        self.cdist_path = cdist_path
        self.visual_path = visual_path
        self.precomputed_pe_path = precomputed_pe_path

        self.method = 'TSNE'
        self.z_anchor = z_anchor
        self.z_cali_method = z_cali_method

        self.graphs = []  # List to store graphs
        self.pes = []  # List to store graphs pe
        self.dists = []
        self.y = []  # Labels for node classification
        self.z = []
        self.zdists = []
        self.normalize_z = normalize_z
        self.ds_block_sizes = []

        self.z_mu = z_mu
        self.z_std = z_std

        self.flip_sign_method = flip_sign_method
        super().__init__(name='distance_graph')

    def process(self):
        set_cfg(cfg)
        posenc_config.set_cfg_posenc(cfg)
        for dataset in self.data_names:
            # load input data
            print(dataset)
            clip_save = torch.load(self.cdist_path + '/' + f'{dataset}_clip_cdist_3000.tar')
            cdist, x, y = clip_save[:3]
            graph_views = []
            graph_pes = []
            # for each dataset, span 5 views

            src, dst = torch.where(torch.isnan(torch.from_numpy(cdist)) == False)
            N = cdist.shape[0]
            graph = dgl.graph((src, dst), num_nodes=N)

            for k, sigma in enumerate([
                0.1, 0.5, 1, 2, 5
            ]):
                edge_ind, edge_weights = graph_cut.get_complet_graph(torch.from_numpy(cdist), sigma=sigma)
                graph.edata[f'weight{k}'] = edge_weights

                if self.precomputed_pe_path is None:
                    graph_pe = extract_graph_pe(edge_ind, edge_weights, N)
                    print(graph_pe.shape)
                    torch.save(graph_pe, '/home/****/autovisual/prepare_data/pe/pe_for_gat' + '/' + f'{dataset}_{sigma}_complete_pe_64freq_svd_u_torchf_only.tar')
                else:
                    graph_pe = torch.load(self.precomputed_pe_path + '/' + f'{dataset}_{sigma}_complete_pe_64freq_svd_u_torchf_only.tar')

                graph_pe = flip_pe_sign_utils.filp_pe_sign(graph_pe, method=self.flip_sign_method)
                graph.ndata[f'pe{k}'] = graph_pe

            # for method in [
            #     'gaussian', 'student_t',
            #                # 'linear',
            #                # 'softmax'
            # ]:
            #     complete_graph = graph_cut.get_complet_graph(torch.from_numpy(cdist), sigma=1.0, method=method)
            #     if self.precomputed_pe_path is None:
            #         graph_pe = extract_graph_pe(complete_graph)
            #         print(method, graph_pe.shape)
            #         torch.save(graph_pe,
            #                    '../prepare_data/pe/pe_for_gat' + '/' + f'{dataset}_{method}_complete_pe_64freq_lap_only.tar')
            #     else:
            #         graph_pe = torch.load(
            #             self.precomputed_pe_path + '/' + f'{dataset}_{method}_complete_pe_64freq_lap_only.tar')

                # print(graph_pe)

                # graph_views.append(complete_graph)
                # graph_pes.append(graph_pe)

            self.graphs.append(graph)
            # self.pes.append(graph_pes)
            self.y.append(y)
            self.ds_block_sizes.append(y.shape[0])

            cdist = torch.from_numpy(cdist).to(torch.float32)
            bandwidth = torch.median(cdist) * 1.0  # Standard deviation as scaling factor
            gaussian_kernel_cdist = torch.exp(-cdist ** 2 / (2 * bandwidth ** 2))
            self.dists.append(gaussian_kernel_cdist)

            # load output data
            selected_emb, hps = torch.load(
                self.visual_path + '/' + f'visual-method-{self.method}_dataset-{dataset}_selected_emb.tar')
            selected_emb = self.calibrate_z(selected_emb, z_anchor=self.z_anchor)
            self.z.append(selected_emb)

            z = torch.from_numpy(selected_emb).to(torch.float32)
            self.zdists.append(torch.cdist(z, z))

        self.z = np.concatenate(self.z)

        # torch.save((self.pes, self.graphs, self.z, self.zdists, self.y), '../prepare_data/pe/pe_for_gat' + '/' + 'sigma_mv_128ds.ds')

        if self.z_mu is None and self.z_std is None:
            self.z_mu = np.mean(self.z, axis=0)
            self.z_std = np.std(self.z, axis=0)

        if self.normalize_z:
            self.z = (self.z - self.z_mu) / (self.z_std + 1e-9)

        split_indices = np.cumsum(self.ds_block_sizes)[:-1]
        self.z = np.split(self.z, split_indices, axis=0)
        # self.z = np.split(self.z, self.ds_block_sizes, axis=0)
        pass

    def __getitem__(self, idx):
        return self.graphs[idx], self.z[idx], self.zdists[idx], self.y[idx]
        # return self.pes[idx], self.graphs[idx], self.z[idx], self.zdists[idx], self.y[idx]

    def __len__(self):
        return len(self.graphs)

    def calibrate_z(self, z, z_anchor, n_dim=2):
        if self.z_cali_method == 'svd':
            U, S, Vt = np.linalg.svd(z)
            S_matrix = np.diag(S)
            new_z = np.dot(U[:, :n_dim], S_matrix[:n_dim, :n_dim])
            return new_z
        elif self.z_cali_method == 'orth':
            m = np.matmul(np.mean(z, axis=0).reshape(2, 1), np.mean(z_anchor, axis=0).reshape(1, 2))
            u, s, vt = np.linalg.svd(m)
            r = np.matmul(u, vt)
            new_z = z @ r
            return new_z
        elif self.z_cali_method == 'none':
            return z
        else:
            raise NotImplementedError


if __name__ == '__main__':
    from functools import partial
    ds = DatasetGraphDataset(data_names=['mnist_group2'], cdist_path='../prepare_data/clip/features',
                             visual_path='../prepare_data/bo/res-2')

    get_loader = partial(torch.utils.data.DataLoader, batch_size=1,
                         shuffle=True,
                         num_workers=0,
                         collate_fn=lambda x: list(zip(*x))
                         )
    train_loader = get_loader(ds)

    it = iter(train_loader)
    a = next(it)
    batch_graph = dgl.batch(ds[0][1])  # for testing

