#!/usr/bin/python3

import os
import torch
import argparse
import numpy as np
from utils.utils_node import *
from warnings import simplefilter
from torch.optim.lr_scheduler import ReduceLROnPlateau

simplefilter(action='ignore', category=FutureWarning)
torch.set_num_threads(10)

dir_path = os.path.dirname(__file__)
nfs_dataset_path1 = '/nfs_dataset_path1/datasets/'
nfs_dataset_path2 = '/nfs_dataset_path2/datasets/'


def main(args):
    # setting seeds
    set_seed(args)

    # check nfs dataset path
    if os.path.exists(nfs_dataset_path1):
        args.dataset_path = nfs_dataset_path1
    elif os.path.exists(nfs_dataset_path2):
        args.dataset_path = nfs_dataset_path2

    dataset, split_idx = load_dataset(args)
    args = add_args(args, dataset)

    model = load_model(args)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=args.lr_patience, verbose=True)                                 
    modelOptm = ModelOptLoading(model=model, 
                                optimizer=optimizer,
                                scheduler=scheduler,                          
                                args=args)
    modelOptm.optimizing(dataset, split_idx)

    print('optmi')



if __name__=='__main__':

    parser = argparse.ArgumentParser()
    ## datasets path and name
    parser.add_argument("--dataset_path", type=str, default='datasets')
    parser.add_argument("--dataset_name", type=str, default='cora', 
                        choices=['cora','pubmed','citeseer','ogbn-proteins'])
    ## model parameters
    parser.add_argument("--model", type=str, default='GCN', choices=['GCN', 'GraphSage', 'GAT'])
    parser.add_argument("--device", type=int, default=0)  
    parser.add_argument("--num_layer", type=int, default=4)
    parser.add_argument("--embed_dim", type=int, default=128)
    parser.add_argument("--norm_type", type=str, default='motifnorm') 
    parser.add_argument("--norm_affine", type=bool, default=True)
    parser.add_argument("--activation", type=str, default='relu', choices=['relu', 'None'])
    parser.add_argument("--dropout", type=float, default=0.5)
    parser.add_argument("--skip_type", type=str, default='None', 
                        choices= ['None', 'Residual', 'Initial', 'Dense', 'Jumping'])
    parser.add_argument("--econv", action="store_true")

    ## optimization parameters and others
    parser.add_argument("--epochs", type=int, default=450)
    parser.add_argument("--epoch_slice", type=int, default=0)
    parser.add_argument("--lr", type=float, default=1e-2)
    parser.add_argument("--lr_min", type=float, default=1e-5)
    parser.add_argument("--lr_patience", type=int, default=10)
    parser.add_argument("--weight_decay", type=float, default=0.0)
    parser.add_argument("--seed", type=int, default=0)
       
    parser.add_argument("--logs_perf_dir", type=str, default=os.path.join(dir_path,'logs_perf'), 
                        help="logs' files of the loss and performance")
    parser.add_argument("--logs_stas_dir", type=str, default=os.path.join(dir_path,'logs_stas'), 
                        help="statistics' files of the avg and std")                        
    parser.add_argument("--node_weight", default=True)
    parser.add_argument("--state_dict", action="store_true")
    parser.add_argument("--breakout", action="store_true")

    args = parser.parse_args()

    main(args)
