from __future__ import print_function
import argparse
import time
import torch
import torch.utils.data
import torch.optim as optim
import numpy as np
import math
import random

import os
import datetime
import sys

sys.path.append("../integer_discrete_flows")

from utils.load_data import load_dataset


def prep_idf():
    parser = argparse.ArgumentParser(description='PyTorch Discrete Normalizing flows')

    parser.add_argument('-d', '--dataset', type=str, default='cifar10',
                        choices=['cifar10', 'imagenet32', 'imagenet64'],
                        metavar='DATASET',
                        help='Dataset choice.')

    parser.add_argument('-nc', '--no_cuda', action='store_true', default=False,
                        help='disables CUDA training')

    parser.add_argument('--manual_seed', type=int, help='manual seed, if not given resorts to random seed.')

    parser.add_argument('-li', '--log_interval', type=int, default=20, metavar='LOG_INTERVAL',
                        help='how many batches to wait before logging training status')

    parser.add_argument('--evaluate_interval_epochs', type=int, default=5,
                        help='Evaluate per how many epochs')

    parser.add_argument('-od', '--out_dir', type=str, default='snapshots', metavar='OUT_DIR',
                        help='output directory for model snapshots etc.')

    fp = parser.add_mutually_exclusive_group(required=False)
    fp.add_argument('-te', '--testing', action='store_true', dest='testing',
                    help='evaluate on test set after training')
    fp.add_argument('-va', '--validation', action='store_false', dest='testing',
                    help='only evaluate on validation set')
    parser.set_defaults(testing=True)

    # optimization settings
    parser.add_argument('-e', '--epochs', type=int, default=2000, metavar='EPOCHS',
                        help='number of epochs to train (default: 2000)')
    parser.add_argument('-es', '--early_stopping_epochs', type=int, default=300, metavar='EARLY_STOPPING',
                        help='number of early stopping epochs')

    parser.add_argument('-bs', '--batch_size', type=int, default=128, metavar='BATCH_SIZE',
                        help='input batch size for training (default: 100)')
    parser.add_argument('-lr', '--learning_rate', type=float, default=0.001, metavar='LEARNING_RATE',
                        help='learning rate')
    parser.add_argument('--warmup', type=int, default=10,
                        help='number of warmup epochs')

    parser.add_argument('--data_augmentation_level', type=int, default=2,
                        help='data augmentation level')

    parser.add_argument('--variable_type', type=str, default='discrete',
                        help='variable type of data distribution: discrete/continuous',
                        choices=['discrete', 'continuous'])
    parser.add_argument('--distribution_type', type=str, default='logistic',
                        choices=['logistic', 'normal', 'steplogistic'],
                        help='distribution type: logistic/normal')
    parser.add_argument('--n_flows', type=int, default=8,
                        help='number of flows per level')
    parser.add_argument('--n_levels', type=int, default=3,
                        help='number of levels')

    parser.add_argument('--n_bits', type=int, default=8,
                        help='')

    # ---------------- SETTINGS CONCERNING NETWORKS -------------
    parser.add_argument('--densenet_depth', type=int, default=2,
                        help='Depth of densenets')
    parser.add_argument('--n_channels', type=int, default=512,
                        help='number of channels in coupling and splitprior')
    # ---------------- ----------------------------- -------------


    # ---------------- SETTINGS CONCERNING COUPLING LAYERS -------------
    parser.add_argument('--coupling_type', type=str, default='shallow',
                        choices=['shallow', 'resnet', 'densenet', 'densenet++'],
                        help='Type of coupling layer')
    parser.add_argument('--splitfactor', default=0, type=int,
                        help='Split factor for coupling layers.')

    parser.add_argument('--split_quarter', dest='split_quarter', action='store_true',
                        help='Split coupling layer on quarter')
    parser.add_argument('--no_split_quarter', dest='split_quarter', action='store_false')
    parser.set_defaults(split_quarter=True)
    # ---------------- ----------------------------------- -------------


    # ---------------- SETTINGS CONCERNING SPLITPRIORS -------------
    parser.add_argument('--splitprior_type', type=str, default='shallow',
                        choices=['none', 'shallow', 'resnet', 'densenet', 'densenet++'],
                        help='Type of splitprior. Use \'none\' for no splitprior')
    # ---------------- ------------------------------- -------------


    # ---------------- SETTINGS CONCERNING PRIORS -------------
    parser.add_argument('--n_mixtures', type=int, default=1,
                        help='number of mixtures')
    # ---------------- ------------------------------- -------------

    parser.add_argument('--hard_round', dest='hard_round', action='store_true',
                        help='Rounding of translation in discrete models. Weird '
                        'probabilistic implications, only for experimental phase')
    parser.add_argument('--no_hard_round', dest='hard_round', action='store_false')
    parser.set_defaults(hard_round=True)

    parser.add_argument('--round_approx', type=str, default='smooth',
                        choices=['smooth', 'stochastic'])

    parser.add_argument('--lr_decay', default=0.999, type=float,
                        help='Learning rate')

    parser.add_argument('--temperature', default=1.0, type=float,
                        help='Temperature used for BackRound. It is used in '
                        'the the SmoothRound module. '
                        '(default=1.0')
    
    parser.add_argument('--gpu_idx', type=int, default=0,
                        help='gpu idx')
    
    #####################################
    # Einet hyperparameters
    parser.add_argument('--num_repetitions', type=int, default=4,
                        help='number of repetitions')
    parser.add_argument('--num_sums', type=int, default=12,
                        help='number of sum nodes')
    parser.add_argument('--num_input_distributions', type=int, default=12,
                        help='number of input distributions')
    parser.add_argument('--online_em_stepsize', type=float, default=0.05,
                        help='em step size')
    #####################################
    
    parser.add_argument('--enable-large-model', default = False, action = "store_true")
    
    parser.add_argument('--log_file_name', type=str, default='none')
    
    args = parser.parse_args()
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    
    args.manual_seed = random.randint(1, 100000)
    random.seed(args.manual_seed)
    torch.manual_seed(args.manual_seed)
    np.random.seed(args.manual_seed)
    
    args.enable_large_model = True
    
    args.no_pc = (args.num_repetitions == 1) and (args.num_sums == 1) and (args.num_input_distributions == 1)
    if args.enable_large_model:
        args.learning_rate = 0.0005
        args.densenet_depth = 12
        args.coupling_type = "densenet"
        args.splitprior_type = "densenet"
        if args.no_pc:
            args.n_mixtures = 5
        else:
            args.n_mixtures = 1
            args.batch_size = 128
            args.online_em_stepsize = 0.005
    
    kwargs = {'num_workers': 4, 'pin_memory': True} if args.cuda else {}
    
    # ==================================================================================================================
    # LOAD DATA
    # ==================================================================================================================
    train_loader, val_loader, test_loader, args = load_dataset(args, **kwargs)
    
    return train_loader, val_loader, test_loader, args