import argparse
import time
from sklearn.model_selection import StratifiedShuffleSplit
import torch
import json
import numpy as np
import torch.nn as nn
import torch.optim as optim
from copy import deepcopy
from torch_geometric.loader import DataLoader
import hashlib

import utils
import model
from config import *
import lossfunc

parser = argparse.ArgumentParser()
parser.add_argument('--data', default=MCF7, help='Dataset used')
parser.add_argument('--lr', type=float, default=5e-3, help='Learning rate')
parser.add_argument('--batchsize', type=int, default=512, help='Training batch size')
parser.add_argument('--nepoch', type=int, default=100, help='Number of training epochs')
parser.add_argument('--hdim', type=int, default=128, help='Hidden feature dim')
parser.add_argument('--width', type=int, default=1, help='Width of GCN')
parser.add_argument('--depth', type=int, default=4, help='Depth of GCN')
parser.add_argument('--dropout', type=float, default=0.1, help='Dropout rate')
parser.add_argument('--normalize', type=int, default=1, help='Whether batch normalize')
parser.add_argument('--decay', type=float, default=0, help='Weight decay')
parser.add_argument('--seed', type=int, default=2025, help='Random seed')
parser.add_argument('--device', type=int, default=0, help='GPU or cpu')
parser.add_argument('--patience', type=int, default=50, help='Patience')
args = parser.parse_args()

data = args.data
lr = args.lr
batchsize = args.batchsize
nepoch = args.nepoch
hdim = args.hdim
width = args.width
depth = args.depth
dropout = args.dropout
normalize = args.normalize
decay = args.decay
seed = args.seed
device = torch.device("cuda:" + str(args.device)) if (torch.cuda.is_available() and args.device > 0) else torch.device("cpu")

patience = args.patience

seed = int(hashlib.sha1(data.encode("utf-8")).hexdigest(), 16) % seed 
utils.set_seed(seed)
print(seed)

print("Model info:")
print(json.dumps(args.__dict__, indent='\t'))

graphs, info, train_index, val_index, test_index = utils.load_data(data)

trainset = graphs[train_index]
valset = graphs[val_index]
testset = graphs[test_index]

label_count = info["label_count"]
featuredim = graphs.x.shape[1]
netype = max(graphs.edge_type) + 1
nclass = len(label_count)

gad = model.GADGNN(featuredim, hdim, nclass, netype, width, depth, dropout, normalize, device).to(device)
optimizer = optim.Adam(gad.parameters(), lr=lr, weight_decay=decay)

train_loader = DataLoader(trainset, batch_size=batchsize, shuffle=True)
val_loader = DataLoader(valset, batch_size=batchsize, shuffle=False)
test_loader = DataLoader(testset, batch_size=batchsize, shuffle=False)

bestperformance = 0
bestepoch = 0
bestmodel = deepcopy(gad)

print("Starts training...")

patiencecount = 0

criterion = lossfunc.RFCELoss(torch.LongTensor(label_count))
for epoch in range(nepoch):
    epoch_start = time.time()
    gad.train()
    epoch_loss = 0

    for train_batch in train_loader:
        optimizer.zero_grad()
        outputs = gad(train_batch.to(device))

        loss = criterion(outputs, train_batch.y)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    epoch_end = time.time()
    print('Epoch: {}, loss: {}, time cost: {}'.format(epoch, epoch_loss / (len(trainset) // batchsize), epoch_end - epoch_start))

    gad.eval()
    preds = []
    truths = []
    for val_batch in val_loader:
        outputs = gad(val_batch.to(device))
        outputs = nn.functional.softmax(outputs, dim=1)
        preds.append(outputs)
        truths.append(val_batch.y)
    preds = torch.cat(preds, dim=0)
    truths = torch.cat(truths, dim=0)

    AUROC, AUPRC, RECK, MF1 = utils.compute_metrics(preds, truths)
    print("Val auc: {}, prc: {}, reck: {}, f1: {}".format(AUROC, AUPRC, RECK, MF1))

    if bestperformance <= AUROC + AUPRC + RECK + MF1:
        bestperformance = AUROC + AUPRC + RECK + MF1
        bestepoch = epoch
        bestmodel = deepcopy(gad)
        patiencecount = 0
    else:
        patiencecount += 1
    
    if patiencecount >= patience:
        break

print("Best epoch: {}".format(bestepoch))
preds = []
truths = []
for test_batch in test_loader:
    outputs = bestmodel(test_batch.to(device))
    outputs = nn.functional.softmax(outputs, dim=1)
    preds.append(outputs)
    truths.append(test_batch.y)
preds = torch.cat(preds, dim=0)
truths = torch.cat(truths, dim=0)
AUROC, AUPRC, RECK, MF1 = utils.compute_metrics(preds, truths)    
print("Test auc: {}, prc: {}, reck: {}, f1: {}".format(AUROC, AUPRC, RECK, MF1))