"""

Example runs (generate_image_grid):
    python example.py --mode=generate_image_grid --output_path=test_output.png --model_path=/home/id4439/ambient-diffusion/training-runs/00020-cifar10-32x32-cond-ddpmpp-ambient-gpus3-batch516-fp32/network-snapshot-002503.pkl

Example runs (dataset creation):
    python example.py --mode=dataset_creation --model_path=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl --output_path=edm_cifar10_outputs

Example runs (cond expectations): 
    python example.py --model_path=models/cdm_200k.pkl  --output_path=cond_expectations_cdm_200k.png --mode=cond_expectations
    python example.py --model_path=models/cdm_7k.pkl  --output_path=cond_expectations_cdm_7k.png  --mode=cond_expectations
    
    (p=0.4, d=0.1)
    python example.py --model_path=training-runs/00011-cifar10-32x32-cond-ddpmpp-ambient-gpus3-batch516-fp32/network-snapshot-007508.pkl --output_path=cond_expectations.png  --mode=cond_expectations
    python example.py --model_path=training-runs/00011-cifar10-32x32-cond-ddpmpp-ambient-gpus3-batch516-fp32/network-snapshot-007508.pkl --output_path=cond_expectations_masked_inputs.png --mask_input=True  --mode=cond_expectations

    (p=0.4, d=0.05)
    python example.py --model_path=training-runs/00013-cifar10-32x32-cond-ddpmpp-ambient-gpus3-batch516-fp32/network-snapshot-007508.pkl --output_path=cond_expectations_smaller_delta.png  --mode=cond_expectations
    python example.py --model_path=training-runs/00013-cifar10-32x32-cond-ddpmpp-ambient-gpus3-batch516-fp32/network-snapshot-007508.pkl --output_path=cond_expectations_smaller_delta_masked_inputs.png --mask_input=True --corruption_probability=0.4 --delta_probability=0.05  --mode=cond_expectations

    (p=0.4, d=0.3)
    python example.py --model_path=training-runs/00012-cifar10-32x32-cond-ddpmpp-ambient-gpus3-batch516-fp32/network-snapshot-005005.pkl --output_path=cond_expectations_larger_delta.png  --mode=cond_expectations
    python example.py --model_path=training-runs/00012-cifar10-32x32-cond-ddpmpp-ambient-gpus3-batch516-fp32/network-snapshot-005005.pkl --output_path=cond_expectations_larger_delta_masked_inputs.png --mask_input=True --corruption_probability=0.4 --delta_probability=0.3  --mode=cond_expectations

"""

import tqdm
import pickle
import numpy as np
import torch
import PIL.Image
import dnnlib
import matplotlib.pyplot as plt
import click
import os
import json

#----------------------------------------------------------------------------

def stochastic_sampler(net, x_cur, class_labels, S_churn, S_min, S_max, S_noise, num_steps, t_cur, t_next, i):
    # Increase noise temporarily.
    gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
    t_hat = net.round_sigma(t_cur + gamma * t_cur)
    x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * torch.randn_like(x_cur)

    # Euler step.
    denoised = net(x_hat, t_hat, class_labels).to(torch.float64)
    d_cur = (x_hat - denoised) / t_hat
    x_next = x_hat + (t_next - t_hat) * d_cur

    # Apply 2nd order correction.
    if i < num_steps - 1:
        denoised = net(x_next, t_next, class_labels).to(torch.float64)
        d_prime = (x_next - denoised) / t_next
        x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
    return x_next


def det_sampler(net, x_cur, class_labels, num_steps, t_cur, t_next, i, second_order=False):
    x_hat = x_cur
    t_hat = t_cur

    # Euler step.
    denoised = net(x_hat, t_hat, class_labels).to(torch.float64)
    d_cur = (x_hat - denoised) / t_hat
    x_next = x_hat + (t_next - t_hat) * d_cur

    # Apply 2nd order correction.
    if i < num_steps - 1 and second_order:
        denoised = net(x_next, t_next, class_labels).to(torch.float64)
        d_prime = (x_next - denoised) / t_next
        x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
    return x_next


# survival probability = (1 - corruption_probability) * (1 - delta_probability)
def ambient_sampler(net, x_cur, class_labels, num_steps, t_cur, t_next, i, second_order=False, survival_probability=0.54):
    x_hat = x_cur
    t_hat = t_cur

    # randomly mask some pixels
    np.random.seed(42)
    corruption_mask = np.random.binomial(1, survival_probability, size=x_cur.shape[1:]).astype(np.float32)
    corruption_mask = corruption_mask.repeat(x_cur.shape[0], axis=0).reshape(x_cur.shape)
    corruption_mask = torch.tensor(corruption_mask, device=x_cur.device)
    net_input = torch.cat([corruption_mask * x_hat, corruption_mask], dim=1)

    denoised = net(net_input, t_hat, class_labels).to(torch.float64)[:, :3]
    d_cur = (x_hat - denoised) / t_hat
    x_next = x_hat + (t_next - t_hat) * d_cur

    return x_next


def visualize_conditional_expectations(network_pkl, dest_path, dataset_path="../sampling/datasets/cifar10-32x32.zip", device=torch.device('cuda'), 
    use_labels=True, sigma_max=80.0, seed=0, P_mean=-1.2, P_std=1.2, num_steps=18, mask_input=False, 
    corruption_probability=0.4, delta_probability=0.1):
    torch.manual_seed(seed)

    # Load network.
    print(f'Loading network from "{network_pkl}"...')
    with dnnlib.util.open_url(network_pkl) as f:
        net = pickle.load(f)['ema'].to(device)
 
    c = dnnlib.EasyDict()
    c.dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path=dataset_path, use_labels=use_labels, xflip=False, cache=True, 
                                       corruption_probability=corruption_probability, delta_probability=delta_probability)
    dataset_obj = dnnlib.util.construct_class_by_name(**c.dataset_kwargs)
    dataset_image, labels, _, corruption_mask = dataset_obj[1]
    dataset_image = torch.tensor(dataset_image, device=device).to(torch.float32) / 127.5 - 1
    labels = torch.tensor(labels, device=device)
    corruption_mask = torch.tensor(corruption_mask, device=device)

    images = torch.unsqueeze(dataset_image, 0)
    labels = torch.unsqueeze(labels, 0)
    corruption_mask = torch.unsqueeze(corruption_mask, 0)

    rnd_normal = torch.randn([images.shape[0], 1, 1, 1], device=images.device)
    scalings = np.linspace(-3 * P_std, 3 * P_std, num_steps)
    predictions = []
    inputs = []
    for scaling in scalings:
        sigma = (rnd_normal * scaling + P_mean).exp()
        n = torch.randn_like(images) * sigma
        noisy_images = images + n
        if mask_input:
            noisy_images = noisy_images * corruption_mask
        
        model_input = torch.cat([noisy_images, corruption_mask], dim=1)
        denoised = net(model_input, sigma, labels).to(torch.float64)[:, :3]
        predictions.append(denoised.cpu())
        inputs.append(noisy_images.cpu())
    
    gridh = 2
    gridw = num_steps
    full_block = torch.cat(inputs + predictions, dim=0)
    print(f'Saving image grid to "{dest_path}"...')
    full_block = (full_block * 127.5 + 128).clip(0, 255).to(torch.uint8)
    full_block = full_block.reshape(gridh, gridw, *full_block.shape[1:]).permute(0, 3, 1, 4, 2)
    full_block = full_block.reshape(gridh * net.img_resolution, gridw * net.img_resolution, -1)
    full_block = full_block.cpu().numpy()
    PIL.Image.fromarray(full_block, 'RGB').save(dest_path)
    print('Done.')


def generate_image_grid(
    network_pkl, dest_path,
    seed=0, gridw=8, gridh=8, device=torch.device('cuda'),
    num_steps=18, sigma_min=0.002, sigma_max=80, rho=7,
    S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
    same_latents=False, second_order=False,
):
    batch_size = gridw * gridh
    torch.manual_seed(seed)

    # Load network.
    print(f'Loading network from "{network_pkl}"...')
    with dnnlib.util.open_url(network_pkl) as f:
        net = pickle.load(f)['ema'].to(device)

    # Pick latents and labels.
    print(f'Generating {batch_size} images...')
    if same_latents:
        latents = torch.randn([1, net.img_channels, net.img_resolution, net.img_resolution], device=device).repeat(batch_size, 1, 1, 1)
    else:
        latents = torch.randn([batch_size, 3, net.img_resolution, net.img_resolution], device=device)
    
        
    class_labels = None
    if net.label_dim:
        if same_latents:
            class_labels = torch.eye(net.label_dim, device=device)[torch.randint(net.label_dim, size=[1], device=device)]
            class_labels = class_labels.repeat([batch_size, 1])
        else:
            class_labels = torch.eye(net.label_dim, device=device)[torch.randint(net.label_dim, size=[batch_size], device=device)]

    # Adjust noise levels based on what's supported by the network.
    sigma_min = max(sigma_min, net.sigma_min)
    sigma_max = min(sigma_max, net.sigma_max)

    # Time step discretization.
    step_indices = torch.arange(num_steps, dtype=torch.float64, device=device)
    t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
    t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0

    # Main sampling loop.
    x_next = latents.to(torch.float64) * t_steps[0]
    for i, (t_cur, t_next) in tqdm.tqdm(list(enumerate(zip(t_steps[:-1], t_steps[1:]))), unit='step'): # 0, ..., N-1
        x_cur = x_next
        # x_next = det_sampler(net, x_cur, class_labels, num_steps, t_cur, t_next, i, second_order=second_order)
        # x_next = stochastic_sampler(net, x_cur, class_labels, num_steps, t_cur, t_next, i, second_order=second_order)
        x_next = ambient_sampler(net, x_cur, class_labels, num_steps, t_cur, t_next, i, second_order=second_order)
        
    # Save image grid.
    print(f'Saving image grid to "{dest_path}"...')
    image = (x_next * 127.5 + 128).clip(0, 255).to(torch.uint8)
    image = image.reshape(gridh, gridw, *image.shape[1:]).permute(0, 3, 1, 4, 2)
    image = image.reshape(gridh * net.img_resolution, gridw * net.img_resolution, -1)
    image = image.cpu().numpy()
    PIL.Image.fromarray(image, 'RGB').save(dest_path)
    print('Done.')


def dataset_creation(
    network_pkl, dest_path,
    seed=0, total_images=50000, batch_size=256,
    folder_size=1000,
    device=torch.device('cuda'),
    num_steps=18, sigma_min=0.002, sigma_max=80, rho=7,
    S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, second_order=True,
):
    torch.manual_seed(seed)

    # Load network.
    print(f'Loading network from "{network_pkl}"...')
    with dnnlib.util.open_url(network_pkl) as f:
        net = pickle.load(f)['ema'].to(device)

    folder_index = 0
    images_in_folder = 0
    
    json_data = {}
    json_data["labels"] = []

    try:
        os.mkdir(os.path.join(dest_path, f"{folder_index:05d}"))
    except:
        print("Folder exists...")
    for batch_index in tqdm.tqdm(range(total_images // batch_size)):
        latents = torch.randn([batch_size, net.img_channels, net.img_resolution, net.img_resolution], device=device)
        class_labels = None
        if net.label_dim:
            class_labels = torch.eye(net.label_dim, device=device)[torch.randint(net.label_dim, size=[batch_size], device=device)]

        # Adjust noise levels based on what's supported by the network.
        sigma_min = max(sigma_min, net.sigma_min)
        sigma_max = min(sigma_max, net.sigma_max)

        # Time step discretization.
        step_indices = torch.arange(num_steps, dtype=torch.float64, device=device)
        t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
        t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0

        # Main sampling loop.
        x_next = latents.to(torch.float64) * t_steps[0]
        for i, (t_cur, t_next) in tqdm.tqdm(list(enumerate(zip(t_steps[:-1], t_steps[1:]))), unit='step'): # 0, ..., N-1
            x_cur = x_next
            x_next = det_sampler(net, x_cur, class_labels, num_steps, t_cur, t_next, i, second_order=second_order)
            # x_next = stochastic_sampler(net, x_cur, class_labels, num_steps, t_cur, t_next, i, second_order=second_order)
            # x_next = ambient_sampler(net, x_cur, class_labels, num_steps, t_cur, t_next, i, second_order=second_order)
        
        print(f'Saving image grid to "{dest_path}"...')
        images = (x_next * 127.5 + 128).clip(0, 255).to(torch.uint8).permute(0, 2, 3, 1)
        prev_image_index = batch_index * batch_size
        index_offset = 0
        for image in images.cpu().numpy():
            image_index = prev_image_index + index_offset
            # figure out if we need to create a new folder
            if images_in_folder == folder_size:
                images_in_folder = 0
                folder_index += 1
                try:
                    os.mkdir(os.path.join(dest_path, f"{folder_index:05d}"))
                except:
                    print("Folder exists...")
            else:
                images_in_folder += 1

            path_suffix = os.path.join(f"{folder_index:05d}", f"img{image_index:08d}.png")
            image_path = os.path.join(dest_path, path_suffix)
            PIL.Image.fromarray(image, 'RGB').save(image_path)
            json_data["labels"].append([path_suffix, int(class_labels[index_offset].cpu().argmax())])

            index_offset += 1



    with open(os.path.join(dest_path, "dataset.json"), "w") as f:
        json.dump(json_data, f)

#----------------------------------------------------------------------------

#----------------------------------------------------------------------------


@click.command()

# Main options.
@click.option('--mode', help='Script mode', metavar='MODE', type=click.Choice(['generate_image_grid', 'cond_expectations', 'dataset_creation']), default='generate_image_grid', show_default=True)
@click.option('--model_path',        help='Path to pre-trained model (pkl file)', metavar='DIR', type=str, required=True)
@click.option('--output_path',        help='Path for output file.', metavar='DIR', type=str, required=True)
@click.option('--steps', 'num_steps',      help='Number of sampling steps', metavar='INT',                          type=click.IntRange(min=1), default=18, show_default=True)

# ambient diffusion params
@click.option('--mask_input', help='Whether to use ambient sampler', metavar='BOOL', type=bool, default=False, show_default=True)
@click.option('--corruption_probability', help='Probability of corruption', metavar='FLOAT', type=float, default=0.4, show_default=True)
@click.option('--delta_probability', help='Probability of delta corruption', metavar='FLOAT', type=float, default=0.1, show_default=True)

# Dataset creation options.
@click.option('--total_images', 'total_images', help='Number of images to generate', metavar='INT', type=click.IntRange(min=1), default=50000, show_default=True)

def main(mode, model_path, output_path, num_steps, mask_input, corruption_probability, delta_probability, total_images):
    if mode == 'cond_expectations':
        visualize_conditional_expectations(model_path,   output_path, 
                                        mask_input=mask_input, corruption_probability=corruption_probability, delta_probability=delta_probability)
    elif mode == 'generate_image_grid':
        generate_image_grid(model_path,   output_path,  num_steps=num_steps)
    elif mode == 'dataset_creation':
        # create dir if it doesn't already exist 
        try:
            os.mkdir(output_path)
        except:
            print("Folder already exists...")
        dataset_creation(model_path,   output_path,  num_steps=num_steps, total_images=total_images)
    else:
        raise ValueError('Unknown mode')



#----------------------------------------------------------------------------

if __name__ == "__main__":
    main()

#----------------------------------------------------------------------------
