# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is licensed under a Creative Commons
# Attribution-NonCommercial-ShareAlike 4.0 International License.
# You should have received a copy of the license along with this
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/

"""Generate random images using the techniques described in the paper
"Elucidating the Design Space of Diffusion-Based Generative Models"."""

import os
import re
import click
import tqdm
import pickle
import numpy as np
import torch
import PIL.Image
import dnnlib
from torch_utils import distributed as dist
from scipy.io import savemat, loadmat

from piq import LPIPS
import torch.nn.functional as F

def fine_grained_integration(net, class_labels, x_hat, t_hat, t_next, d_cur, M):
    if M==2:
        t_mid = (t_hat+t_next)/2
        x_mid = x_hat + (t_mid - t_hat) * d_cur
        denoisedMid = net(x_mid, t_mid, class_labels).to(torch.float64)
        d_mid = (x_mid - denoisedMid) / t_mid
        x_mid = x_hat + (t_mid - t_hat) * (0.5 * d_cur + 0.5 * d_mid)

        denoisedMid_down = net(x_mid, t_mid, class_labels).to(torch.float64)
        d_mid_down = (x_mid - denoisedMid_down) / t_mid

        x_next_fine = x_mid + (t_next - t_mid) * d_mid_down
        denoised_next_fine = net(x_next_fine, t_next, class_labels).to(torch.float64)
        d_prime_fine = (x_next_fine - denoised_next_fine) / (t_next)

        x_next_final = x_mid + (t_next - t_mid) * (0.5 * d_mid_down + 0.5 * d_prime_fine)
        fg_approx = (t_next-t_hat)*(0.25 * d_cur + 0.25 * d_mid + 0.25 * d_mid_down + 0.25 * d_prime_fine)

    elif M==3:
        h = (t_hat-t_next)/3
        t_m1 = t_hat - h
        t_m2 = t_m1 - h

        x_m1 = x_hat + (t_m1 - t_hat) * d_cur
        denoised_m1 = net(x_m1, t_m1, class_labels).to(torch.float64)
        d_m1 = (x_m1 - denoised_m1) / t_m1
        x_m1 = x_hat + (t_m1 - t_hat) * (0.5 * d_cur + 0.5 * d_m1)

        denoised_m1_up = net(x_m1, t_m1, class_labels).to(torch.float64)
        d_m1_up = (x_m1 - denoised_m1_up) / t_m1
        x_m2 = x_m1 + (t_m2 - t_m1) * d_m1_up
        denoised_m2 = net(x_m2, t_m2, class_labels).to(torch.float64)
        d_m2 = (x_m2 - denoised_m2) / t_m2
        x_m2 = x_m1 + (t_m2 - t_m1) * (0.5 * d_m1_up + 0.5 * d_m2)

        denoised_m2_up = net(x_m2, t_m2, class_labels).to(torch.float64)
        d_m2_up = (x_m2 - denoised_m2_up) / t_m2
        x_next_fg = x_m2 + (t_next - t_m2) * d_m2_up
        denoised_next_fg = net(x_next_fg, t_next, class_labels).to(torch.float64)
        d_next_fg = (x_next_fg - denoised_next_fg) / t_next
        x_next_fg = x_m2 + (t_next - t_m2) * (0.5 * d_m2_up + 0.5 * d_next_fg)

        fg_approx = ((t_m1 - t_hat) * (0.5 * d_cur + 0.5 * d_m1)
            +(t_m2 - t_m1) * (0.5 * d_m1_up + 0.5 * d_m2)
            +(t_next - t_m2) * (0.5 * d_m2_up + 0.5 * d_next_fg)
            )
    else:
        print('the M value is not implimented yet')
    return fg_approx


def edm_sampler(
    net, latents, class_labels=None, randn_like=torch.randn_like, IIA_strength={},
    num_steps=18, sigma_min=0.002, sigma_max=80, rho=7,
    S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, s_m="IIA", M=3,
):
    
    # Adjust noise levels based on what's supported by the network.
    sigma_min = max(sigma_min, net.sigma_min)
    sigma_max = min(sigma_max, net.sigma_max)

    step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
    t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
    t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0
    
    # Main sampling loop.
    x_next = latents.to(torch.float64) * t_steps[0]
    
    #BIIA
    D_prev = None

    #IIA
    Ddiff_prev = None
    xDdiff_prev = None
    for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
        x_cur = x_next
    
        # Increase noise temporarily.
        gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
        t_hat = net.round_sigma(t_cur + gamma * t_cur)
        x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur)

        # Euler step.
        denoised = net(x_hat, t_hat, class_labels).to(torch.float64)
        d_cur = (x_hat - denoised) / t_hat
        x_next = x_hat + (t_next - t_hat) * d_cur

        # Apply 2nd order correction.
        if i < num_steps - 1:
            denoised_next = net(x_next, t_next, class_labels).to(torch.float64)
            d_prime = (x_next - denoised_next) / t_next                            
            if "IIA" in s_m:
                if i not in IIA_strength.keys():
                    fg_approx = fine_grained_integration(net, class_labels, x_hat, t_hat, t_next, d_cur, M)
                    if s_m=="BIIA":
                        if i==0: #when i=0, we ignore the IIA step                          
                            IIA_strength.update({i:(0.0, 0.0)})                                
                        else:
                            x1 = (t_next - t_hat)*(0.5 * d_cur + 0.5 * d_prime)
                            x2 = (t_next - t_hat)*D_prev
                            #solving a quaratic optimisation problem
                            sol = torch.matmul(torch.linalg.inv(torch.tensor([[torch.sum(x1*x1), torch.sum(x1*x2)],[torch.sum(x2*x1),torch.sum(x2*x2)]])) \
                                    ,torch.tensor([[torch.sum(x1*fg_approx)],[torch.sum(x2*fg_approx)]]) ).to('cuda')                                
                            IIA_strength.update({i:(sol[0],sol[1])})
                            
                    elif s_m=="IIA":
                        if i==0: #when i=0, we ignore the IIA step 
                            IIA_strength.update({i:(0.0, 0.0, 0.0, 0.0)})                                
                        else:
                            x1 = x_hat - denoised
                            x2 = denoised_next-denoised
                            x3 = Ddiff_prev
                            x4 = xDdiff_prev
                            #solving a quaratic optimisation problem
                            sol = torch.matmul(torch.linalg.inv(torch.tensor([[torch.sum(x1*x1), torch.sum(x1*x2), torch.sum(x1*x3), torch.sum(x1*x4)],[torch.sum(x2*x1),torch.sum(x2*x2), torch.sum(x2*x3), torch.sum(x2*x4)],[torch.sum(x3*x1),torch.sum(x3*x2), torch.sum(x3*x3), torch.sum(x3*x4)], [torch.sum(x4*x1),torch.sum(x4*x2), torch.sum(x4*x3), torch.sum(x4*x4)] ])) \
                                    ,torch.tensor([[torch.sum(x1*fg_approx)],[torch.sum(x2*fg_approx)], [torch.sum(x3*fg_approx)], [torch.sum(x4*fg_approx)] ]) ).to('cuda')                                
                            IIA_strength.update({i:(sol[0],sol[1], sol[2], sol[3])})
                # compute next diffusion step for BIIA-EDM or IIA-EDM           
                if i==0: #when i=0, no IIA step is involved for r=1 in the paper
                    x_next = x_hat + (t_next - t_hat)*(0.5 * d_cur + 0.5 * d_prime)         
                    if s_m=="BIIA":
                        D_prev = (0.5 * d_cur + 0.5 * d_prime)
                    elif s_m=="IIA":                          
                        xDdiff_prev = x_hat - denoised
                        Ddiff_prev = denoised_next-denoised
                else:
                    if s_m=="BIIA":
                        denoised_next = net(x_next, t_next, class_labels).to(torch.float64)
                        d_prime = (x_next - denoised_next) / t_next

                        x1 = (t_next - t_hat)*(0.5 * d_cur + 0.5 * d_prime)
                        x2 = (t_next - t_hat)*D_prev
                        x_next = x_hat + IIA_strength[i][0]*x1+ IIA_strength[i][1]*x2
                        D_prev = (0.5 * d_cur + 0.5 * d_prime)
                    elif s_m=="IIA":
                        x1 = x_hat - denoised
                        x2 = denoised_next-denoised
                        x3 = Ddiff_prev
                        x4 = xDdiff_prev
                        x_next = (x_hat + IIA_strength[i][0]*x1 \
                                        + IIA_strength[i][1]*x2 \
                                        + IIA_strength[i][2]*x3 \
                                        + IIA_strength[i][3]*x4 
                                )
                        xDdiff_prev = x_hat - denoised
                        Ddiff_prev = denoised_next-denoised
            else:
                x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
                
    return x_next, IIA_strength


#----------------------------------------------------------------------------
# Wrapper for torch.Generator that allows specifying a different random seed
# for each sample in a minibatch.

class StackedRandomGenerator:
    def __init__(self, device, seeds):
        super().__init__()
        self.generators = [torch.Generator(device).manual_seed(int(seed) % (1 << 32)) for seed in seeds]

    def randn(self, size, **kwargs):
        assert size[0] == len(self.generators)
        return torch.stack([torch.randn(size[1:], generator=gen, **kwargs) for gen in self.generators])

    def randn_like(self, input):
        return self.randn(input.shape, dtype=input.dtype, layout=input.layout, device=input.device)

    def randint(self, *args, size, **kwargs):
        assert size[0] == len(self.generators)
        return torch.stack([torch.randint(*args, size=size[1:], generator=gen, **kwargs) for gen in self.generators])

#----------------------------------------------------------------------------
# Parse a comma separated list of numbers or ranges and return a list of ints.
# Example: '1,2,5-10' returns [1, 2, 5, 6, 7, 8, 9, 10]

def parse_int_list(s):
    if isinstance(s, list): return s
    ranges = []
    range_re = re.compile(r'^(\d+)-(\d+)$')
    for p in s.split(','):
        m = range_re.match(p)
        if m:
            ranges.extend(range(int(m.group(1)), int(m.group(2))+1))
        else:
            ranges.append(int(p))
    return ranges

#----------------------------------------------------------------------------

@click.command()
@click.option('--network', 'network_pkl',  help='Network pickle filename', metavar='PATH|URL',                      type=str, required=True)
@click.option('--outdir',                  help='Where to save the output images', metavar='DIR',                   type=str, required=True)
@click.option('--seeds',                   help='Random seeds (e.g. 1,2,5-10)', metavar='LIST',                     type=parse_int_list, default='0-63', show_default=True)
@click.option('--subdirs',                 help='Create subdirectory for every 1000 seeds',                         is_flag=True)
@click.option('--class', 'class_idx',      help='Class label  [default: random]', metavar='INT',                    type=click.IntRange(min=0), default=None)
@click.option('--batch', 'max_batch_size', help='Maximum batch size', metavar='INT',                                type=click.IntRange(min=1), default=200, show_default=True)

@click.option('--steps', 'num_steps',      help='Number of sampling steps', metavar='INT',                          type=click.IntRange(min=1), default=18, show_default=True)
@click.option('--sampling_method', 's_m',  help='sampling_method', metavar='BIIA|IIA|none',                         type=click.Choice(['BIIA', 'IIA', 'none']))
@click.option('--M', 'M',                  help='Number of fine-grained step', metavar='INT',                       type=click.IntRange(min=2, max=3), default=3, show_default=True)
@click.option('--sigma_min',               help='Lowest noise level  [default: varies]', metavar='FLOAT',           type=click.FloatRange(min=0, min_open=True))
@click.option('--sigma_max',               help='Highest noise level  [default: varies]', metavar='FLOAT',          type=click.FloatRange(min=0, min_open=True))
@click.option('--rho',                     help='Time step exponent', metavar='FLOAT',                              type=click.FloatRange(min=0, min_open=True), default=7, show_default=True)
@click.option('--S_churn', 'S_churn',      help='Stochasticity strength', metavar='FLOAT',                          type=click.FloatRange(min=0), default=0, show_default=True)
@click.option('--S_min', 'S_min',          help='Stoch. min noise level', metavar='FLOAT',                          type=click.FloatRange(min=0), default=0, show_default=True)
@click.option('--S_max', 'S_max',          help='Stoch. max noise level', metavar='FLOAT',                          type=click.FloatRange(min=0), default='inf', show_default=True)
@click.option('--S_noise', 'S_noise',      help='Stoch. noise inflation', metavar='FLOAT',                          type=float, default=1, show_default=True)

@click.option('--solver',                  help='Ablate ODE solver', metavar='euler|heun',                          type=click.Choice(['euler', 'heun']))
@click.option('--disc', 'discretization',  help='Ablate time step discretization {t_i}', metavar='vp|ve|iddpm|edm', type=click.Choice(['vp', 've', 'iddpm', 'edm']))
@click.option('--schedule',                help='Ablate noise schedule sigma(t)', metavar='vp|ve|linear',           type=click.Choice(['vp', 've', 'linear']))
@click.option('--scaling',                 help='Ablate signal scaling s(t)', metavar='vp|none',                    type=click.Choice(['vp', 'none']))

def main(network_pkl, outdir, subdirs, seeds, class_idx, max_batch_size, device=torch.device('cuda'), **sampler_kwargs):

    '''
    * We modify the code of EDM paper to incorporate the BIIA and IIA technique 
      for the purpose of improving the performance for small NFEs
    * In this implimentation, we only include the scenario of r=1 in the paper for simplicity
    '''
    dist.init()
    num_batches = ((len(seeds) - 1) // (max_batch_size * dist.get_world_size()) + 1) * dist.get_world_size()
    all_batches = torch.as_tensor(seeds).tensor_split(num_batches)
    rank_batches = all_batches[dist.get_rank() :: dist.get_world_size()]

    # Rank 0 goes first.
    if dist.get_rank() != 0:
        torch.distributed.barrier()

    # Load network.
    dist.print0(f'Loading network from "{network_pkl}"...')
    with dnnlib.util.open_url(network_pkl, verbose=(dist.get_rank() == 0)) as f:
        net = pickle.load(f)['ema'].to(device)

    # Other ranks follow.
    if dist.get_rank() == 0:
        torch.distributed.barrier()

    #a dictionary for storing the computed coefficents in BIIA and IIA
    #the coefficients in IIA_strength will only be computed by the first mini-batch
    # the remaning samples are generated by directly using the coefficients in IIA_strength
    IIA_strength = {}

    # Loop over batches.
    dist.print0(f'Generating {len(seeds)} images to "{outdir}"...')
    for batch_seeds in tqdm.tqdm(rank_batches, unit='batch', disable=(dist.get_rank() != 0)):

        torch.distributed.barrier()
        batch_size = len(batch_seeds)
        if batch_size == 0:
            continue

        # Pick latents and labels.
        rnd = StackedRandomGenerator(device, batch_seeds)
        latents = rnd.randn([batch_size, net.img_channels, net.img_resolution, net.img_resolution], device=device)
        class_labels = None
        if net.label_dim:
            class_labels = torch.eye(net.label_dim, device=device)[rnd.randint(net.label_dim, size=[batch_size], device=device)]
        if class_idx is not None:
            class_labels[:, :] = 0
            class_labels[:, class_idx] = 1


        # Generate images.
        sampler_kwargs = {key: value for key, value in sampler_kwargs.items() if value is not None}

        have_ablation_kwargs = any(x in sampler_kwargs for x in ['solver', 'discretization', 'schedule', 'scaling'])
        
        images, IIA_strength = edm_sampler(net, latents, class_labels, randn_like=rnd.randn_like, IIA_strength=IIA_strength, **sampler_kwargs)

        # Save images.
        images_np = (images * 127.5 + 128).clip(0, 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy()
        for seed, image_np in zip(batch_seeds, images_np):
            image_dir = os.path.join(outdir, f'{seed-seed%1000:06d}') if subdirs else outdir
            os.makedirs(image_dir, exist_ok=True)
            image_path = os.path.join(image_dir, f'{seed:06d}.png')
            if image_np.shape[2] == 1:
                PIL.Image.fromarray(image_np[:, :, 0], 'L').save(image_path)
            else:
                PIL.Image.fromarray(image_np, 'RGB').save(image_path)

    # Done.
    torch.distributed.barrier()
    dist.print0('Done.')

#----------------------------------------------------------------------------

if __name__ == "__main__":
    main()

#----------------------------------------------------------------------------
