from torch_geometric.graphgym.config import (cfg, dump_cfg,
                                             # set_agg_dir,
                                             set_cfg, load_cfg,
                                             makedirs_rm_exist)
from torch_geometric.data import Data

import dgl
import torch
from dgl.data import DGLDataset
from dgl.dataloading import GraphDataLoader

import numpy as np

import graph_cut

import posenc_stats
import posenc_config


def extract_graph_pe(data):
    # Convert adjacency matrix to edge indices
    pes = [
        'LapPE',
           # 'EquivStableLapPE',
           # 'SignNet',
           # 'RWSE',
           # 'HKdiagSE', 'ElstaticSE',
           # 'RRWP',
           # 'SVD', 'PPR', 'WLPE', 'GCKN', 'RWDIFF'
           ]
    # pes = ['RWSE',]

    # print(data.num_nodes())
    edge_index = torch.tensor([data.edges()[0].numpy(), data.edges()[1].numpy()], dtype=torch.long)
    data = Data(edge_index=edge_index, num_nodes=data.num_nodes(), edge_weight=data.edata['weight'])
    # data = Data(edge_index=edge_index, num_nodes=data.num_nodes(), edge_weight=None)

    emb_data = posenc_stats.compute_posenc_stats(data, pe_types=pes, is_undirected=True, cfg=cfg)

    final_emb = []
    for pe in pes:
        emb = getattr(emb_data, f'{pe}_pos_enc')
        # print(f'{pe}_pos_enc', emb.shape)
        # a = emb[torch.isnan(emb)]
        final_emb.append(emb)
    # final_emb = np.concatenate(final_emb, axis=1)
    final_emb = torch.cat(final_emb, dim=1)
    return final_emb


# Step 1: Define a Multi-Graph Dataset
class DatasetGraphDataset(DGLDataset):
    def __init__(self, data_names, cdist_path, visual_path, precomputed_pe_path=None, z_cali_method='none', z_anchor=None, normalize_z=True,
                 z_mu=None, z_std=None):

        self.data_names = data_names
        self.cdist_path = cdist_path
        self.visual_path = visual_path
        self.precomputed_pe_path = precomputed_pe_path

        self.method = 'TSNE'
        self.z_anchor = z_anchor
        self.z_cali_method = z_cali_method

        self.graphs = []  # List to store graphs
        self.pes = []  # List to store graphs pe
        self.dists = []
        self.y = []  # Labels for node classification
        self.z = []
        self.zdists = []
        self.normalize_z = normalize_z
        self.ds_block_sizes = []

        self.z_mu = z_mu
        self.z_std = z_std
        super().__init__(name='distance_graph')

    def process(self):
        set_cfg(cfg)
        posenc_config.set_cfg_posenc(cfg)
        for dataset in self.data_names:
            # load input data
            clip_save = torch.load(self.cdist_path + '/' + f'{dataset}_clip_cdist_3000.tar')
            cdist, x, y = clip_save[:3]
            graph_views = []
            graph_pes = []
            # for each dataset, span 5 views

            complete_graph = graph_cut.get_complet_graph(torch.from_numpy(cdist))

            if self.precomputed_pe_path is None:
                graph_pe = extract_graph_pe(complete_graph)
                print(graph_pe.shape)
                torch.save(graph_pe, '../prepare_data/pe/pe_for_gat' + '/' + f'{dataset}_complete_pe_64freq_lap_only.tar')
            else:
                graph_pe = torch.load(self.precomputed_pe_path + '/' + f'{dataset}_complete_pe_64freq_lap_only.tar')
                # graph_pe = torch.abs(graph_pe)
            # graph_pe = extract_graph_pe(complete_graph)

            self.graphs.append(complete_graph)
            # print(graph_pe[64:])
            # print(torch.mean(graph_pe[64:], dim=0))
            # print(torch.std(graph_pe[64:], dim=0))
            # print(torch.min(graph_pe[64:], dim=0))
            # print(torch.max(graph_pe[64:], dim=0))
            self.pes.append(graph_pe)
            self.y.append(y)
            self.ds_block_sizes.append(y.shape[0])

            cdist = torch.from_numpy(cdist).to(torch.float32)
            bandwidth = torch.median(cdist) * 1.0  # Standard deviation as scaling factor
            gaussian_kernel_cdist = torch.exp(-cdist ** 2 / (2 * bandwidth ** 2))
            self.dists.append(gaussian_kernel_cdist)

            # load output data
            selected_emb, hps = torch.load(
                self.visual_path + '/' + f'visual-method-{self.method}_dataset-{dataset}_selected_emb.tar')
            selected_emb = self.calibrate_z(selected_emb, z_anchor=self.z_anchor)
            self.z.append(selected_emb)

            z = torch.from_numpy(selected_emb).to(torch.float32)
            self.zdists.append(torch.cdist(z, z))

        self.z = np.concatenate(self.z)

        if self.z_mu is None and self.z_std is None:
            self.z_mu = np.mean(self.z, axis=0)
            self.z_std = np.std(self.z, axis=0)

        if self.normalize_z:
            self.z = (self.z - self.z_mu) / (self.z_std + 1e-9)

        split_indices = np.cumsum(self.ds_block_sizes)[:-1]
        self.z = np.split(self.z, split_indices, axis=0)
        # self.z = np.split(self.z, self.ds_block_sizes, axis=0)
        pass

    def __getitem__(self, idx):
        return self.dists[idx], self.pes[idx], self.graphs[idx], self.z[idx], self.zdists[idx], self.y[idx]

    def __len__(self):
        return len(self.graphs)

    def calibrate_z(self, z, z_anchor, n_dim=2):
        if self.z_cali_method == 'svd':
            U, S, Vt = np.linalg.svd(z)
            S_matrix = np.diag(S)
            new_z = np.dot(U[:, :n_dim], S_matrix[:n_dim, :n_dim])
            return new_z
        elif self.z_cali_method == 'orth':
            m = np.matmul(np.mean(z, axis=0).reshape(2, 1), np.mean(z_anchor, axis=0).reshape(1, 2))
            u, s, vt = np.linalg.svd(m)
            r = np.matmul(u, vt)
            new_z = z @ r
            return new_z
        elif self.z_cali_method == 'none':
            return z
        else:
            raise NotImplementedError


if __name__ == '__main__':
    from functools import partial
    ds = DatasetGraphDataset(data_names=['mnist_group2'], cdist_path='../prepare_data/clip/features',
                             visual_path='../prepare_data/bo/res-2')

    get_loader = partial(torch.utils.data.DataLoader, batch_size=1,
                         shuffle=True,
                         num_workers=0,
                         collate_fn=lambda x: list(zip(*x))
                         )
    train_loader = get_loader(ds)

    it = iter(train_loader)
    a = next(it)
    batch_graph = dgl.batch(ds[0][1])  # for testing

