import os
import pdb
import time
import torch
import random
import numpy as np

import networkx as nx
import metispy as metis

import torch_geometric
from torch_geometric.data import Data
import torch_geometric.datasets as datasets
import torch_geometric.transforms as T
from torch_geometric.utils import to_dense_adj, dense_to_sparse
from ogb.nodeproppred import PygNodePropPredDataset

from utils import LargestConnectedComponents

import sys
sys.path.insert(1, os.path.join(sys.path[0], '../..'))
from misc.utils import *

mode = ''
dist = 'overlap' # default for overlapping nodes
data_path = '../../../data'
seed = 1234
comms = [4]
n_clien_per_comm = 5
to_dense = True
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

def generate_data(dataset, n_comms):
    st = time.time()
    data = get_data(dataset)
    dataset = dataset.replace(f'{ratio_train}', f'{dist}-{n_clien_per_comm}_{ratio_train}')
    data = split_train(data, dataset, n_comms)
    split_joint(n_comms, data, dataset)
    print(f'done ({time.time()-st:.2f})')

def get_data(dataset):
    st = time.time()

    if dataset in ['Cora', 'CiteSeer', 'PubMed']:
        data = datasets.Planetoid(data_path, dataset, transform=T.NormalizeFeatures())
        data = data[0]
    elif dataset in [f'Cora_CC_total_{ratio_train}', f'CiteSeer_CC_total_{ratio_train}', f'PubMed_CC_total_{ratio_train}']:
        dataset = dataset.replace(f'_CC_total_{ratio_train}', '')
        data = datasets.Planetoid(data_path, dataset, transform=T.Compose([LargestConnectedComponents(), T.NormalizeFeatures()]))
        data = data[0]
    elif dataset in [f'Computers_CC_total_{ratio_train}', f'Photo_CC_total_{ratio_train}']:
        dataset = dataset.replace(f'_CC_total_{ratio_train}', '')
        data = datasets.Amazon(data_path, dataset, transform=T.Compose([LargestConnectedComponents(), T.NormalizeFeatures()]))
        data = data[0]
        data.train_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
        data.val_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
        data.test_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
    elif dataset in [f'ogbn-arxiv_CC_total_{ratio_train}']:
        dataset = dataset.replace(f'_CC_total_{ratio_train}', '')
        data = PygNodePropPredDataset(dataset, root=data_path, transform=T.Compose([T.ToUndirected(), LargestConnectedComponents()]))
        data = data[0]
        data.train_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
        data.val_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
        data.test_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
        data.y = data.y.view(-1)
    elif dataset in [f'ogbn-proteins_CC_total_{ratio_train}']:
        dataset = dataset.replace(f'_CC_total_{ratio_train}', '')
        data = PygNodePropPredDataset(dataset, root=data_path, transform=T.Compose([T.ToSparseTensor(attr='edge_attr', remove_edge_index=False)]))
        data = data[0]
        data.x = data.adj_t.mean(dim=1)
        data.adj_t.set_value_(None)
        data.train_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
        data.val_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
        data.test_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
    
    print(f'{dataset} have been loaded ({time.time()-st:.2f} sec)')
    return data

def split_train(data, dataset, n_comms):
    st=time.time()
    n_data = data.num_nodes
    ratio_test = (1-ratio_train)/2
    n_train = round(n_data * ratio_train)
    n_test = round(n_data * ratio_test)
    
    permuted_indices = torch.randperm(n_data)
    train_indices = permuted_indices[:n_train]
    test_indices = permuted_indices[n_train:n_train+n_test]
    val_indices = permuted_indices[n_train+n_test:]

    data.train_mask.fill_(False)
    data.val_mask.fill_(False)
    data.test_mask.fill_(False)

    data.train_mask[train_indices] = True
    data.val_mask[val_indices] = True
    data.test_mask[test_indices] = True

    torch_save(data_path,f'{dataset}{mode}/{n_comms*n_clien_per_comm}/train.pt', {
        'data': data
    })
    torch_save(data_path,f'{dataset}{mode}/{n_comms*n_clien_per_comm}/test.pt', {
        'data': data
    })
    torch_save(data_path,f'{dataset}{mode}/{n_comms*n_clien_per_comm}/val.pt', {
        'data': data
    })
    print(f'splition done, n_train:{n_train}, n_test:{n_test}, n_val:{len(val_indices)} ({time.time()-st:.2f} sec)')
    return data

def split_joint(n_comms, data, dataset):
    st = time.time()
    fast = False

    if n_comms == 1:
        n_cuts, membership = 0, [0 for _ in range(data.num_nodes)]
    else:        
        G = torch_geometric.utils.to_networkx(data)
        n_cuts, membership = metis.part_graph(G, n_comms)
    assert len(list(set(membership))) == n_comms
    print(f'graph partition done, metis, n_partitions: {len(list(set(membership)))}, n_lost_edges:{n_cuts} ({time.time()-st:.2f})')

    if to_dense:
        adj = to_dense_adj(data.edge_index)[0]

    for comm_id in range(n_comms):

        for client_id in range(n_clien_per_comm):
            # Original community
            client_indices = np.where(np.array(membership) == comm_id)[0]
            client_indices = list(client_indices)
            client_num_nodes = len(client_indices)

            # Sampling
            client_indices = random.sample(client_indices, client_num_nodes // 2)
            client_num_nodes = len(client_indices)

            client_edge_index = []
            if to_dense:
                client_adj = adj[client_indices][:, client_indices]
                client_edge_index, _ = dense_to_sparse(client_adj)
                client_edge_index = client_edge_index.T.tolist()
            else:
                for _index, _edge in enumerate(data.edge_index.T):
                    if _edge[0].item() in client_indices and \
                        _edge[1].item() in client_indices:
                        client_edge_index.append([
                            client_indices.index(_edge[0].item()), 
                            client_indices.index(_edge[1].item())
                        ])
            client_num_edges = len(client_edge_index)

            client_edge_index = torch.tensor(client_edge_index, dtype=torch.long)
            client_x = data.x[client_indices]
            client_y = data.y[client_indices]
            client_train_mask = data.train_mask[client_indices]
            client_val_mask = data.val_mask[client_indices]
            client_test_mask = data.test_mask[client_indices]

            client_data = Data(
                x = client_x,
                y = client_y,
                edge_index = client_edge_index.t().contiguous(),
                train_mask = client_train_mask,
                val_mask = client_val_mask,
                test_mask = client_test_mask
            )

            assert torch.sum(client_train_mask).item() > 0

            torch_save(data_path,f'{dataset}{mode}/{n_comms*n_clien_per_comm}/overlapping_partition_{comm_id*n_clien_per_comm+client_id}.pt', {
                'client_data': client_data,
                'client_id': client_id
            })
            print(f'client_id:{comm_id*n_clien_per_comm+client_id}, iid, n_train_node:{client_num_nodes}, n_train_edge:{client_num_edges} ({time.time()-st:.2f})')
            st = time.time()


ratio_train = 0.2
for n_comms in comms:
    generate_data(dataset=f'Cora_CC_total_{ratio_train}', n_comms=n_comms)
