# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is licensed under a Creative Commons
# Attribution-NonCommercial-ShareAlike 4.0 International License.
# You should have received a copy of the license along with this
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
import os
import re

import click
import tqdm
import pickle
import torch
from torch.utils.data import DataLoader, Subset

import dnnlib
import training.dataset
from torch_utils import distributed as dist


def enumerate_edm_loss(
    net, y, batch_unique_labels, class_labels=None, randn_like=torch.randn_like,
    num_steps=18, sigma_min=0.002, sigma_max=80, rho=7,
):
    """Estimate denoising losses along the time steps used in EDM sampler."""
    # Adjust noise levels based on what's supported by the network.
    sigma_min = max(sigma_min, net.sigma_min)
    sigma_max = min(sigma_max, net.sigma_max)

    # Time step discretization.
    step_indices = torch.arange(num_steps, dtype=torch.float64, device=y.device)
    t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
    t_steps = net.round_sigma(t_steps)  # t_N = 0
    sigmas = t_steps

    losses = []
    for sigma in sigmas:  # 0, ..., N-1
        n = randn_like(y) * sigma
        D_yn = net(y + n, sigma, batch_unique_labels, class_labels)
        loss = (D_yn - y).square().sum(dim=[1, 2, 3])
        losses.append(loss)

    losses = torch.stack(losses, dim=1)
    return losses


class StackedRandomGenerator:
    """
    Wrapper for torch.Generator that allows specifying a different random seed
    for each sample in a minibatch.
    """
    def __init__(self, device, seeds):
        super().__init__()
        self.generators = [torch.Generator(device).manual_seed(int(seed) % (1 << 32)) for seed in seeds]

    def randn(self, size, **kwargs):
        assert size[0] == len(self.generators), f"{size[0]} != {len(self.generators)}"
        return torch.stack([torch.randn(size[1:], generator=gen, **kwargs) for gen in self.generators])

    def randn_like(self, input):
        return self.randn(input.shape, dtype=input.dtype, layout=input.layout, device=input.device)

    def randint(self, *args, size, **kwargs):
        assert size[0] == len(self.generators)
        return torch.stack([torch.randint(*args, size=size[1:], generator=gen, **kwargs) for gen in self.generators])


def parse_int_list(s):
    """Parse a comma separated list of numbers or ranges and return a list of ints.

    Example: '1,2,5-10' returns [1, 2, 5, 6, 7, 8, 9, 10]
    """
    if isinstance(s, list):
        return s
    ranges = []
    range_re = re.compile(r'^(\d+)-(\d+)$')
    for p in s.split(','):
        m = range_re.match(p)
        if m:
            ranges.extend(range(int(m.group(1)), int(m.group(2))+1))
        else:
            ranges.append(int(p))
    return ranges


@click.command()
@click.option('--network', 'network_pkl',      help='Network pickle filename', metavar='PATH|URL',             type=str, required=True)
@click.option('--generated', 'generated_path', help='Path to the generated images', metavar='PATH',            type=str, required=True)
@click.option('--outpath',                     help='Path to save the output losses', metavar='DIR',           type=str, required=True)
@click.option('--seeds',                       help='Random seeds (e.g. 1,2,5-10)', metavar='LIST',            type=parse_int_list, default=None, show_default=True)
@click.option('--class', 'class_idx',          help='Class label  [default: random]', metavar='INT',           type=click.IntRange(min=0), default=None)
@click.option('--batch', 'batch_size',         help='Maximum batch size', metavar='INT',                       type=click.IntRange(min=1), default=64, show_default=True)
@click.option('--unique_labels',               help='Generate with unique labels',                             type=parse_int_list, default=None, show_default=True)
# Sampler arguments.
@click.option('--steps', 'num_steps',          help='Number of sampling steps', metavar='INT',                 type=click.IntRange(min=1), default=18, show_default=True)
@click.option('--sigma_min',                   help='Lowest noise level  [default: varies]', metavar='FLOAT',  type=click.FloatRange(min=0, min_open=True))
@click.option('--sigma_max',                   help='Highest noise level  [default: varies]', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True))
@click.option('--rho',                         help='Time step exponent', metavar='FLOAT',                     type=click.FloatRange(min=0, min_open=True), default=7, show_default=True)
def main(network_pkl, generated_path, outpath, seeds, class_idx, batch_size, unique_labels, device=torch.device('cuda'), **sampler_kwargs):
    dist.init()

    dataset = training.dataset.ImageFolderDataset(path=generated_path)
    indices = list(range(len(dataset)))
    total_n = len(dataset)
    total_batch_size = batch_size * dist.get_world_size()

    if seeds is not None:
        assert len(seeds) == len(dataset)
    else:
        seeds = list(range(1, len(dataset) + 1))
    if unique_labels is not None:
        assert len(seeds) == len(unique_labels)

    if total_n % total_batch_size != 0:
        m = total_batch_size - total_n % total_batch_size
        seeds = seeds + seeds[: m]
        unique_labels = unique_labels + unique_labels[: m] if unique_labels is not None else None
        indices = indices + indices[: m]

    assert len(seeds) % total_batch_size == 0
    assert len(unique_labels) % total_batch_size == 0 if unique_labels is not None else True
    assert len(indices) % total_batch_size == 0

    all_seeds = torch.as_tensor(seeds).split(batch_size)
    all_unique_labels = torch.as_tensor(unique_labels).split(batch_size) if unique_labels is not None else torch.zeros(len(seeds)).split(batch_size)
    all_indices = torch.as_tensor(indices).split(batch_size)
    rank_seeds = all_seeds[dist.get_rank():: dist.get_world_size()]
    rank_unique_labels = all_unique_labels[dist.get_rank():: dist.get_world_size()]
    rank_indices = all_indices[dist.get_rank():: dist.get_world_size()]

    dataset = Subset(dataset, torch.cat(rank_indices))
    dataloader = DataLoader(dataset, batch_size, num_workers=4, pin_memory=True)
    assert len(dataloader) == len(rank_seeds), f'{len(dataloader)} != {len(rank_seeds)}'

    # Rank 0 goes first.
    if dist.get_rank() != 0:
        torch.distributed.barrier()

    # Load network.
    dist.print0(f'Loading network from "{network_pkl}"...')
    with dnnlib.util.open_url(network_pkl, verbose=(dist.get_rank() == 0)) as f:
        net = pickle.load(f)['ema'].to(device)

    # Other ranks follow.
    if dist.get_rank() == 0:
        torch.distributed.barrier()

    # Get the number of cached losses.
    if os.path.exists(outpath):
        cache_n = torch.load(outpath).shape[0]
    else:
        cache_n = 0

    # Initialize buffer.
    if dist.get_rank() == 0:
        batch_losses = []
    now_n = 0
    buffer_size = 8192

    # Loop over batches.
    dist.print0(f'Enumerating the denoising losses of {total_n} images to "{outpath}"...')
    with tqdm.tqdm(total=total_n, disable=(dist.get_rank() != 0)) as pbar:
        for (y, _, _), batch_seeds, batch_unique_labels in zip(dataloader, rank_seeds, rank_unique_labels):
            torch.distributed.barrier()
            if now_n + batch_size * dist.get_world_size() > cache_n:
                y = y.float().to(device) / 127.5 - 1    # Normalize to [-1, 1]

                # Pick labels.
                rnd = StackedRandomGenerator(device, batch_seeds)
                class_labels = None
                if net.label_dim:
                    class_labels = torch.eye(net.label_dim, device=device)[rnd.randint(net.label_dim, size=[batch_size], device=device)]
                if class_idx is not None:
                    class_labels[:, :] = 0
                    class_labels[:, class_idx] = 1
                batch_unique_labels = batch_unique_labels.to(device)

                # Enumerate losses.
                sampler_kwargs = {key: value for key, value in sampler_kwargs.items() if value is not None}
                losses = enumerate_edm_loss(net, y, batch_unique_labels, class_labels, randn_like=rnd.randn_like, **sampler_kwargs)

                # Gather losses and save to buffer.
                losses = dist.gather(losses).cpu()
                if dist.get_rank() == 0:
                    batch_losses.append(losses)
                now_n += len(losses)

                # Flush buffer.
                if now_n - cache_n >= buffer_size or now_n >= total_n:
                    if dist.get_rank() == 0:
                        batch_losses = torch.cat(batch_losses, dim=0)
                        if cache_n > 0:
                            cache = torch.load(outpath)
                            cache = cache[:now_n - batch_losses.shape[0]]
                            cache = torch.cat([cache, batch_losses], dim=0)
                        else:
                            cache = batch_losses
                        torch.save(cache, outpath)
                        batch_losses = []
                    cache_n = now_n
            else:
                now_n += batch_size * dist.get_world_size()
            if now_n >= total_n:
                pbar.update(batch_size * dist.get_world_size() - (now_n - total_n))
            else:
                pbar.update(batch_size * dist.get_world_size())
            pbar.set_description(f"now_n: {now_n}, cache_n: {cache_n}")

    # Save losses.
    if dist.get_rank() == 0:
        cache = torch.load(outpath)
        torch.save(cache[:total_n], outpath)

    # Done.
    torch.distributed.barrier()
    dist.print0('Done.')


if __name__ == "__main__":
    main()
