"""
    Denoise the dataset with the diffusion denoiser: https://arxiv.org/abs/2102.09672
    Some of the code adapted from: https://github.com/ethz-privsec/diffusion_denoised_smoothing
"""
import os
import json
import argparse
from tqdm import tqdm

# torch modules
import torch
import torchvision.utils as vutils

# diffusion model
from networks.cifar10.dm import DiffusionDenoiser

# custom libs
from utils.datasets import load_dataset



# ------------------------------------------------------------------------------
#   Globals
# ------------------------------------------------------------------------------
_use_cuda = torch.cuda.is_available()



# ------------------------------------------------------------------------------
#   Run denoising: receive data, denoise it, and store
# ------------------------------------------------------------------------------
def run_denoising(args):
    global _use_cuda

    # load datasets
    kwargs = {
            'num_workers': args.num_workers,
            'pin_memory' : args.pin_memory
        } if _use_cuda else {}
    train_loader, test_loader = load_dataset( \
            args.dataset, args.batch_size, False, False, kwargs)        # no normalization, no augmentation
    print (' : load the train/valid data [{}]'.format(args.dataset))


    # load the denoiser model
    model = DiffusionDenoiser()
    model.eval()
    if _use_cuda: model.cuda()
    print (' : load a network [{}]'.format(type(model).__name__))


    # set the timestamp t corresponding to noise level sigma
    target_sigma = args.sigma * 2
    real_sigma = 0
    t = 0
    while real_sigma < target_sigma:
        t += 1
        a = model.diffusion.sqrt_alphas_cumprod[t]
        b = model.diffusion.sqrt_one_minus_alphas_cumprod[t]
        real_sigma = b / a
    print (' : set the timestamp [{}], for the sigma [Real: {:.2f} / Target: {:.2f} (yours * 2)]'.format(t, real_sigma, target_sigma))


    # set the output location
    store_location = os.path.join('datasets', 'denoised', args.dataset)
    if not os.path.exists(store_location): os.makedirs(store_location)
    print (' : set the location to store the denoised data [{}]'.format(store_location))


    # ----------------------------------------
    #   Denoise the testing data
    # ----------------------------------------
    test_data, test_labels = None, None

    # loop over the dataset
    for data, labels in tqdm(test_loader, desc=' : [denoise-test ]'):
        if _use_cuda: data = data.cuda()

        # : forward
        denoised_data = model(data, t)              # no multistep
        denoised_data = (denoised_data + 1) / 2.    # back to [0, 1] range

        # : stack the denoised data
        if test_data == None:
            test_data = denoised_data.cpu()
        else:
            test_data = torch.cat([test_data, denoised_data.cpu()], dim=0)

        if test_labels == None:
            test_labels = labels
        else:
            test_labels = torch.cat([test_labels, labels], dim=0)

    # store the denoised tensor
    denoised_dataset = {}
    denoised_dataset['test_data'  ] = test_data
    denoised_dataset['test_labels'] = test_labels


    # ----------------------------------------
    #   Denoise the training data
    # ----------------------------------------
    train_data, train_labels = None, None

    # loop over the dataset
    for data, labels in tqdm(train_loader, desc=' : [denoise-train]'):
        if _use_cuda: data = data.cuda()

        # : forward
        denoised_data = model(data, t)              # no multistep
        denoised_data = (denoised_data + 1) / 2.    # back to [0, 1] range

        # : stack the denoised data
        if train_data == None:
            train_data = denoised_data.cpu()
        else:
            train_data = torch.cat([train_data, denoised_data.cpu()], dim=0)

        if train_labels == None:
            train_labels = labels
        else:
            train_labels = torch.cat([train_labels, labels], dim=0)
        
    # store the denoised tensor
    denoised_dataset['train_data'  ] = train_data
    denoised_dataset['train_labels'] = train_labels


    # ----------------------------------------
    #   Store the denoised data
    #   Note  : we denoise the data in [0, 1]
    #   Format: {
    #               'train_data'  : pytorch tensor of (50k, 3, 32, 32),     <- FloatTensor
    #               'train_labels': pytorch tensor of (50k),                <- LongTensor
    #               'test_data'   : pytorch tensor of (10k, 3, 32, 32),
    #               'test_labels' : pytorch tensor of (10k)m
    #           }
    # ----------------------------------------    

    # store under the location
    denoised_dataset_filename = \
        os.path.join(store_location, 'denoised_w_sigma_{}.pt'.format(args.sigma)) if args.trial < 0 else \
        os.path.join(store_location, 'denoised_w_sigma_{}_{}.pt'.format(args.sigma, args.trial))
    torch.save(denoised_dataset, denoised_dataset_filename)
    print (' : store the valid. data to [{}]'.format(denoised_dataset_filename))

    # done.


"""
    Main (to run the denoising with the clean data):
    > python denoise.py --dataset cifar10 --network DiffusionDenoiser --batch-size 32 --sigma 1.0
"""
if __name__ == '__main__':
    parser = argparse.ArgumentParser( \
        description='Remove the poisoning artifacts with the denoisers')

    # system parameters
    parser.add_argument('--num-workers', type=int, default=4,
                        help='number of workers (default: 4)')
    parser.add_argument('--pin-memory', action='store_false',
                        help='the data loader copies tensors into CUDA pinned memory')

    # dataset parameters
    parser.add_argument('--dataset', type=str, default='cifar10',
                        help='dataset used to train: cifar10.')

    # model parameters
    parser.add_argument('--network', type=str, default='DiffusionDenoiser',
                        help='model name (default: DiffusionDenoiser).')
    parser.add_argument('--classes', type=int, default=10,
                        help='number of classes in the dataset (ex. 10 in CIFAR10).')

    # denoiser parameters
    parser.add_argument('--batch-size', type=int, default=64,
                        help='input batch size for training (default: 64)')
    parser.add_argument('--sigma', type=float, default=1.0,
                        help='sigma for the denoiser (default: 1.0)')
    parser.add_argument('--trial', type=int, default=-1,
                        help='current trial number, for multiple runs with the same sigma (default: -1)')

    # execution parameters
    args = parser.parse_args()
    print (json.dumps(vars(args), indent=2))

    # run the denoiser 
    run_denoising(args)
    # Fin.
