
import click
import torch
import os
import numpy as np
import random

from frgp import FRGP

# self-defined library
from utils.tools import Log, create_argparser, args_to_dict, json_dump, json_load, new_dir
from configs import TRAINING_LOG, RESULTS_DIR, ARGUMENTS_DIR, DATA_DIR
################################################################################
# Command line arguments
################################################################################
@click.command()
@click.option('--dataset_name', type=click.Choice(['adult', 'compas', 'celeba', 'titanic', 'sp', 'credit']))
@click.option('--load_arguments', type=click.Path(), default=None,
              help='arguments JSON-file path (default: None).')
@click.option('--repeat', type=int, default=1, help='the repeat time for calculating the average metric.')
@click.option('--seed', type=int, default=-1, 
                help='random seed for parameters initialization.')
@click.option('--stop_threshold', type=float, default=None, 
                help='Stop error for Sinkhorn algorithm.')
@click.option('--beta', type=float, default=None, 
                help='The coefficient of reconstruction error.')                             
@click.option('--rsd', type=str, default=None, 
                help='save dir of the experimental resutls.')
@click.option('--_lambda', type=float, default=None, 
                help='The coefficient of fairness loss term.')
@click.option('--balanced', type=bool, default=False, 
                help='when == True, using balanced data split.') 
@click.option('--fairway', type=str, default=None, 
                help='[explicit, implicit]')
@click.option('--latent_dimension', type=int, default=None, 
                help='')
@click.option('--threshold_ratio', type=float, default=0.9, 
                help='the percentile of anomaly score of training set with ascending order') 
@click.option('--fair_c_ratio', type=str, default=None, 
                help='contamination ratio') 
@click.option('--af', type=str, default=None, 
                help='activation function [relu, leaky_relu, sigmoid, tanh, elu]') 
@click.option('--device', type=str, default=None, 
                help='[cuda | cpu]') 
def main(dataset_name, load_arguments, repeat, stop_threshold, beta, 
            rsd, seed, _lambda, balanced, fairway, latent_dimension, 
            threshold_ratio, fair_c_ratio, af, device):

    # training args
    if load_arguments is not None:
        default_args = json_load(os.path.join(ARGUMENTS_DIR, load_arguments))

    args = create_argparser(default_args).parse_args()

    # training logs
    training_log = os.path.join(TRAINING_LOG, args.log_dir)
    if not os.path.exists(training_log):
        os.makedirs(training_log)
    mylogger = Log(training_log, log_name=['log'])
    print = mylogger.print

    # update args
    args.dataset_name = dataset_name
    args.repeat = repeat
    args.seed = seed
    args.fair_c_ratio = fair_c_ratio

    if stop_threshold is not None:
        args.stop_threshold = stop_threshold

    if _lambda is not None:
        args._lambda = _lambda
    
    if rsd is not None:
        args.rsd = rsd
    
    if beta is not None:
        args.beta = beta

    if fairway is not None:
        args.fairway = fairway
    
    if latent_dimension is not None:
        args.latent_dimension = latent_dimension
    
    if af is not None:
        args.af = af
    
    if device is not None:
        args.device = device
    else:
        if not torch.cuda.is_available():
            device = 'cpu'
        else:
            device = 'cuda'

    args.threshold_ratio = threshold_ratio

    # training information
    print('Dataset: %s' % args.dataset_name)
    print('Network: %s' % args.net_name)
    print(f'Optimizer:{args.optimizer_name}')
    print(f'learning rate:[{args.lr}]')
    print(f'epochs:[{args.epochs}]')
    print(f'batch_size:[{args.batch_size}]')
    print(f'latent dimension: {args.latent_dimension}')
    print(f'log dir: {args.log_dir}')
    print(f'Fair way :{args.fairway}')
    print(f'Beta: [{args.beta}]')
    print(f'Threshold: [{args.threshold_ratio}]')
    if balanced:
        print('Balanced Data Split')
    else:
        print('Imbalanced Data Split.')
    if args.fairway == 'explicit':
        print(f'lambda: [{args._lambda}]')

    print(f'Computation device: {device}')
    print(f'Repeat time: [{args.repeat}]')

    in_channels = None

    if dataset_name in ['celeba']:
        in_channels = 3
    elif dataset_name == 'adult':
        in_channels = 14
    elif dataset_name == 'compas':
        in_channels = 8
    elif dataset_name == 'titanic':
        in_channels = 7
    elif dataset_name == 'sp':
        in_channels = 32
    elif dataset_name == 'credit':
        in_channels = 23
    
    in_channels -= 1
    
    for i in range(1, args.repeat + 1):
        if args.seed != -1:
            random.seed(args.seed)
            np.random.seed(args.seed)
            torch.manual_seed(args.seed)
            torch.cuda.manual_seed(args.seed)
            torch.backends.cudnn.deterministic = True
            print('Set seed to [%d].' % args.seed)
        print(f'================== the {i}-th time ==================')
        model = FRGP(
                dataset_name=args.dataset_name,
                net_name=args.net_name,
                data_path=DATA_DIR,
                lr=args.lr,
                optimizer_name=args.optimizer_name,
                results_dir=args.rsd,
                epochs=args.epochs,
                batch_size=args.batch_size,
                device=args.device,
                print=print,
                in_channels=in_channels,
                _lambda=args._lambda,
                latent_dimension=args.latent_dimension,
                Fairway=args.fairway,
                stop_threshold=args.stop_threshold,
                entropy_reg_coe=args.entropy_reg_coe,
                beta=args.beta,
                balanced=balanced,
                threshold_ratio=args.threshold_ratio,
                fair_c_ratio=args.fair_c_ratio,
                af_name=args.af
                )
        model.train()

    mylogger.ending


if __name__ == '__main__':

    main()
    