import click
import torch
import logging
import random
import numpy as np

from utils.config import Config
from DFDC import DFDC
from Deep_RIM import Deep_RIM
from datasets.main import load_dataset


def compute_balance(n_clusters, y_pred, yt):
    balance = []
    nbins = len(np.unique(yt))
    size_list = []
    print(np.size(np.unique(y_pred)))
    for k in range(n_clusters):
        idx = np.where(y_pred == k)
        y_k = np.squeeze(yt[idx])
        cluster_size = np.size(idx)
        hist = np.zeros((nbins,))
        cur_size = 0
        for i in range(nbins):
            val = i
            hist[i] = np.size(np.where(y_k == val))
            cur_size += hist[i]
        size_list.append(cur_size)
        if cluster_size>0:
            p_rule = np.min(hist) / np.max(hist)
            print("min v:", np.min(hist), " max v:", np.max(hist))
            balance.append(p_rule)
    print("balance list:", balance)
    print("size list:", size_list)
    return np.amin(balance)

################################################################################
# Settings
################################################################################
@click.command()
@click.argument('dataset_name', type=click.Choice(['adult', 'bank', 'credit', 'har', 'mnist_invert', 'mnist_usps']))
@click.argument('net_name', type=click.Choice(['mnist_LeNet', 'credit_rim', 'bank_rim', 'adult_rim', 'har_rim', 'mnist_rim', 'mnist_mlp', 'DCC_LeNet']))
@click.argument('xp_path', type=click.Path(exists=True))
@click.argument('data_path', type=click.Path(exists=True))
@click.argument('s_weight', type=float, default=5.0)
@click.option('--load_config', type=click.Path(exists=True), default=None,
              help='Config JSON-file path (default: None).')
@click.option('--load_model', type=click.Path(exists=True), default=None,
              help='Model file path (default: None).')
@click.option('--objective', type=click.Choice(['deeprim','dfdc','recon']), default='dfdc',
              help='Specify clustering objective.')
@click.option('--device', type=str, default='cuda', help='Computation device to use ("cpu", "cuda", "cuda:2", etc.).')
@click.option('--seed', type=int, default=-1, help='Set seed. If -1, use randomization.')
@click.option('--optimizer_name', type=click.Choice(['adam', 'amsgrad']), default='adam',
              help='Name of the optimizer to use for network training.')
@click.option('--lr', type=float, default=0.001,
              help='Initial learning rate for network training. Default=0.001')
@click.option('--n_epochs', type=int, default=50, help='Number of epochs to train.')
@click.option('--lr_milestone', type=int, default=0, multiple=True,
              help='Lr scheduler milestones at which lr is multiplied by 0.1. Can be multiple and must be increasing.')
@click.option('--batch_size', type=int, default=128, help='Batch size for mini-batch training.')
@click.option('--weight_decay', type=float, default=0,
              help='Weight decay (L2 penalty) hyperparameter for Deep clustering objective.')
@click.option('--pretrain', type=bool, default=False,
              help='Pretrain neural network parameters via autoencoder.')
@click.option('--ae_optimizer_name', type=click.Choice(['adam', 'amsgrad']), default='adam',
              help='Name of the optimizer to use for autoencoder pretraining.')
@click.option('--ae_lr', type=float, default=0.001,
              help='Initial learning rate for autoencoder pretraining. Default=0.001')
@click.option('--ae_n_epochs', type=int, default=100, help='Number of epochs to train autoencoder.')
@click.option('--ae_lr_milestone', type=int, default=0, multiple=True,
              help='Lr scheduler milestones at which lr is multiplied by 0.1. Can be multiple and must be increasing.')
@click.option('--ae_batch_size', type=int, default=128, help='Batch size for mini-batch autoencoder training.')
@click.option('--ae_weight_decay', type=float, default=0,
              help='Weight decay (L2 penalty) hyperparameter for autoencoder objective.')
@click.option('--n_jobs_dataloader', type=int, default=0,
              help='Number of workers for data loading. 0 means that the data will be loaded in the main process.')
#@click.option('--normal_class', type=int, default=0,
#              help='Specify the normal class of the dataset (all other classes are considered anomalous).')
def main(dataset_name, net_name, xp_path, data_path, s_weight, load_config, load_model, objective, device, seed,
         optimizer_name, lr, n_epochs, lr_milestone, batch_size, weight_decay, pretrain, ae_optimizer_name, ae_lr,
         ae_n_epochs, ae_lr_milestone, ae_batch_size, ae_weight_decay, n_jobs_dataloader):
    """
    Deep Fair Discriminative Clustering, a deep method for fair clustering.

    :arg DATASET_NAME: Name of the dataset to load.
    :arg NET_NAME: Name of the neural network to use.
    :arg XP_PATH: Export path for logging the experiment.
    :arg DATA_PATH: Root path of data.
    """

    # Get configuration
    cfg = Config(locals().copy())

    # Set up logging
    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    log_file = xp_path + '/log.txt'
    file_handler = logging.FileHandler(log_file)
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)

    # Print arguments
    logger.info('Log file is %s.' % log_file)
    logger.info('Data path is %s.' % data_path)
    logger.info('Export path is %s.' % xp_path)

    logger.info('Dataset: %s' % dataset_name)
    logger.info('Network: %s' % net_name)

    # If specified, load experiment config from JSON-file
    if load_config:
        cfg.load_config(import_json=load_config)
        logger.info('Loaded configuration from %s.' % load_config)

    # Print configuration
    logger.info('Model objective: %s' % cfg.settings['objective'])
    logger.info('supervised weight: %.8f', s_weight)
    # Set seed
    if cfg.settings['seed'] != -1:
        random.seed(cfg.settings['seed'])
        np.random.seed(cfg.settings['seed'])
        torch.manual_seed(cfg.settings['seed'])
        logger.info('Set seed to %d.' % cfg.settings['seed'])

    # Default device to 'cpu' if cuda is not available
    if not torch.cuda.is_available():
        device = 'cpu'
    logger.info('Computation device: %s' % device)
    logger.info('Number of dataloader workers: %d' % n_jobs_dataloader)

    # Load data
    dataset = load_dataset(dataset_name, data_path)
    if cfg.settings['objective'] == 'dfdc': #deep fair discriminative clustering
        deep_clustering = DFDC(dataset_name, s_weight, cfg.settings['objective'])
    elif cfg.settings['objective'] == 'deeprim': #deep regularized information maximization 
        deep_clustering = Deep_RIM(dataset_name, s_weight, cfg.settings['objective'])

    deep_clustering.set_network(net_name)
    if load_model:
        deep_clustering.load_model(model_path=load_model, load_ae=True)
        logger.info('Loading model from %s.' % load_model)

    logger.info('Pretraining: %s' % pretrain)
    if pretrain:
        # Log pretraining details
        logger.info('Pretraining optimizer: %s' % cfg.settings['ae_optimizer_name'])
        logger.info('Pretraining learning rate: %g' % cfg.settings['ae_lr'])
        logger.info('Pretraining epochs: %d' % cfg.settings['ae_n_epochs'])
        logger.info('Pretraining learning rate scheduler milestones: %s' % (cfg.settings['ae_lr_milestone'],))
        logger.info('Pretraining batch size: %d' % cfg.settings['ae_batch_size'])
        logger.info('Pretraining weight decay: %g' % cfg.settings['ae_weight_decay'])

        # Pretrain model on dataset (via autoencoder)
        deep_clustering.pretrain(dataset,
                           optimizer_name=cfg.settings['ae_optimizer_name'],
                           lr=cfg.settings['ae_lr'],
                           n_epochs=cfg.settings['ae_n_epochs'],
                           lr_milestones=cfg.settings['ae_lr_milestone'],
                           batch_size=cfg.settings['ae_batch_size'],
                           weight_decay=cfg.settings['ae_weight_decay'],
                           device=device,
                           n_jobs_dataloader=n_jobs_dataloader)
    save_model =False
    if save_model:
        deep_clustering.save_model(export_model=xp_path + 'model.pt', save_ae=True)
        logger.info('save model to %s.' % "model.pt")
    
    generate_list = False
    if generate_list:
        deep_clustering.generate_list(xp_path, dataset,
                    optimizer_name=cfg.settings['optimizer_name'],
                    lr=cfg.settings['lr'],
                    n_epochs=cfg.settings['n_epochs'],
                    lr_milestones=cfg.settings['lr_milestone'],
                    batch_size=cfg.settings['batch_size'],
                    weight_decay=cfg.settings['weight_decay'],
                    device=device,
                    n_jobs_dataloader=n_jobs_dataloader)
    
    # Log training details
    logger.info('Training optimizer: %s' % cfg.settings['optimizer_name'])
    logger.info('Training learning rate: %g' % cfg.settings['lr'])
    logger.info('Training epochs: %d' % cfg.settings['n_epochs'])
    logger.info('Training learning rate scheduler milestones: %s' % (cfg.settings['lr_milestone'],))
    logger.info('Training batch size: %d' % cfg.settings['batch_size'])
    logger.info('Training weight decay: %g' % cfg.settings['weight_decay'])
    print(dataset_name, cfg.settings['optimizer_name']) 
    # Train model on dataset
    deep_clustering.train(xp_path,
                        dataset,
                        optimizer_name=cfg.settings['optimizer_name'],
                        lr=cfg.settings['lr'],
                        n_epochs=cfg.settings['n_epochs'],
                        lr_milestones=cfg.settings['lr_milestone'],
                        batch_size=cfg.settings['batch_size'],
                        weight_decay=cfg.settings['weight_decay'],
                        device=device,
                        n_jobs_dataloader=n_jobs_dataloader)

    # Test model (optional: out of sample predictions)
    """
    deep_clustering.test(dataset, device=device, n_jobs_dataloader=n_jobs_dataloader)

    indices, labels, scores, psvs = zip(*deep_clustering.results['test_scores'])
    indices, labels, scores, psvs = np.array(indices), np.array(labels), np.array(scores), np.array(psvs)
    #indices[psvs==1] += len(psvs[psvs==0]) 
    y_pred = np.argmax(scores, axis=1)
    n_cluster = len(np.unique(labels))
    balance_result = compute_balance(n_cluster, y_pred, psvs)
    print("balance result:", balance_result)
    logger.info('balance result: %s', str(balance_result))
    """
if __name__ == '__main__':
    main()
