# -*- coding: utf-8 -*-
# +
import argparse
import numpy as np
import dgl
import torch
import logging
import torch.optim as optim
import torch.nn.functional as F

from pathlib import Path
from models import Model
from dataloader import load_data
from utils import get_logger, get_evaluator, set_seed, get_training_config, check_writable
from train_and_eval import run_transductive, run_transductive_sage, run_transductive_mini_batch

# -

def get_args():
    parser = argparse.ArgumentParser(description='PyTorch DGL implementation')
    parser.add_argument('--device', type=int, default=2, help='CUDA device')
    parser.add_argument('--seed', type=int, default=0, help='Random seed')
    parser.add_argument('--log_level', type=int, default=20, help='Logger levels for run {10: DEBUG, 20: INFO, 30: WARNING}')
    parser.add_argument('--console_log', action='store_true', 
                        help='Set to True to display log info in console')
    parser.add_argument('--output_path', type=str, default='outputs/', help='Path to save outputs')
    parser.add_argument('--num_exp', type=int, default=1, help='Repeat how many experiments')

    '''Dataset'''
    parser.add_argument('--dataset', type=str, default='ogbn-arxiv', help='Dataset')
    parser.add_argument('--data_path', type=str, default='./data/', help='Path to data')
    
    '''Model and Optimization'''
    parser.add_argument('--model_config_path', type=str, default='./train.conf.yaml', help='Path to model configeration')
    parser.add_argument('--teacher', type=str, default='SAGE', help='Teacher model')
    parser.add_argument('--num_layers', type=int, default=3, help='Model number of layers')
    parser.add_argument('--hidden_dim', type=int, default=256, help='Model hidden layer dimensions')
    parser.add_argument('--learning_rate', type=float, default=0.01)
    parser.add_argument('--dropout_ratio', type=float, default=0.5)
    parser.add_argument('--weight_decay', type=float, default=0)
    parser.add_argument('--norm_type', type=str, default='batch', help='One of [none, batch, layer]')
    parser.add_argument('--max_epoch', type=int, default=500, help='Evaluate once per how many epochs')
    parser.add_argument('--patience', type=int, default=50,
                        help='Early stop is the score on validation set does not improve for how many epochs')
    parser.add_argument('--eval_interval', type=int, default=1, help='Evaluate once per how many epochs')
    args = parser.parse_args()
    return args


def run(args):
    ''' Set seed, device, and logger '''
    set_seed(args.seed)
    device = torch.device('cuda:'+ str(args.device) if torch.cuda.is_available() else 'cpu')
    output_dir = Path.cwd().joinpath(args.output_path, 'transductive', args.dataset, args.teacher, f'seed_{args.seed}')
    check_writable(output_dir, overwrite=False)
    logger = get_logger(output_dir.joinpath('log'), args.console_log, args.log_level)
    logger.info(f'output_dir: {output_dir}')
    
    ''' Load data'''
    g, labels, idx_train, idx_val, idx_test = load_data(args.dataset, args.data_path)
    logger.info(f'Total {g.number_of_nodes()} nodes.')
    logger.info(f'Total {g.number_of_edges()} edges.')

    feats = g.ndata['feat']
    args.feat_dim = g.ndata['feat'].shape[1]
    args.label_dim = labels.int().max().item() + 1
        
    ''' Model config '''
    conf = {}
    if args.model_config_path is not None:
        conf = get_training_config(args.model_config_path, args.teacher, args.dataset)
    conf = dict(args.__dict__, **conf)
    conf['device'] = device
    logger.info(f'conf: {conf}')

    model = Model(conf)
    optimizer = optim.Adam(model.parameters(), lr=conf['learning_rate'], weight_decay=conf['weight_decay'])
    criterion = torch.nn.NLLLoss()
    evaluator = get_evaluator(conf['dataset'])
    indices = (idx_train, idx_val, idx_test)
    
    ''' Run ''' 
    if 'SAGE' in conf['model_name']:
        out, score_val, score_test = run_transductive_sage(conf, model, g, feats, labels, indices, criterion, evaluator, optimizer, logger)
    elif 'MLP' in conf['model_name']:
        score_val, score_test = run_transductive_mini_batch(conf, model.encoder, feats, labels, indices, criterion, evaluator, optimizer, logger)
    else:
        out, score_val, score_test = run_transductive(conf, model, g, feats, labels, indices, criterion, evaluator, optimizer, logger)
    
        
    logger.info(f"Model: {conf['teacher']}. Dataset: {conf['dataset']}")
    logger.info(f"Best valid model on test set: score_val: {score_val :.4f}, score_test: {score_test :.4f}")
    logger.info(f"num_layers: {conf['num_layers']}. hidden_dim: {conf['hidden_dim']}. dropout_ratio {conf['dropout_ratio']}" )
    logger.info(f"# params {sum(p.numel() for p in model.parameters())}")

    ''' Saving results '''
    # Teacher output
    if 'MLP' not in conf['model_name']:
        out_np = out.detach().cpu().numpy()
        np.savez(output_dir.joinpath('out'), out_np)

    # Model
    torch.save(model.state_dict(), output_dir.joinpath('model'))

    # Test result
    with open(output_dir.parent.joinpath('test_results'), 'a+') as f:
        f.write(f"{score_test :.4f}\n")
        
    return score_test


# +
def repeat_run(args):
    s = []
    for seed in range(args.num_exp):
        args.seed = seed
        score_t = run(args)
        s += [score_t]
    score_test = np.array(s)
    print(f'{score_test.mean() : .4f}  {score_test.std() : .4f}')
    
def main():
    args = get_args()
    if args.num_exp == 1:
        score = run(args)
        print(f'score: {score: .4f}')
    elif args.num_exp > 1:
        repeat_run(args)


# -

if __name__ == "__main__":
    main()



