import random
import argparse
import os.path as osp
import numpy as np
import warnings
import pickle

import torch
from torch_scatter import scatter
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_geometric.utils import add_remaining_self_loops
from torch_geometric.data import DataLoader
from torch_geometric.datasets import TUDataset
from torch_scatter import scatter

from evaluate_embedding import evaluate_embedding
from positional_embedding import laplacian_eigenvector_pe, random_walk_pe
warnings.filterwarnings('ignore')


def load_dataset(name):
    path = 'data'
    dataset = TUDataset(path, name=name).shuffle()
    return dataset


def propagate_k_hops(x, edge_index, K, add_self_loops=False):
    if add_self_loops:
        edge_index, _ = add_remaining_self_loops(edge_index)
    for _ in range(K):
        row, col = edge_index
        x = scatter(x[col], row, dim=0, reduce='add')
    return x


def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    np.random.seed(seed)
    random.seed(seed)


def get_embeddings(loader):
    ret = []
    y = []
    print('Total sample batches in evaluation: ', len(loader))
    for index, data in enumerate(loader):
        data = data.to(device)
        pe = None
        if args.lap_k != 0:
            pe = laplacian_eigenvector_pe(data, k=args.lap_k).to(device)
        if args.walk_length != 0:
            rw_pe = random_walk_pe(data, walk_length=args.walk_length).to(device)
            pe = rw_pe if pe is None else torch.cat([pe, rw_pe], dim=-1)
        if pe is None:
            x = torch.ones(data.batch.size(0), 1).to(device) if data.x is None else data.x
        else:
            x = pe.to(device) if data.x is None else torch.cat([data.x, pe.to(data.x.device, data.x.dtype)], dim=-1)
        if index % 1000 == 0:
            print(index)
        edge_index, _ = gcn_norm(edge_index=data.edge_index, num_nodes=x.size(0))
        x = propagate_k_hops(x, edge_index, args.K, add_self_loops=args.self_loop)
        x = scatter(x, data.batch, dim=0, reduce="mean")

        ret.append(x.cpu().numpy())
        y.append(data.y.cpu().numpy())
    ret = np.concatenate(ret, 0)
    y = np.concatenate(y, 0)
    with open(f'saved_embed/prop_{args.batch_size}_{args.K}.pkl', 'wb') as f:
        pickle.dump((ret, y), f)
    return ret, y
    

def arg_parse():
    parser = argparse.ArgumentParser(description='GcnInformax Arguments.')
    parser.add_argument('--dataset', help='Dataset')
    parser.add_argument('--gpu', type=int, default=0)
    parser.add_argument('--K', type=int, default=1)
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--self_loop', action='store_true')
    parser.add_argument('--batch_size', type=int, default=128)

    parser.add_argument('--lap_k', type=int, default=0)
    parser.add_argument('--walk_length', type=int, default=0)
    return parser.parse_args()


if __name__ == '__main__':
    args = arg_parse()
    setup_seed(args.seed)
    print(args)

    dataset = load_dataset(args.dataset)
    dataset = dataset[: 10000]
    print(len(dataset))

    dataloader = DataLoader(dataset, batch_size=args.batch_size)
    device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')

    emb, y = get_embeddings(dataloader)
    print(f'num_features: {emb.shape[1]}, num_graphs: {emb.shape[0]}')
    
    accs = []
    for _ in range(5):
        acc_mean, acc_std = evaluate_embedding(emb, y)
        accs.append(acc_mean)
        print(acc_mean)
    print(f'acc: {np.mean(accs):.4f} +- {np.std(accs):.4f}')
    print()

