import os 
import argparse

import numpy as np
import torch
import torch.optim as optim
import torch.nn.functional as F

from model import Model, preprocess_feature,  CorrReg
from utils_data import load_dataset, set_seed
import yaml

import warnings
warnings.filterwarnings('ignore')

parser = argparse.ArgumentParser(description='InfoMLP')

parser.add_argument('--dataname', type=str, default='cora', help='Name of dataset.')
parser.add_argument('--gpu', type=int, default=0, help='GPU index. Default: 0')
parser.add_argument('--setting', type = int, default=0, help='Setting index, 0: transductive; 1: inductive')

args = parser.parse_args()

if args.gpu != -1 and torch.cuda.is_available():
    args.device = 'cuda:{}'.format(args.gpu)
else:
    args.device = 'cpu'

def main(args):
    dataname = args.dataname
    device = args.device

    mode_dict = {0: 'transductive',
                 1: 'inductive'}
    mode = mode_dict[args.setting]

    config_path = f'configs/{mode}/{dataname}.yml'

    with open(config_path, 'r') as f:
        configs = yaml.load(f, Loader=yaml.Loader)
        config = configs
    '''
        Loading hyperparameters
    '''
    epochs = int(config['epochs'])
    lr = float(config['lr'])
    wd = float(config['wd'])
    hid_dim =  int(config['hid_dim'])
    
    num_layer = int(config['num_layer'])
    dropout = float(config['dropout'])
    order = int(config['order'])

    use_bn = config['use_bn']

    alpha = float(config['alpha'])
    beta = float(config['beta'])

    graph, feat, label, num_class, train_idx, val_idx, test_idx = load_dataset(dataname, mode) 
    graph = graph.remove_self_loop().add_self_loop()
    in_dim = feat.shape[1]
    out_dim = num_class

    print(graph)

    if not os.path.exists(f'aug_features/{mode}'):
        os.makedirs(f'aug_features/{mode}')
    
    if not os.path.exists(f'aug_features/{mode}/{dataname}.npy'):
        Feat = preprocess_feature(graph, feat)
        np.save(f'aug_features/{mode}/{dataname}.npy', Feat.cpu().numpy())
    else:
        Feat = np.load(f'aug_features/{mode}/{dataname}.npy') 
        Feat = torch.from_numpy(Feat).to(device)
    
    graph = graph.to(device)
    feat = feat.to(device)   
    label = label.to(device)

    # This step loads graph-augmented feature matrix up to order (args.order, K in the paper)
    sfeat = Feat[:order].sum(0) / (order)
    sfeat = sfeat.to(device)
    
    model = Model(in_dim, hid_dim, num_class, num_layer, dropout, use_bn)
    model = model.to(device)
    
    loss_ce = torch.nn.CrossEntropyLoss()
    
    opt = optim.Adam(model.parameters(), lr = lr, weight_decay = wd)

    loss_best = np.inf
    acc_eval = 0
    acc_best = 0 
    
    print(model)
    patience = 0
    # training epoches =============================================================== #
    iden = torch.tensor(np.eye(hid_dim)).to(device)
    
    print(graph)
    for epoch in range(epochs):
        patience += 1
        ''' Training '''
        model.train()
        logits, h1, h2 = model(feat, sfeat)

        loss_reg = CorrReg(h1, h2, iden, alpha, beta)
     
        loss_sup = loss_ce(logits[train_idx], label[train_idx]) 
        loss = loss_sup + loss_reg
    
        acc_train = torch.sum(logits[train_idx].argmax(dim=1) == label[train_idx]).item() / len(train_idx)

        opt.zero_grad()
        loss.backward()
        opt.step()

    
        ''' Validation '''
        model.eval()
        with torch.no_grad():
            logits, _, _ = model(feat, sfeat)

            acc_val = torch.sum(logits[val_idx].argmax(dim=1) == label[val_idx]).item() / len(val_idx)
            acc_test = torch.sum(logits[test_idx].argmax(dim=1) == label[test_idx]).item() / len(test_idx)

            print("In epoch {}, Train Acc: {:.4f} | Train Loss: {:.4f}, Val Acc: {:.4f}, Test Acc: {:.4f}".
              format(epoch, acc_train, loss.item(), acc_val, acc_test, patience))

            if acc_val > acc_best:
                acc_eval = acc_test
                acc_best = acc_val
                patience = 0
            elif acc_val == acc_best and acc_test > acc_eval:
                acc_eval = acc_test
                patience = 0
   
    print("Evaluation Acc: {:.4f}".format(acc_eval)) 

if __name__ == '__main__':
    main(args)

