import torch

# from torch.linalg import svd
from sklearn.decomposition import PCA
import torch.nn.functional as F
from torch.linalg import svd

import numpy as np

from torch_geometric.datasets import TUDataset

import torch_geometric.transforms as T
from torch_geometric.data import (
    Data,
    InMemoryDataset,
    download_url,
    extract_zip,
)
from torch_geometric.io.tu import read_file, cat
from torch_geometric.utils import coalesce, one_hot, remove_self_loops, add_self_loops, degree, to_dense_adj

import glob

import os
import os.path as osp

import tqdm
from sklearn import preprocessing

import time

from .align import align_graph_feat
from .utils import (
    get_adj_feature, 
    get_betweenness_centrality, 
    get_closeness_centrality, 
    get_eigenvector_centrality, 
    get_gaussian_feature, 
    compute_landing_probs, 

    sign_flip, 
    compute_kernel_mtx, 
    compute_kernel_mtx_chunked, 
    
)
DATA_ROOT = '/root/autodl-tmp/TUD_org'
PT_PATH = '/root/autodl-tmp/pt'




def read_tu_dataset(
        folder, prefix, dim=None, mode=None, A_trunc_dim=32,
):
    dataset_name = prefix
    dataset = TUDataset(
        root = DATA_ROOT, 
        name = dataset_name, 
        use_node_attr=False
    )
    prefix = prefix if mode is None else prefix + '_' + mode
    
    
    files = glob.glob(osp.join(folder, f'{prefix}_*.txt'))
    names = [f.split(os.sep)[-1][len(prefix) + 1:-4] for f in files]
    edge_index = read_file(folder, prefix, 'A', torch.long).t() - 1


    org_batch = read_file(folder, prefix, 'graph_indicator', torch.long) - 1
    num_nodes = org_batch.size(0)

    # edge_attr = cat([edge_attributes, edge_labels])
    edge_attr=None

    edge_index, edge_attr = remove_self_loops(edge_index, edge_attr)
    edge_index, edge_attr = coalesce(edge_index, edge_attr, num_nodes)

    org_graph_num_nodes = torch.bincount(org_batch)
    # delete graphs with 0 node
    graph_num_nodes = org_graph_num_nodes[org_graph_num_nodes > 0]
    # recover batch from num nodes
    batch = torch.cat([torch.full((count.item(),), idx, dtype=torch.long) for idx, count in enumerate(graph_num_nodes)])
    node_slice = torch.cumsum(torch.from_numpy(np.bincount(batch)), 0)
    node_slice = torch.cat([torch.tensor([0]), node_slice])
    
    row, _ = edge_index
    edge_slice = torch.cumsum(torch.from_numpy(np.bincount(batch[row])), 0)
    edge_slice = torch.cat([torch.tensor([0]), edge_slice])
    # Edge indices should start at zero for every graph.
    graph_edge_index = edge_index - node_slice[batch[row]].unsqueeze(0)

    graph_num_nodes = batch.bincount()
    max_nodes = graph_num_nodes.max().item()
    
    node_label = torch.empty((batch.size(0), 0))
    if 'node_labels' in names:
        n_node_classes = dataset.num_node_labels
        del dataset
        node_label = read_file(folder, prefix, 'node_labels', torch.long)
        if node_label.dim() == 1:
            node_label = node_label.unsqueeze(-1)
        if mode is None:
            print("None mode, map the node index to 0-based")
            node_label = node_label - node_label.min(dim=0)[0]
        node_labels = list(node_label.unbind(dim=-1))
        node_labels = [one_hot(x, num_classes=n_node_classes) for x in node_labels]
        if len(node_labels) == 1:
            node_label = node_labels[0]
        else:
            node_label = torch.cat(node_labels, dim=-1)

    node_attributes = torch.empty((batch.size(0), 0))    
    if 'node_attributes' in names:
        node_attributes = read_file(folder, prefix, 'node_attributes')
        if node_attributes.dim() == 1:
            node_attributes = node_attributes.unsqueeze(-1)
        node_attributes = torch.nan_to_num(node_attributes, nan=0.0) 

    adj_feats = torch.empty((batch.size(0), A_trunc_dim))
    close_cen = torch.empty((batch.size(0), 1))
    btwn_cen = torch.empty((batch.size(0), 1))
    deg_cen = torch.empty((batch.size(0), 1))
    eigen_cen = torch.empty((batch.size(0), 1))

    if  ('node_attributes' in names or 'node_labels' in names):
        for i in range(batch[-1]+1):
            edge_start, edge_end = edge_slice[i], edge_slice[i+1]
            node_start, node_end = node_slice[i], node_slice[i+1]
            num_nodes = graph_num_nodes[i]
   
            graph = Data(x = torch.empty(node_end-node_start, 1), edge_index=graph_edge_index[:, edge_start:edge_end])

            btwn_cen[node_start:node_end, :] = get_betweenness_centrality(graph, graph_num_nodes[i])


    if not ('node_attributes' in names or 'node_labels' in names):
        # raise ValueError('no attr')
        for i in range(batch[-1]+1):
            edge_start, edge_end = edge_slice[i], edge_slice[i+1]
            node_start, node_end = node_slice[i], node_slice[i+1]
            num_nodes = graph_num_nodes[i]
            adj_feat_i = get_adj_feature(graph_edge_index[:, edge_start:edge_end], graph_num_nodes[i], dim=A_trunc_dim)
            graph = Data(x = torch.empty(node_end-node_start, 1), edge_index=graph_edge_index[:, edge_start:edge_end])

            btwn_cen[node_start:node_end, :] = get_betweenness_centrality(graph, graph_num_nodes[i]) 
            deg_cen[node_start:node_end, :] = deg_cen[node_start:node_end, :] = degree(graph_edge_index[:, edge_start:edge_end][0], num_nodes=graph_num_nodes[i], dtype=torch.float).unsqueeze(1)
            adj_feats[node_start:node_end] = adj_feat_i            
        node_attributes = cat([deg_cen, btwn_cen, adj_feats]).to(torch.float)

    else:
        node_attributes = cat([node_attributes, node_label, btwn_cen]).to(torch.float)

    y = None
    if 'graph_labels' in names:  # Classification problem.
        y = read_file(folder, prefix, 'graph_labels', torch.long)
        _, y = y.unique(sorted=True, return_inverse=True)
    elif 'graph_attributes' in names:  # Regression problem.
        y = read_file(folder, prefix, 'graph_attributes')

    y2 = y[org_graph_num_nodes > 0]        

    return node_attributes, \
        edge_slice, graph_edge_index, edge_attr, \
        y2, batch, graph_num_nodes, node_slice, edge_index





def transform_feat(name, 
                   node_attributes, 
                   graph_edge_index, batch, 
                   dim, use_decomp, scales, pad_multi_n, 
                   mode=None, src_dir=None, align_feat=False):
    device = 'cuda'
    start_time = time.time()
    print(f'Transforming features for {name} dataset in {mode} mode')
    if node_attributes is None:
        return torch.zeros((batch.size(0), len(scales) * dim))
    if mode == 'test':
        assert src_dir is not None, "src_dir must be provided in test mode"
        src_feat = read_tu_dataset(osp.join(src_dir, 'raw'), name, dim=dim, mode='train')[0]
        scaler = preprocessing.StandardScaler()
        scaler.fit(src_feat.to(torch.float).numpy())
        src_feat = torch.tensor(scaler.transform(src_feat.to(torch.float).numpy()), dtype=torch.float)
        node_attributes = torch.tensor(scaler.transform(node_attributes.to(torch.float).numpy()), dtype=torch.float)
    else:
        src_feat = None
        scaler = preprocessing.StandardScaler()
        scaler.fit(node_attributes.to(torch.float).numpy())
        node_attributes = torch.tensor(scaler.transform(node_attributes.to(torch.float).numpy()), dtype=torch.float)

    if use_decomp == "all_graphs":
        if mode == 'train':   # return_uv == True means load train dataset
            x, Us, Ss, Vs, dist_mean = get_gaussian_feature(
                node_attributes, scales, dim, src_feat=src_feat, 
                return_uv=True, dist_mean=None) # x: [len(scales), N, dim]
            # save U, S, V and mean distance of train set
            torch.save(Us, osp.join(src_dir, 'Us.pt'))
            torch.save(Ss, osp.join(src_dir, 'Ss.pt'))
            torch.save(Vs, osp.join(src_dir, 'Vs.pt'))
            torch.save(dist_mean, osp.join(src_dir, 'dist_mean.pt'))
            torch.save(x, osp.join(src_dir, 'features.pt'))
            print(f"save U, S, V, mean distance and features of train set to {src_dir}")
            
            x = x.transpose(0, 1).reshape(node_attributes.size(0), -1)
            x = sign_flip(x)
            # x = pad_attr(x, pad_multi_n)
        
        elif mode == 'test' and not align_feat:    # load test dataset

            dist_mean = torch.load(osp.join(src_dir, 'dist_mean.pt'), weights_only=True)
            # src_feat = torch.cat([src_feat, node_attributes], dim=0)
            x =get_gaussian_feature(
                node_attributes, 
                scales=scales, 
                decom_dim=dim, 
                dist_mean=dist_mean, 
                src_feat=src_feat, 
                return_uv=False
            )   # [len(scales), N, dim]
            x = x.transpose(0, 1).reshape(-1, len(scales) * dim)   # [N, len(scales) * dim]
            x = sign_flip(x)
            x = x[-node_attributes.shape[0]:, :]


        elif mode == 'test':    # load test dataset
            S0 = torch.load(osp.join(src_dir, 'Ss.pt'), weights_only=True)
            V0 = torch.load(osp.join(src_dir, 'Vs.pt'), weights_only=True)
            dist_mean = torch.load(osp.join(src_dir, 'dist_mean.pt'), weights_only=True)

            trans_x = []
            diag_S = torch.diagonal(S0, dim1=-2, dim2=-1)
            diag_S = torch.pow(diag_S, -0.5)
            S0 = torch.diag_embed(diag_S)
            for i in range(len(scales)):
                V0[i] = sign_flip(V0[i].transpose(0, 1)).transpose(0, 1)
                if node_attributes.shape[0] > 10_0000:
                    x = compute_kernel_mtx_chunked(
                        node_attributes.to(device), src_feat.to(device), 
                        scale=scales[i], dist_mean=dist_mean, 
                        chunk_size=1000  # 根据可用内存调整此值
                        )
                else:
                    x = compute_kernel_mtx(
                        node_attributes, src_feat, 
                        scale=scales[i], dist_mean=dist_mean, 
                        mode='normal', dist_mtx=None
                    )
                trans_x.append(torch.mm(x.to(device), torch.mm(V0[i].t().to(device), S0[i].to(device))))
            x = torch.stack(trans_x, dim=0)
            del trans_x, S0, V0, src_feat
            x = x.transpose(0, 1).reshape(node_attributes.size(0), -1).cpu()
            del node_attributes

        else:   # return_uv == False and align_feat == False means load pretrain dataset
             
            x = get_gaussian_feature(
                node_attributes, scales, dim, 
                src_feat=src_feat, 
                dist_mean=None,
                return_uv=False)
            x = x.transpose(0, 1).reshape(node_attributes.size(0), -1)
            x = sign_flip(x)
        end_time = time.time()
        print(f'Transforming features for {name} dataset in {mode} mode cost {end_time - start_time} seconds')
        return x
    elif use_decomp == "single":
        pass
    elif use_decomp == "none":
        pass



def read_and_process(root, name, mode, use_decomp, dim, pad_multi_n, scales, src_dir=None, align_feat=False, 
                     A_trunc_dim=3, use_feature='btwn_cen+eigen_cen+adj_feats'):

    folder = osp.join(root, 'raw') if mode in ['train', 'test'] else osp.join(root, name, 'raw')
    prefix = name

    node_attributes, \
    edge_slice, graph_edge_index, \
    edge_attr, y, batch, graph_num_nodes, node_slice, edge_index = read_tu_dataset(folder, prefix, dim, mode=mode)

    # node_attributes, edge_slice, graph_edge_index, y, batch, node_slice = delete_graph(
    #     node_attributes, batch, edge_index=edge_index, y=y, remain_ratio=0.5
    # )
    print(f'feature dimention: {node_attributes.shape[1]}')

    node_attributes = transform_feat(
        name, 
        node_attributes, graph_edge_index, batch, 
        dim, use_decomp, scales, pad_multi_n, 
        mode=mode, src_dir=src_dir, align_feat=align_feat
        )
    
    slices = {}

    slices['x'] = node_slice
    slices['edge_index'] = edge_slice

    data = Data(x=node_attributes, edge_index=graph_edge_index, edge_attr=edge_attr, y=y)   # 不要edge_attr了
    slices['y'] = torch.arange(0, batch[-1] + 2, dtype=torch.long)

    sizes = {
        'num_node_attributes': node_attributes.size(-1),
        'num_edge_attributes': 0,
        'num_edge_labels': 0,
    }

    return data, slices, sizes

def delete_graph(node_attr, batch, edge_index, y, remain_ratio=0.5):
    n_graphs = torch.max(batch).item() + 1
    selected_graph_index = np.random.permutation(n_graphs)[:int(remain_ratio*n_graphs)]
    selected_graph_index.sort()
    sel = torch.as_tensor(selected_graph_index, dtype=batch.dtype, device=batch.device)
    mask = torch.isin(batch, sel)   
    idx = torch.nonzero(mask, as_tuple=True)[0] # index of nodes that are in selected graphs

    node_attr = node_attr[idx]
    y = y[sel]
    batch = batch[mask]
    _, batch = torch.unique(batch, sorted=True, return_inverse=True)
    n_remained = torch.max(batch).item()+1

    src_row, dst_row = edge_index   
    edge_mask_src = torch.isin(src_row, idx)
    edge_mask_dst = torch.isin(dst_row, idx)
    edge_mask = torch.logical_or(edge_mask_src, edge_mask_dst)
    edge_index = edge_index[:, edge_mask]
    new_id = torch.full((mask.numel(),), -1, dtype=torch.long)
    new_id[idx] = torch.arange(idx.numel())
    edge_index = new_id[edge_index]

    row, _ = edge_index
    edge_slice = torch.cumsum(torch.bincount(batch[row], minlength=n_remained), 0)
    edge_slice = torch.cat([torch.tensor([0], device=edge_slice.device), edge_slice])

    node_slice = torch.cumsum(torch.bincount(batch, minlength=n_remained), 0)
    node_slice = torch.cat([torch.tensor([0], device=node_slice.device), node_slice])
    # Edge indices should start at zero for every graph.
    graph_edge_index = edge_index - node_slice[batch[row]].unsqueeze(0)    

    return node_attr, edge_slice, graph_edge_index, y, batch, node_slice

    





    
    




