import numpy as np
import random
import torch
import argparse
import time

from datetime import datetime
from collections import Counter
from sklearn.model_selection import StratifiedKFold
from .utils.train import *
from .utils.loader import *
from .utils.utility import *
from .utils.model import *

### Make argument parser(hyper-parameters)
def get_args():
    parser = argparse.ArgumentParser()
    ### Data
    parser.add_argument('--data', default='adni', help='Type of dataset')
    ### Condition
    parser.add_argument('--seed', type=int, default=100, help='Number of random seed')
    parser.add_argument('--device', type=int, default=0, help='Number of GPU device')
    parser.add_argument('--feat', type=int, default=23, help='Feature(4/5/6/7)')
    parser.add_argument('--lab', type=int, default=0, help="Label(2)")
    parser.add_argument('--model', type=str, default='ours', help='Model')
    parser.add_argument('--layer', type=int, default=2, help='Number of layers')
    ### Experiment
    parser.add_argument('--split', type=int, default=5, help="Number of splits for k-fold")
    parser.add_argument('--epoch', type=int, default=2000, help='Number of epochs')
    parser.add_argument('--hid', type=int, default=16, help='Number of hidden units')
    parser.add_argument('--lr', type=float, default=1e-3, help='Initial learing rate')
    parser.add_argument('--wd', type=float, default=5e-4, help='L2 loss on parameters')
    parser.add_argument('--dr', type=float, default=0.5, help='Dropout rate')
    ### Etc
    parser.add_argument('--TEST', action="store_true", help='Load the data')
    args = parser.parse_args()
    
    return args

### Control the randomness of all experiments
def set_randomness(seed):
    torch.manual_seed(seed) # Pytorch randomness
    np.random.seed(seed) # Numpy randomness
    random.seed(seed) # Python randomness
    
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed) # Current GPU randomness
        torch.cuda.manual_seed_all(seed) # Multi GPU randomness

### Main function
def main():
    args = get_args()
    set_randomness(args.seed)
    device = torch.device('cuda:' + str(args.device) if torch.cuda.is_available() else 'cpu')
    
    MODEL = args.model
    DATA = args.data
    TEST = args.TEST
    
    "Load dataset"
    A, X, Y, L, EIGVAL, EIGVEC = load_dataset(args)    
    
    "K-fold cross validation"
    stratified_train_test_split = StratifiedKFold(n_splits=args.split)

    idx_pairs = []
    for idx_train, idx_test in stratified_train_test_split.split(A, Y):
        idx_tr = torch.LongTensor(idx_train)
        idx_te = torch.LongTensor(idx_test)
        idx_pairs.append((idx_tr, idx_te))
    
    "Utilize GPUs for computation"
    if torch.cuda.is_available() and MODEL != 'svm':
        A = A.to(device) # (sample, node, node)
        X = X.to(device) # (sample, node, feat)
        Y = Y.to(device) # (sample)
        L = L.to(device) # (node, node)
        EIGVAL = EIGVAL.to(device) # (sample, node)
        EIGVEC = EIGVEC.to(device) # (sample, node, node)
    
    avac, avpr, avsp, avse, avf1 = list([] for _ in range(5))
    for i, idx_pair in enumerate(idx_pairs):
        print(f"=============================== Fold {i+1} ===============================")
        
        "Build data loader"
        dl_tr, dl_te = build_data_loader(args, idx_pair, A, X, Y, L, EIGVAL, EIGVEC)

        "Select the model to use"
        model = select_model(args, A, X, Y).to(device)
        optimizer = select_optimizer(args, model)
        trainer = select_trainer(args, model, optimizer, dl_tr, dl_te)
        
        "Train and test"
        if TEST == False:
            trainer.train(i+1)
        cf_acc, cf_pre, cf_spe, cf_sen, cf_f1s = trainer.test(i+1)

        "Performance"
        avac.append(cf_acc)
        avpr.append(cf_pre)
        avsp.append(cf_spe)
        avse.append(cf_sen)
        avf1.append(cf_f1s)
        
    class_info = Y.tolist()
    cnt = Counter(class_info)
    print("------------- Parameters -------------")
    if DATA == 'adni':
        print('\033[93m' + f"featrues: {args.feat}" + '\033[0m')
        print('\033[93m' + f"labels: {args.lab}" + '\033[0m')
    print('\033[93m' + f"model: {args.model}" + '\033[0m')
    print('\033[93m' + f"epoch: {args.epoch}" + '\033[0m')
    print('\033[93m' + f"hid: {args.hid}" + '\033[0m')
    print('\033[93m' + f"lr: {args.lr}" + '\033[0m')
    print('\033[93m' + f"wd: {args.wd}" + '\033[0m')
    print("--------------- Result ---------------")
    if MODEL != 'svm':
        print("==> Total parameters: {:.2f}M".format(sum(p.numel() for p in model.parameters()) / 1000000.0))
    print(f"Label distribution:   {cnt}")
    print("---------- Confusion Matrix ----------")
    print(f"{args.split}-Fold accuracy:    {avac}")
    print(f"{args.split}-Fold precision:   {avpr}")
    print(f"{args.split}=Fold sensitivity: {avse}")
    print(f"{args.split}-Fold specificity: {avsp}")
    print(f"{args.split}=Fold f1 score:    {avf1}")
    print("------------- Mean, Std --------------")
    print('\033[94m' + f"Mean: {np.mean(avac):.3f} {np.mean(avpr):.3f} {np.mean(avse):.3f} {np.mean(avf1):.3f}" + '\033[0m')
    print('\033[94m' + f"Std:  {np.std(avac):.3f} {np.std(avpr):.3f} {np.std(avse):.3f} {np.std(avf1):.3f}" + '\033[0m')

if __name__ == '__main__':
    start_time = time.time()
    
    main()

    process_time = time.time() - start_time
    hour = int(process_time // 3600)
    minute = int((process_time - hour * 3600) // 60)
    second = int(process_time % 60)
    print(f"\nTime: {hour}:{minute}:{second}")
    now = datetime.now()
    print(f"▶ {now.year}-{now.month}-{now.day} {now.hour+9}:{now.minute}:{now.second} ◀")