from __future__ import print_function
import argparse
import time
import torch
import torch.utils.data
import torch.optim as optim
import numpy as np
import math
import random
from torchvision import transforms
from torch.utils.data import Dataset
from PIL import Image

import os
import datetime
import sys

sys.path.append("../integer_discrete_flows")

from utils.load_data import load_dataset


class ImageNet(Dataset):
    def __init__(self, root, transform=None):
        self.transform = transform
        self.root_dir = root
        self.files = os.listdir(self.root_dir)
        self.datasets = []
        self.labels = []
        self.img2dataset = dict()
        self.img2idx = dict()
        sample_idx = 0
        for file in self.files:
            print("> Loading {}".format(file))
            fname = os.path.join(self.root_dir, file)
            data = np.load(fname)
            self.datasets.append(data["data"].reshape(-1, 3, 32, 32))
            self.labels.append(data["labels"])
            for i in range(self.datasets[-1].shape[0]):
                self.img2dataset[sample_idx] = len(self.datasets) - 1
                self.img2idx[sample_idx] = i
                sample_idx += 1
                
        self.length = sample_idx

    def __getitem__(self, index):
        img = self.datasets[self.img2dataset[index]][self.img2idx[index]]
        label = self.labels[self.img2dataset[index]][self.img2idx[index]]

        img = Image.fromarray(np.transpose(img, (1, 2, 0)))
        if self.transform:
            img = self.transform(img)

        return img, label

    def __len__(self):
        return self.length
    
    
class ToTensorNoNorm():
    def __call__(self, X_i):
        return torch.from_numpy(np.array(X_i, copy=False)).permute(2, 0, 1)


def my_load_imagenet(resolution, args, **kwargs):
    assert resolution == 32 or resolution == 64

    args.input_size = [3, resolution, resolution]

    trainpath = '../../../../../datasets/imagenet/Imagenet{res}_train_npz'.format(res=resolution)
    valpath = '../../../../../datasets/imagenet/Imagenet{res}_val_npz'.format(res=resolution)

    data_transform = transforms.Compose([
        ToTensorNoNorm()
    ])

    print('Starting loading ImageNet')

    train_dataset = ImageNet(trainpath, transform = data_transform)
    valid_dataset = ImageNet(valpath, transform = data_transform)
    test_dataset = valid_dataset
    
    train_loader = torch.utils.data.DataLoader(
        dataset = train_dataset, batch_size = args.batch_size, shuffle = True, drop_last = True, num_workers = 4)
    val_loader = torch.utils.data.DataLoader(
        dataset = test_dataset, batch_size = args.batch_size, shuffle = True, drop_last = True, num_workers = 2)
    test_loader = val_loader

    print('Number of data images', len(train_dataset) + len(valid_dataset))

    return train_loader, val_loader, test_loader, args


def prep_idf():
    parser = argparse.ArgumentParser(description='PyTorch Discrete Normalizing flows')

    parser.add_argument('-d', '--dataset', type=str, default='imagenet32',
                        choices=['cifar10', 'imagenet32', 'imagenet64'],
                        metavar='DATASET',
                        help='Dataset choice.')

    parser.add_argument('-nc', '--no_cuda', action='store_true', default=False,
                        help='disables CUDA training')

    parser.add_argument('--manual_seed', type=int, help='manual seed, if not given resorts to random seed.')

    parser.add_argument('-li', '--log_interval', type=int, default=20, metavar='LOG_INTERVAL',
                        help='how many batches to wait before logging training status')

    parser.add_argument('--evaluate_interval_epochs', type=int, default=5,
                        help='Evaluate per how many epochs')

    parser.add_argument('-od', '--out_dir', type=str, default='snapshots', metavar='OUT_DIR',
                        help='output directory for model snapshots etc.')

    fp = parser.add_mutually_exclusive_group(required=False)
    fp.add_argument('-te', '--testing', action='store_true', dest='testing',
                    help='evaluate on test set after training')
    fp.add_argument('-va', '--validation', action='store_false', dest='testing',
                    help='only evaluate on validation set')
    parser.set_defaults(testing=True)

    # optimization settings
    parser.add_argument('-e', '--epochs', type=int, default=2000, metavar='EPOCHS',
                        help='number of epochs to train (default: 2000)')
    parser.add_argument('-es', '--early_stopping_epochs', type=int, default=300, metavar='EARLY_STOPPING',
                        help='number of early stopping epochs')

    parser.add_argument('-bs', '--batch_size', type=int, default=128, metavar='BATCH_SIZE',
                        help='input batch size for training (default: 100)')
    parser.add_argument('-lr', '--learning_rate', type=float, default=0.001, metavar='LEARNING_RATE',
                        help='learning rate')
    parser.add_argument('--warmup', type=int, default=10,
                        help='number of warmup epochs')

    parser.add_argument('--data_augmentation_level', type=int, default=2,
                        help='data augmentation level')

    parser.add_argument('--variable_type', type=str, default='discrete',
                        help='variable type of data distribution: discrete/continuous',
                        choices=['discrete', 'continuous'])
    parser.add_argument('--distribution_type', type=str, default='logistic',
                        choices=['logistic', 'normal', 'steplogistic'],
                        help='distribution type: logistic/normal')
    parser.add_argument('--n_flows', type=int, default=8,
                        help='number of flows per level')
    parser.add_argument('--n_levels', type=int, default=3,
                        help='number of levels')

    parser.add_argument('--n_bits', type=int, default=8,
                        help='')

    # ---------------- SETTINGS CONCERNING NETWORKS -------------
    parser.add_argument('--densenet_depth', type=int, default=2,
                        help='Depth of densenets')
    parser.add_argument('--n_channels', type=int, default=512,
                        help='number of channels in coupling and splitprior')
    # ---------------- ----------------------------- -------------


    # ---------------- SETTINGS CONCERNING COUPLING LAYERS -------------
    parser.add_argument('--coupling_type', type=str, default='shallow',
                        choices=['shallow', 'resnet', 'densenet', 'densenet++'],
                        help='Type of coupling layer')
    parser.add_argument('--splitfactor', default=0, type=int,
                        help='Split factor for coupling layers.')

    parser.add_argument('--split_quarter', dest='split_quarter', action='store_true',
                        help='Split coupling layer on quarter')
    parser.add_argument('--no_split_quarter', dest='split_quarter', action='store_false')
    parser.set_defaults(split_quarter=True)
    # ---------------- ----------------------------------- -------------


    # ---------------- SETTINGS CONCERNING SPLITPRIORS -------------
    parser.add_argument('--splitprior_type', type=str, default='shallow',
                        choices=['none', 'shallow', 'resnet', 'densenet', 'densenet++'],
                        help='Type of splitprior. Use \'none\' for no splitprior')
    # ---------------- ------------------------------- -------------


    # ---------------- SETTINGS CONCERNING PRIORS -------------
    parser.add_argument('--n_mixtures', type=int, default=1,
                        help='number of mixtures')
    # ---------------- ------------------------------- -------------

    parser.add_argument('--hard_round', dest='hard_round', action='store_true',
                        help='Rounding of translation in discrete models. Weird '
                        'probabilistic implications, only for experimental phase')
    parser.add_argument('--no_hard_round', dest='hard_round', action='store_false')
    parser.set_defaults(hard_round=True)

    parser.add_argument('--round_approx', type=str, default='smooth',
                        choices=['smooth', 'stochastic'])

    parser.add_argument('--lr_decay', default=0.999, type=float,
                        help='Learning rate')

    parser.add_argument('--temperature', default=1.0, type=float,
                        help='Temperature used for BackRound. It is used in '
                        'the the SmoothRound module. '
                        '(default=1.0')
    
    parser.add_argument('--gpu_idx', type=int, default=0,
                        help='gpu idx')
    
    #####################################
    # Einet hyperparameters
    parser.add_argument('--num_repetitions', type=int, default=4,
                        help='number of repetitions')
    parser.add_argument('--num_sums', type=int, default=12,
                        help='number of sum nodes')
    parser.add_argument('--num_input_distributions', type=int, default=12,
                        help='number of input distributions')
    parser.add_argument('--online_em_stepsize', type=float, default=0.05,
                        help='em step size')
    #####################################
    
    parser.add_argument('--enable-large-model', default = False, action = "store_true")
    
    parser.add_argument('--log_file_name', type=str, default='none')
    
    args = parser.parse_args()
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    
    args.manual_seed = random.randint(1, 100000)
    random.seed(args.manual_seed)
    torch.manual_seed(args.manual_seed)
    np.random.seed(args.manual_seed)
    
    args.no_pc = (args.num_repetitions == 1) and (args.num_sums == 1) and (args.num_input_distributions == 1)
    if args.enable_large_model:
        args.learning_rate = 0.0005
        args.densenet_depth = 12
        args.coupling_type = "densenet"
        args.splitprior_type = "densenet"
        if args.no_pc:
            args.n_mixtures = 5
        else:
            args.n_mixtures = 1
            args.batch_size = 128
            args.online_em_stepsize = 0.005
    
    kwargs = {'num_workers': 4, 'pin_memory': True} if args.cuda else {}
    
    # ==================================================================================================================
    # LOAD DATA
    # ==================================================================================================================
    train_loader, val_loader, test_loader, args = my_load_imagenet(32, args, **kwargs)
    
    return train_loader, val_loader, test_loader, args