from numpy import loadtxt
from model.model import Graph_Representation_Learning
from model.DOHSC import pretrain, train_DO2HSC
from Utils.arguments import arg_parse
import os.path as osp
from torch_geometric.data import DataLoader
from torch_geometric.datasets import TUDataset
import torch
import os
from Utils.split_data import *

if __name__ == '__main__':
    args = arg_parse()
    batch_size = 128
    lr = args.lr
    DS = args.DS
    nu = args.nu
    R = 0.0
    repNum = args.repNum
    percentage = args.percentage
    hidden_dim = args.hidden_dim
    num_gc_layers = args.num_gc_layers
    latent_dim = args.latent_dim
    epochs = args.epochs
    lam = args.lam
    lr=args.lr
    lr_milestones=args.lr_milestones
    train_class = args.train_class
    path = osp.join(osp.dirname(osp.realpath(__file__)),'data', DS)
    dataset = TUDataset(path, name=DS)
    print('class 0:' + str(dataset.data.y.tolist().count(0)))
    print('class 1:' + str(dataset.data.y.tolist().count(1)))
    print('class 2:' + str(dataset.data.y.tolist().count(2)))
    auclist = np.zeros([repNum, 1])
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    if not os.path.exists(
            './data/' + DS + '/' + DS + '/test_idx_' + str(train_class) + '.txt') or not os.path.exists(
        './data/' + DS + '/' + DS + '/train_idx_' + str(train_class) + '.txt'):
        print('Split Data')
        train_idx, test_idx = split_data(dataset, DS, train_class, percentage)
    else:
        train_idx = np.array(
            (loadtxt('./data/' + DS + '/' + DS + '/train_idx_' + str(train_class) + '.txt'))).astype(
            dtype=int).tolist()
        test_idx = np.array(
            (loadtxt('./data/' + DS + '/' + DS + '/test_idx_' + str(train_class) + '.txt'))).astype(
            dtype=int).tolist()
    train_dataset = dataset[train_idx]
    print('len(train_dataset)', len(train_dataset))
    test_dataset = dataset[test_idx]
    print('len(test_dataset)', len(test_dataset))
    dataset_num_features = max(dataset.num_features, 1)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)
    print('================')
    print('lr: {}'.format(lr))
    print('num_features: {}'.format(dataset_num_features))
    print('hidden_dim: {}'.format(args.hidden_dim))
    print('num_gc_layers: {}'.format(args.num_gc_layers))
    print('================')
    for rep in range(repNum):
        if not args.pretrain:
            c = torch.randn(latent_dim).to(device)
        else:
            c = pretrain(DS, train_loader, lr)
            print('Pretraining Process')
        model = Graph_Representation_Learning(hidden_dim, num_gc_layers, latent_dim, dataset_num_features, mode='train').to(device)
        test_auc = train_DO2HSC(model, train_loader, test_loader, c, nu, epochs, lam, train_class, lr_milestones, lr, device)
        auclist[rep] = test_auc
    AUCmean_std = np.around([np.mean(auclist), np.std(auclist)], decimals=4)
    print("Testing Statistic Results:" + str(AUCmean_std))
