import os.path as osp
import argparse
import torch
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from torch_sparse import spspmm, coalesce
from torch_geometric.utils import add_remaining_self_loops, remove_self_loops

from meta_model import  Meta


parser = argparse.ArgumentParser()
parser.add_argument('--device', type=str, default='cuda:3')
parser.add_argument('--dropout', type=float, default=0.6)
parser.add_argument('--hidden', type=int, default=64)
parser.add_argument('--lr', type=float, default=0.01)
parser.add_argument('--epochs', type=int, default=400)
parser.add_argument('--sample_rate', type=float, default=0.1)
parser.add_argument('--dataset', type=str, default='Cora')
parser.add_argument('--edge_tau', type=float, default=0.3)
args = parser.parse_args()
print(args)



dataset = args.dataset
path = osp.join(osp.dirname(osp.realpath(__file__)), '.', 'data', dataset)
dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())
data = dataset[0]

def getTwoHop(edge_index, num_nodes):
    value0 = edge_index.new_ones((edge_index.size(1), ), dtype=torch.float)
    index, value = spspmm(edge_index, value0, edge_index, value0, num_nodes, num_nodes, num_nodes)
    value.fill_(0)
    index, value = remove_self_loops(index, value)
    edge_index = torch.cat([edge_index, index], dim=1)
    edge_index, _ = coalesce(edge_index, None, num_nodes, num_nodes)
    return edge_index

data.edge_index, _ = add_remaining_self_loops(data.edge_index, num_nodes=data.num_nodes)
data.two_hop_index = getTwoHop(data.edge_index, data.num_nodes)


label_gene_ext_config = [
    ('dropout', [args.dropout]),
    ('gcn_conv', [dataset.num_features, args.hidden]),
]

label_gene_cla_config = [
    ('relu', [True]),
    ('dropout', [args.dropout]),
    ('gcn_conv', [args.hidden, dataset.num_classes]),
    ('log_softmax', [True]),
]

extractor_config =  [
   ('dropout', [args.dropout]),
   ('gcn_conv', [dataset.num_features, args.hidden]),
   ('relu', [True]),
   ('dropout', [args.dropout]),
   ('gcn_conv', [args.hidden, args.hidden])
]

classifier_config = [
   ('relu', [True]),
   ('dropout', [args.dropout]),
   ('gcn_conv', [args.hidden, dataset.num_classes]),
   ('log_softmax', [True])
]

meta_model = Meta(label_gene_ext_config, label_gene_cla_config, extractor_config, classifier_config, args)


device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
meta_model = meta_model.to(device)
data = data.to(device)


def train():
    meta_model.train()
    meta_model(data)
    


def test():
    meta_model.eval()
    accs = meta_model.model_eval(data)
    return accs


best_val_acc = test_acc = train_acc = 0

for epoch in range(1, args.epochs+1):
    train()
    train_acc, val_acc, tmp_test_acc = test()
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        test_acc = tmp_test_acc

    log = 'Epoch: {:03d}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}'
    print(log.format(epoch, train_acc, best_val_acc, test_acc))


model_name = meta_model.model_name
layers = 3
file_name = '{model}-{layers:d}L {tau:.1f}tau {sample_rate:.1f}sr {dataset}'.format(model=model_name, layers=layers, tau=args.edge_tau, sample_rate=args.sample_rate, dataset=args.dataset)


test_acc_log = open(file_name + ' testAcc.txt', 'a')
test_acc_log.write(str(test_acc)+'\n')
test_acc_log.close()


