import sys
import os
import datetime as dt
from pytz import timezone
import csv
import time
import json

from loguru import logger
import torch

from utils.parser import create_parser
from utils.db_handler import DBHandler, db_config
from utils.others import fix_seed
from utils.data import get_data, get_preprocessed_data, preprocess_inductive, preprocess_adj
from utils.modeler import get_model

from utils.debug import get_memory_used

from train import train, validate_test

if __name__ == '__main__':
    # Start logging
    start_time = dt.datetime.now().astimezone(timezone('Asia/Seoul'))
    logger.info(f'Start @: {start_time}')
    logger.info('Now logging into a log file...')

    log_dict = dict()

    # Parse args
    logger.info('Parsing arguments...')
    args, argstr = create_parser()
    logger.info(argstr)

    if args.loguru_path == '':
        logger.add(f"logs/{start_time.strftime('%Y_%m_%d-%H_%M_%S')}.log")
    else:
        logger.add(args.loguru_path)

    # DB init
    if args.send_db:
        logger.info(f'Setup DB handler...')
        db_config['mongo_host'] = args.db_host
        db_config['mongo_db'] = args.db_name
        db_config['project'] = args.db_project
        db_config['ssh_username'] = args.db_user
        db_config['ssh_pwd'] = args.db_passwd
        db_handler = DBHandler()

    # Fix seed
    logger.info(f'Fix seed...')
    fix_seed(args.seed)
    
    # Device init
    if torch.cuda.is_available():
        assert len(args.gpus) == 1, 'Only single GPU is supported.'
        os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(gpu) for gpu in args.gpus])
        devices = [int(i) for i in os.environ['CUDA_VISIBLE_DEVICES'].split(',')]
        device = f'cuda:{devices[0]}'
        logger.info(f'Using GPU: {device}')
    else:
        logger.error('CUDA is not available. Please check your environment.')
        sys.exit(1)
    

    # logging memory usage
    logger.info(f"Initial: {get_memory_used()/10**9} GB")


    # Get dataset
    if args.use_preprocessed and not args.inductive and ('gcn' in args.model.lower()) and \
        os.path.exists(f'{args.dataset_path}/{args.dataset}_train_data_gcn.pt'):
        train_data = torch.load(f'{args.dataset_path}/{args.dataset}_train_data_gcn.pt')
        match args.dataset:
            case 'products':
                in_channels, out_channels = 100, 47
            case 'papers':
                in_channels, out_channels = 128, 172
            case _:
                if 'kron' in args.dataset:
                    in_channels, out_channels = 128, 10
                elif 'igb' in args.dataset:
                    in_channels, out_channels = 1024, 19
                else:
                    assert ('igb' in args.dataset) or ('kron' in args.dataset), 'Only products/papers/igb are supported.'
    else:
        if not args.use_preprocessed or 'igb' in args.dataset:
            logger.info('Get unpreprocessed dataset...')
            data, in_channels, out_channels = get_data(args.dataset_path, name=args.dataset)
        else:
            logger.info('Get preprocessed dataset...')
            data, in_channels, out_channels = get_preprocessed_data(args.dataset_path, name=args.dataset)
        
        # Additional preprocessing for inductive setting
        if args.inductive:
            train_data, val_data, test_data = preprocess_inductive(data)
        else:
            train_data = data # for transductive setting
        del data

        # preprocess adjacency matrix if needed... (e.g., normalization)
        if 'gcn' in args.model.lower():
            train_data.adj_t = preprocess_adj('train', train_data.adj_t, add_self_loops=True,
                                              root=args.dataset_path, dataset=args.dataset)
            torch.save(train_data, f'{args.dataset_path}/{args.dataset}_train_data_gcn.pt')
            if args.inductive:
                val_data.adj_t = preprocess_adj('valid', val_data.adj_t, add_self_loops=True,
                                                root=args.dataset_path, dataset=args.dataset)
                test_data.adj_t = preprocess_adj('test', test_data, add_self_loops=True,
                                                 root=args.dataset_path, dataset=args.dataset)
            
    # Now partition
    import numpy as np
    if args.pre_partitioned and \
        os.path.exists(f"{args.dataset_path}/adj_t/{args.dataset}_{args.partition_method}_{args.n_partitions}_perm.npy") and \
        os.path.exists(f"{args.dataset_path}/adj_t/{args.dataset}_{args.partition_method}_{args.n_partitions}_ptr.npy"):

        from utils.partition import permute
        logger.info('Get pre-partitioned adj_t (SparseTensor)...')
        train_perm = torch.from_numpy(np.load(f"{args.dataset_path}/adj_t/{args.dataset}_{args.partition_method}_{args.n_partitions}_perm.npy"))
        train_ptr = torch.from_numpy(np.load(f"{args.dataset_path}/adj_t/{args.dataset}_{args.partition_method}_{args.n_partitions}_ptr.npy"))
        train_data = permute(train_data, train_perm, log=True)
    else:
        if args.inductive:
            logger.error('Inductive setting is not yet supported...')
            sys.exit(1)
        else:
            logger.info('Do not have pre-partitioned adj_t (SparseTensor)...')
            logger.info('Partitioning with transductive setting...')
            # select partitioning method
            from utils.partition import partition, permute
            train_perm, train_ptr = partition(args.partition_method, train_data.adj_t, args.n_partitions, log=True)
            np.save(f"{args.dataset_path}/adj_t/{args.dataset}_{args.partition_method}_{args.n_partitions}_perm.npy", train_perm.numpy())
            np.save(f"{args.dataset_path}/adj_t/{args.dataset}_{args.partition_method}_{args.n_partitions}_ptr.npy", train_ptr.numpy())
            train_data = permute(train_data, train_perm, log=True)

    # Generate loader
    if args.inductive:
        logger.error('Inductive setting is not yet supported...')
        sys.exit(1)
    else:
        logger.info('Generate loader...')
        from utils.loader import SubgraphLoader
        train_loader = SubgraphLoader(train_data, train_ptr,
                                      use_cache=args.use_cache, cache_size=args.cache_size,
                                      storage_offload=args.storage_offload, storage_path=args.storage_path,
                                      optimize_dataloader=args.optimize_dataloader)
        logger.info('Loader generated...')
        # print expansion ratio
        logger.info(f'Expansion ratio: {train_loader.get_expansion_ratio()}')
        train_loader.print_reusability()

    # loss function
    logger.info('Set loss function...')
    if train_data.y.dim() == 1:
        criterion = torch.nn.CrossEntropyLoss()
    else:
        criterion = torch.nn.BCEWithLogitsLoss()
    del train_data

    # logging memory usage
    logger.info(f"After dataloader: {get_memory_used()/10**9} GB")

    # Model init
    logger.info('Model init...')
    if args.inductive:
        logger.error('Inductive setting is not yet supported...')
        sys.exit(1)
    else:
        logger.info('Transductive setting...')
        train_model = get_model(args.model, in_channels, args.n_hidden, out_channels,
                        args.n_conv_layers, train_loader, args.dropout, device,
                        args.use_cache, args.layer_wise_cache, args.checkpointing_strategy,
                        args.storage_offload, args.storage_path, args.optimize_dataloader)
        logger.info(f'Model: \n{train_model}')

    # logging memory usage
    logger.info(f"After model: {get_memory_used()/10**9} GB")
    # logging GPU memory usage (into MB)
    logger.info(f"GPU Memory Usage: {torch.cuda.memory_allocated()/10**6} MB")

    # Optimizer init
    logger.info('Optimizer init...')
    optimizer = torch.optim.Adam([
        dict(params=train_model.reg_modules.parameters(), weight_decay=args.weight_decay),
        dict(params=train_model.nonreg_modules.parameters(), weight_decay=0.0)
    ], lr=args.lr)

    # logging memory usage
    logger.info(f"After optimizer: {get_memory_used()/10**9} GB")

    # Training
    logger.info('Training...')
    # present total epochs
    logger.info(f'Total epochs: {args.n_epochs}')

    val_accs  = []
    test_accs = []
    best_acc = 0.0
    max_val_acc = 0.0
    time_str = time.strftime('%Y_%m_%d__%H_%M_%S')
    args.time_str = time_str # save for future use
    ckpt_base_dir = f'{args.ckpt_path}{time_str}__{args.dataset}_{args.model}_{args.n_conv_layers}_{args.n_epochs}'
    if args.inductive:
        args.ckpt_name = f'{ckpt_base_dir}_induc/'
        os.mkdir(args.ckpt_name)
    else:
        args.ckpt_name = f'{ckpt_base_dir}_trans/'
        os.mkdir(args.ckpt_name)
    for epoch in range(args.n_epochs):

        train_model.reset_tensors(epoch)

        start_t = time.perf_counter()
        train_dict = train(epoch, train_model, train_loader, criterion, optimizer,
                           args.inductive, args.storage_offload)
        logger.info(f'Time (s): {time.perf_counter() - start_t:.2f}s', end=' | ')
        torch.cuda.empty_cache()
        if args.inductive:
            logger.error('Inductive setting is not yet supported...')
            sys.exit(1)
        val_accs.append(train_dict['val_acc'])
        test_accs.append(train_dict['test_acc'])
        if train_dict['val_acc'] >= max_val_acc:
            max_val_acc = train_dict['val_acc']
            if epoch != 0:
                os.remove(args.ckpt_name + f'{best_acc*100:.2f}.pth') # remove old ckpt
            best_acc = train_dict['test_acc']
            torch.save(train_dict['model'].state_dict(), args.ckpt_name + f'{best_acc*100:.2f}.pth')
        logger.info(f'Best Acc (%): {best_acc: .2%}')

    with open(args.ckpt_name + 'val_accs.csv', 'w', encoding='utf-8') as file:
        writer = csv.writer(file)
        writer.writerow(val_accs)

    # copy arguments to db_dict
    for key, value in vars(args).items():
        log_dict[key] = value
    # record accuracies
    log_dict['best_acc'] = best_acc
    log_dict['val_accs'] = val_accs
    log_dict['test_accs'] = test_accs
    log_dict['max_test_acc'] = max(test_accs)
    # pop confidential information
    log_dict.pop('db_host')
    log_dict.pop('db_name')
    log_dict.pop('db_project')
    log_dict.pop('db_user')
    log_dict.pop('db_passwd')

    if args.save_json:
        # save to json
        os.makedirs(args.json_path, exist_ok=True)
        with open(args.json_path + f'/{time_str}.json', 'w', encoding='utf-8') as file:
            json.dump(log_dict, file)
        print('[GriNNder] Saved Information to JSON...')

    if args.send_db:
        # insert to db
        db_handler.insert_item_one(log_dict)
        print('[GriNNder] Inserted Information to DB...')
        # close db
        db_handler.close_connection()