import numpy as np
import argparse
import os
import time
import dgl
import torch
from data_loader import load_data
from utils import *
from model import *
from eval import from_predefined_split
from logistic_regression import LREvaluator

import matplotlib.pyplot as plt

def train(args, model, optimizer, features, edges, str_encodings):
    
    model.train()
    loss = model(features, edges, str_encodings)  
    optimizer.zero_grad()    
    loss.backward()
    optimizer.step()

    return loss.item()



def main(): 
    parser = argparse.ArgumentParser()
    # Experimental setting
    parser.add_argument('-dataset', type=str, default='cora',
                        choices=['cora', 'citeseer', 'pubmed', 'chameleon', 'squirrel', 'actor', 'cornell',
                                 'texas', 'wisconsin', 'computers', 'photo', 'cs', 'physics', 'wikics'])
    parser.add_argument('-ntrials', default=10, type=int)   
    parser.add_argument('-epochs', default=800, type=int, help='epochs')        
    parser.add_argument('-lr', default=0.01, type=float, help='learning rate')  # {10^−4, 5 × 10^−4, 10^−3, 5 × 10^−3, 10^−2}
    parser.add_argument('-weight_decay', '-wd', default=1e-4, type=float, metavar='W', help='weight decay') # {10^−6, 10^−5, 10^−4, 10^−3}
    parser.add_argument('-dropout', default=0.5, type=float, help='dropout')      
    parser.add_argument('-nlayers', default=2, type=int, help='number of layers')
    parser.add_argument('-emb_dim', default=512,  type=int, help='embedding dimension') # {64, 128, 256, 512}
    parser.add_argument('-proj_dim', default=128,  type=int, help='projection dimension')   # {64, 128, 256, 512}
    parser.add_argument('-loss_batch_size', default=0, type=int, help='batch_size for loss calculation')
    parser.add_argument('-eval_freq', default=5, type=int, help='frequency of evaluation')
    parser.add_argument('-k', default=1, type=int, help='max power')    # {1, 2, 3, 4}
    parser.add_argument('-sparse', default=0, type=int, help='if sparse.')
    parser.add_argument('-eps', default=0.5, type=float, help='fixed scalar weight.') # {0.1, . . . , 0.5}
    parser.add_argument('-tau', default=1, type=float) 
    parser.add_argument('-alpha', default=1e-2, type=float) # {10^−3, 10^−2, 10^−1, 1}
    parser.add_argument('-beta', default=1e-2, type=float) # {10^−3, 10^−2, 10^−1, 1}
    parser.add_argument('-gamma', default=1e-1, type=float) # {10^−3, 10^−2, 10^−1, 1}
    parser.add_argument('-eta', default=1e-2, type=float)   # {10^−3, 10^−2, 10^−1, 1}
    parser.add_argument('-ncolors', default=15, type=int)   # {5, 10, 15, 20}
    parser.add_argument('-gpu', default=1, type=int, help='cuda device')

    args = parser.parse_args()
    print(args)

    setup_seed(42)

    device_id = 'cuda:%d' % (args.gpu)
    device = torch.device(device_id if torch.cuda.is_available() else 'cpu')

    g, features, edges, str_encodings, sum_adj, train_mask, val_mask, test_mask, labels, nnodes, nfeats = load_data(args.dataset, args.k)
    st_dim=str_encodings.shape[1]
    
    
    for trial in range(args.ntrials):

        print('Trial:{}'.format(trial))

        # setup_seed(trial)

        g = g.to(device)
        features = features.to(device)
        edges = edges.to(device)
        str_encodings = str_encodings.to(device)
        sum_adj = sum_adj.to(device)

        model = HeteModel(args, g=g, se=str_encodings, sum_adj=sum_adj, in_dim=nfeats, st_dim=st_dim, device=device).to(device)
        
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

        
        for epoch in range(1, args.epochs + 1):
            loss = train(args, model, optimizer, features, edges, str_encodings)
            print("[TRAIN] Epoch:{:04d} | loss:{:.4f}".format(epoch, loss))
           
            
            if epoch % args.eval_freq == 0:

                model.eval()
                
                embedding = model.get_embedding(features)    

                cur_split = 0 if (train_mask.shape[1]==1) else (trial % train_mask.shape[1])
                
                split = from_predefined_split(train_mask[:, cur_split],
                                                val_mask[:, cur_split], test_mask[:, cur_split], nnodes)

                result = LREvaluator()(embedding, labels, split)

                print(f'Best ACC={result["acc"]:.2f}')

    

if __name__ == '__main__':
    main()
    