import os, torch, logging, argparse
import models
import torch.nn as nn
from utils import train, test, val
from data import load_data
import pickle
import numpy as np
import gc

def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
# out dir 
OUT_PATH = "results/"
if not os.path.isdir(OUT_PATH):
    os.mkdir(OUT_PATH)

# parser for hyperparameters
parser = argparse.ArgumentParser()
parser.add_argument('--data', type=str, default='cora', help='{cora, pubmed, citeseer}.')
parser.add_argument('--model', type=str, default='DeepGCN', help='{SGC, DeepGCN, DeepGAT, GIN}')
parser.add_argument('--hid', type=int, default=32, help='Number of hidden units.')
parser.add_argument('--lr', type=float, default=0.005, help='Initial learning rate.')
parser.add_argument('--nhead', type=int, default=1, help='Number of head attentions.')
parser.add_argument('--dropout', type=float, default=0.6, help='Dropout rate.')
parser.add_argument('--epochs', type=int, default=200, help='Number of epochs to train.')
parser.add_argument('--log', type=str, default='debug', help='{info, debug}')
parser.add_argument('--wd', type=float, default=5e-4, help='Weight decay (L2 loss on parameters).')
parser.add_argument('--nlayer', type=int, default=8, help='Number of layers, works for Deep model.')
parser.add_argument('--residual', type=int, default=0, help='Residual connection')
parser.add_argument('--seed',type=int, default=0)
parser.add_argument('--runs', type=int, default=5)
parser.add_argument('--device', type=int, default=0)
# for normalization
parser.add_argument('--norm_mode', type=str, default='Label', help='{None, Center, Label, Sign}')
parser.add_argument('--norm_scale', type=float, default=0.2, help='Row-normalization scale')
parser.add_argument('--use_layer_norm', action='store_true')
parser.add_argument('--initial_norm', action='store_true')

# for data
parser.add_argument('--no_fea_norm', action='store_false', default=True, help='not normalize feature' )
parser.add_argument('--missing_rate', type=int, default=0, help='missing rate, from 0 to 100' )
parser.add_argument('--train_set', type=float, default=0., help='this is for the train set')
parser.add_argument('--neg_weight',type=float, default=1.0)
args = parser.parse_args()
print(args)
# logger
logging.basicConfig(format='%(message)s', level=getattr(logging, args.log.upper())) 
device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")

# load data
data = load_data(args.data, normalize_feature=args.no_fea_norm, missing_rate=args.missing_rate, cuda=True, device=args.device, train_set=args.train_set)
nfeat = data.x.size(1)
nclass = int(data.y.max()) + 1

best_acc = 0 
best_loss = 1e10
best_test_acc = 0
all_test_acc = []
for i in range(args.runs):
    set_seed(args.seed+i)
    print(f'seed: {args.seed+i}')
    net = getattr(models, args.model)(args, nfeat, nclass)
    net = net.to(device)
    optimizer = torch.optim.Adam(net.parameters(), args.lr, weight_decay=args.wd)
    criterion = torch.nn.CrossEntropyLoss()
    for epoch in range(args.epochs):
        gc.collect()
        cal_erank = False # if epoch == args.epochs - 1 else False
        cal_metrics = False # if epoch == args.epochs - 1 else False
        train_loss, train_acc, metrics = train(net, optimizer, criterion, data, cal_erank=cal_erank, cal_metrics=cal_metrics)
        val_loss, val_acc = val(net, criterion, data)
        test_loss, test_acc = test(net, criterion, data)
        # if epoch % 20 == 0:
        #     print(f'epoch:{epoch}, train loss: {train_loss:.4f},train acc: {train_acc:.4f}, val loss: {val_loss:.4f},val acc: {val_acc:.4f}, test loss: {test_loss:.4f},test acc: {test_acc:.4f}')
        # if epoch == args.epochs - 1 and cal_metrics:
        #     with open(f"results/{args.data}-{args.norm_mode}-{args.hid}-{args.norm_scale}-{args.nlayer}.pkl", 'wb') as f:
        #         pickle.dump(metrics, f)
        # save model 
        if best_acc < val_acc:
            best_acc = val_acc.item()
            best_test_acc = test_acc.item()
            # torch.save(net.state_dict(), OUT_PATH+'checkpoint-best-acc.pkl')
        if best_loss > val_loss:
            best_loss = val_loss.item()
            # torch.save(net.state_dict(), OUT_PATH+'checkpoint-best-loss.pkl')
    print(f'best_val_acc:{best_acc:4f}, best_val_loss:{best_loss:4f}, best_test_acc:{best_test_acc:4f}')
    # pick up the best model based on val_acc, then do test
    # val_loss, val_acc = val(net, criterion, data)
    # print(f'val: {val_loss:.4f},{val_acc:.4f}')
    # test_loss, test_acc = test(net, criterion, data)
    # print(f'test: {test_loss:.4f},{test_acc:.4f}')
    all_test_acc.append(best_test_acc)

import numpy as np

all_test_acc = np.array(all_test_acc)
print(f'{all_test_acc.mean():.4f}, {all_test_acc.std():.4f}')
