import os
import click
import pickle
import torch
import dnnlib
from torch_utils import distributed as dist
from torch_utils import misc
from utils import get_sampler_settings
import monitor
import numpy as np

# torch.set_printoptions(precision=16)

#----------------------------------------------------------------------------

def monitor_manager(
    outdir              = '',
    network             = '',
    network_pkl         = '', 
    dataset_kwargs      = {},
    data_loader_kwargs  = {},
    seed                = 0,
    device              = torch.device('cuda'),
    class_idx           = None,
    num_steps           = None,
    func                = None,
):    
    
    # Load dataset.
    dist.print0('Loading dataset...')
    dataset_obj = dnnlib.util.construct_class_by_name(**dataset_kwargs) # subclass of training.dataset.Dataset
    dataset_sampler = misc.InfiniteSampler(dataset=dataset_obj, rank=dist.get_rank(), num_replicas=dist.get_world_size(), seed=seed, shuffle=False)
    dataset_iterator = iter(torch.utils.data.DataLoader(dataset=dataset_obj, sampler=dataset_sampler, batch_size=1, **data_loader_kwargs))

    # Pick a good test image
    test_idx = 105          # blue car in main text
    for i in range(test_idx):
        images, labels = next(dataset_iterator)
    test_images, test_labels = next(dataset_iterator)
    test_images = test_images.to(device).to(torch.float32) / 127.5 - 1                  # [1, 3, 32, 32]
    test_labels = test_labels.to(device)
    
    # Load the whole dataset.
    dataset_obj = dnnlib.util.construct_class_by_name(**dataset_kwargs) # subclass of training.dataset.Dataset
    dataset_sampler = misc.InfiniteSampler(dataset=dataset_obj, rank=dist.get_rank(), num_replicas=dist.get_world_size(), seed=seed, shuffle=False)
    dataset_iterator = iter(torch.utils.data.DataLoader(dataset=dataset_obj, sampler=dataset_sampler, batch_size=50000, **data_loader_kwargs))
    
    images_all, labels = next(dataset_iterator)
    images_all = images_all.to(device).to(torch.float32) / 127.5 - 1
    
    # Load network.
    dist.print0(f'Loading network from "{network_pkl}"...')
    with dnnlib.util.open_url(network_pkl, verbose=(dist.get_rank() == 0)) as f:
        net = pickle.load(f)['ema'].to(device)
    net.eval()
    
    # get t_steps
    disc = sche = scale = network.split('_')[0]
    t_steps, sigma, sigma_deriv, sigma_inv, s, s_deriv, solver = get_sampler_settings(net, num_steps, discretization=disc, schedule=sche, scaling=scale, device=device)
    
    # manager
    if func == 'all_trajs':
        monitor.monitor_all_trajs(outdir=outdir, network_name=network, net=net, images_all=images_all, test_images=test_images, test_labels=test_labels, device=device, \
                                        t_steps=t_steps, sigma=sigma, sigma_deriv=sigma_deriv, sigma_inv=sigma_inv, s=s, s_deriv=s_deriv, solver=solver)
    elif func == 'cos_collect':
        monitor.monitor_cos_collect(outdir=outdir, network_name=network, net=net, device=device, class_idx=class_idx, \
                                    t_steps=t_steps, sigma=sigma, sigma_deriv=sigma_deriv, sigma_inv=sigma_inv, s=s, s_deriv=s_deriv, solver=solver)
    elif func == 'denoiser_std_collect':
        monitor.monitor_denoiser_std_collect(outdir=outdir, network_name=network, net=net, images_all=images_all, device=device, \
                                             t_steps=t_steps, sigma=sigma, sigma_deriv=sigma_deriv, sigma_inv=sigma_inv, s=s, s_deriv=s_deriv, solver=solver)
    elif func == 'deviation_collect':
        monitor.monitor_deviation_collect(outdir=outdir, network_name=network, net=net, images_all=images_all, device=device, class_idx=class_idx, \
                                               t_steps=t_steps, sigma=sigma, sigma_deriv=sigma_deriv, sigma_inv=sigma_inv, s=s, s_deriv=s_deriv, solver=solver)
    elif func == 'interpolation':
        monitor.monitor_interpolation(outdir=outdir, network_name=network, inter_size=9, net=net, images_all=images_all, device=device, class_idx=class_idx, \
                                      t_steps=t_steps, sigma=sigma, sigma_deriv=sigma_deriv, sigma_inv=sigma_inv, s=s, s_deriv=s_deriv, solver=solver)
    elif func == 'interpolation_generate':
        monitor.monitor_interpolation_generate(outdir=outdir, network_name=network, inter_size=9, net=net, images_all=images_all, device=device, class_idx=class_idx, \
                                               t_steps=t_steps, sigma=sigma, sigma_deriv=sigma_deriv, sigma_inv=sigma_inv, s=s, s_deriv=s_deriv, solver=solver)
    elif func == 'normTraj_collect':
        monitor.monitor_normTraj_collect(outdir=outdir, net=net, images_all=images_all, device=device, class_idx=class_idx, \
                                         t_steps=t_steps, sigma=sigma, sigma_deriv=sigma_deriv, sigma_inv=sigma_inv, s=s, s_deriv=s_deriv, solver=solver)
    elif func == 'trajDistance_collect':
        monitor.monitor_trajDistance_collect(outdir=outdir, network_name=network, net=net, images_all=images_all, device=device, class_idx=class_idx, \
                                             t_steps=t_steps, sigma=sigma, sigma_deriv=sigma_deriv, sigma_inv=sigma_inv, s=s, s_deriv=s_deriv, solver=solver)
    elif func == 'traj_generate':
        monitor.monitor_traj_generate(outdir=outdir, network_name=network, net=net, device=device, class_idx=class_idx, \
                                      t_steps=t_steps, sigma=sigma, sigma_deriv=sigma_deriv, sigma_inv=sigma_inv, s=s, s_deriv=s_deriv, solver=solver)
    else:
        print('Unsuported function!')
        1/0

#----------------------------------------------------------------------------

@click.command()
# Main options
@click.option('--data',          help='Path to the dataset', metavar='ZIP|DIR',                     type=str, required=True)
@click.option('--cond',          help='Train class-conditional model', metavar='BOOL',              type=bool, default=False, show_default=True)
# Performance-related.
@click.option('--cache',         help='Cache dataset in CPU memory', metavar='BOOL',                type=bool, default=True, show_default=True)
@click.option('--workers',       help='DataLoader worker processes', metavar='INT',                 type=click.IntRange(min=1), default=1, show_default=True)
# monitor setting
@click.option('--seed',          help='Random seed  [default: 1]', metavar='INT',                   type=int, default=1)
@click.option('--network',       help='Network pickle filename', metavar='STR',                     type=str, required=True)
@click.option('--outdir',        help='output path', metavar='PATH',                                type=str, default='./')
@click.option('--num_steps',     help='none', metavar='INT',                                        type=click.IntRange(min=1), default=30)
@click.option('--func',          help='Choose which func to run', metavar='STR',                    type=str, required=True)

def main(**kwargs):
    
    opts = dnnlib.EasyDict(kwargs)
    os.makedirs(opts.outdir, exist_ok=True)
    torch.multiprocessing.set_start_method('spawn')
    dist.init()

    if opts.network == 'edm_ddpmpp_uncond':
        opts.network_pkl = './edms/edm-cifar10-32x32-uncond-vp.pkl'
    elif opts.network == 'edm_ddpmpp_cond':
        opts.network_pkl = './edms/edm-cifar10-32x32-cond-vp.pkl'
    elif opts.network == 'edm_ncsnpp_uncond':
        opts.network_pkl = './edms/edm-cifar10-32x32-uncond-ve.pkl'
    elif opts.network == 'vp_uncond':
        opts.network_pkl = './edms/baseline-cifar10-32x32-uncond-vp.pkl'
    else:
        print('Unsuported network')
        1/0

    if opts.network.split('_')[-1] == 'cond':
        opts.cond = True
    else:
        opts.cond = False

    # Initialize config dict.
    c = dnnlib.EasyDict()
    c.dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path=opts.data, use_labels=opts.cond, cache=opts.cache)
    c.data_loader_kwargs = dnnlib.EasyDict(pin_memory=True, num_workers=opts.workers, prefetch_factor=2)

    # Validate dataset options.
    try:
        dataset_obj = dnnlib.util.construct_class_by_name(**c.dataset_kwargs)
        dataset_name = dataset_obj.name
        c.dataset_kwargs.resolution = dataset_obj.resolution # be explicit about dataset resolution
        c.dataset_kwargs.max_size = len(dataset_obj) # be explicit about dataset size
        if opts.cond and not dataset_obj.has_labels:
            raise click.ClickException('--cond=True requires labels specified in dataset.json')
        del dataset_obj # conserve memory
    except IOError as err:
        raise click.ClickException(f'--data: {err}')

    # Random seed.
    if opts.seed is not None:
        c.seed = opts.seed + dist.get_rank()
        np.random.seed(c.seed)
        torch.manual_seed(c.seed)
    else:
        seed = torch.randint(1 << 31, size=[], device=torch.device('cuda'))
        torch.distributed.broadcast(seed, src=0)
        c.seed = int(seed)

    # monitor
    c.num_steps = opts.num_steps
    c.network = opts.network
    c.outdir = opts.outdir
    c.network_pkl = opts.network_pkl
    c.func = opts.func
    monitor_manager(**c)

#----------------------------------------------------------------------------

if __name__ == "__main__":
    main()
