import sys
import os
sys.path.append(os.getcwd())
import torch
import torch.nn.functional as F
import numpy as np
import process
import utils_gcnii
from torch_geometric.utils import sparse as sparseConvert
from src import wgnn_network as GN
from src import wgnn_graphops as GO
print(torch.cuda.get_device_properties('cuda:0'))
import argparse
parser = argparse.ArgumentParser(description="wgnn_fully_supervised")
parser.add_argument(
    "--dataset",
    default='cora',
    type=str,
    help='dataset name',
)

parser.add_argument(
    "--omega",
    default=1,
    type=int,
    help='1 if use omegaGCN, 0 otherwise',
)

parser.add_argument(
    "--attspat",
    default=1,
    type=int,
    help='1 if use attention for spatial operation, 0 otherwise',
)

parser.add_argument(
    "--attHeads",
    default=1,
    type=int,
    help='number of attention heads',
)

parser.add_argument(
    "--numOmega",
    default=1,
    type=int,
    help='number of omega to learn',
)

nlayers = 2
nomega = 1
nheads = 1
ncheckpoints = 1
args = parser.parse_args()
n_channels = 64
nopen = n_channels
nhid = n_channels
nNclose = n_channels
nlayer = nlayers
datastr = args.dataset
dropout = 0.5
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
lr = 0.001
attLR = 0.001
lrGCN = 0.001
lrOmega = 0.001
wd = 1e-4
wdGCN = 1e-4
wdOmega = 1e-4
attWD = 1e-4


def train_step(model, optimizer, features, labels, adj, idx_train):
    model.train()
    optimizer.zero_grad()
    I = adj[0, :]
    J = adj[1, :]
    N = labels.shape[0]
    w = torch.ones(adj.shape[1]).to(device)
    G = GO.graph(I, J, N, W=w, pos=None, faces=None)
    G = G.to(device)
    xn = features
    [out, _] = model(xn, G, omega=args.omega,
                     attention=args.attspat, checkpoints=ncheckpoints)
    acc_train = utils_gcnii.accuracy(out[idx_train], labels[idx_train].to(device))
    loss_train = F.nll_loss(out[idx_train], labels[idx_train].to(device))
    loss_train.backward()
    optimizer.step()
    return loss_train.item(), acc_train.item()


def eval_test_step(model, features, labels, adj, idx_test):
    model.eval()
    with torch.no_grad():
        I = adj[0, :]
        J = adj[1, :]
        N = labels.shape[0]
        w = torch.ones(adj.shape[1]).to(device)

        G = GO.graph(I, J, N, W=w, pos=None, faces=None)
        G = G.to(device)
        xn = features
        [out, _] = model(xn, G, omega=args.omega,
                         attention=args.attspat, checkpoints=ncheckpoints)

        loss_test = F.nll_loss(out[idx_test], labels[idx_test].to(device))
        acc_test = utils_gcnii.accuracy(out[idx_test], labels[idx_test].to(device))
        return loss_test.item(), acc_test.item()


def train(datastr, splitstr, num_output):
    slurm = False
    adj, features, labels, idx_train, idx_val, idx_test, num_features, num_labels = process.full_load_data(
        datastr,
        splitstr, slurm=slurm)
    adj = adj.to_dense()
    [edge_index, edge_weight] = sparseConvert.dense_to_sparse(adj)
    del adj
    edge_index = edge_index.to(device)
    features = features.to(device).t().unsqueeze(0)
    idx_train = idx_train.to(device)
    idx_test = idx_test.to(device)
    labels = labels.to(device)
    numAttHeads = args.attHeads if args.attspat else 1
    model = GN.wgnn(num_features, nopen, nhid, nlayer,
                    num_output=num_output,
                    dropOut=dropout,
                    numAttHeads=numAttHeads, num_omega=nomega, omega_perchannel=nomega)
    model.reset_parameters()
    model = model.to(device)
    optimizer = torch.optim.Adam([
        dict(params=model.KN1, lr=lrGCN, weight_decay=wdGCN),
        dict(params=model.K1Nopen, weight_decay=wd),
        dict(params=model.KNclose, weight_decay=wd),
        dict(params=model.att_src, lr=attLR, weight_decay=attWD),
        dict(params=model.att_dst, lr=attLR, weight_decay=attWD),
        dict(params=model.omega, lr=lrOmega, weight_decay=wdOmega),
    ], lr=lr)
    bad_counter = 0
    best = 0
    for epoch in range(10000):
        loss_tra, acc_tra = train_step(model, optimizer, features, labels, edge_index, idx_train)
        loss_val, acc_test = eval_test_step(model, features, labels, edge_index, idx_test)

        if acc_test > best:
            best = acc_test
            bad_counter = 0
        else:
            bad_counter += 1

        if bad_counter == 1000:
            break
    acc = best

    return acc * 100


acc_list = []
for i in range(10):
    if datastr == "cora":
        num_output = 7
    elif datastr == "citeseer":
        num_output = 6
    elif datastr == "pubmed":
        num_output = 3
    elif datastr == "chameleon":
        num_output = 5
    else:
        num_output = 5
    splitstr = '../splits/' + datastr + '_split_0.6_0.2_' + str(i) + '.npz'

    acc_list.append(train(datastr, splitstr, num_output))
    print(i, ": {:.2f}".format(acc_list[-1]))

mean_test_acc = np.mean(acc_list)
std_test_acc = np.std(acc_list)
print("Test acc.:{:.3f},".format(mean_test_acc), ", std::{:.3f}".format(std_test_acc), flush=True)
