# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is licensed under a Creative Commons
# Attribution-NonCommercial-ShareAlike 4.0 International License.
# You should have received a copy of the license along with this
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/

import re
import os
import contextlib
import numpy as np
import torch
import warnings
import dnnlib
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
#----------------------------------------------------------------------------
# Cached construction of constant tensors. Avoids CPU=>GPU copy when the
# same constant is used multiple times.

_constant_cache = dict()

def constant(value, shape=None, dtype=None, device=None, memory_format=None):
    value = np.asarray(value)
    if shape is not None:
        shape = tuple(shape)
    if dtype is None:
        dtype = torch.get_default_dtype()
    if device is None:
        device = torch.device('cpu')
    if memory_format is None:
        memory_format = torch.contiguous_format

    key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
    tensor = _constant_cache.get(key, None)
    if tensor is None:
        tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
        if shape is not None:
            tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
        tensor = tensor.contiguous(memory_format=memory_format)
        _constant_cache[key] = tensor
    return tensor

#----------------------------------------------------------------------------
# Replace NaN/Inf with specified numerical values.

try:
    nan_to_num = torch.nan_to_num # 1.8.0a0
except AttributeError:
    def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
        assert isinstance(input, torch.Tensor)
        if posinf is None:
            posinf = torch.finfo(input.dtype).max
        if neginf is None:
            neginf = torch.finfo(input.dtype).min
        assert nan == 0
        return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)

#----------------------------------------------------------------------------
# Symbolic assert.

try:
    symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
except AttributeError:
    symbolic_assert = torch.Assert # 1.7.0

#----------------------------------------------------------------------------
# Context manager to temporarily suppress known warnings in torch.jit.trace().
# Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672

@contextlib.contextmanager
def suppress_tracer_warnings():
    flt = ('ignore', None, torch.jit.TracerWarning, None, 0)
    warnings.filters.insert(0, flt)
    yield
    warnings.filters.remove(flt)

#----------------------------------------------------------------------------
# Assert that the shape of a tensor matches the given list of integers.
# None indicates that the size of a dimension is allowed to vary.
# Performs symbolic assertion when used in torch.jit.trace().

def assert_shape(tensor, ref_shape):
    if tensor.ndim != len(ref_shape):
        raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
    for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
        if ref_size is None:
            pass
        elif isinstance(ref_size, torch.Tensor):
            with suppress_tracer_warnings(): # as_tensor results are registered as constants
                symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
        elif isinstance(size, torch.Tensor):
            with suppress_tracer_warnings(): # as_tensor results are registered as constants
                symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
        elif size != ref_size:
            raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')

#----------------------------------------------------------------------------
# Function decorator that calls torch.autograd.profiler.record_function().

def profiled_function(fn):
    def decorator(*args, **kwargs):
        with torch.autograd.profiler.record_function(fn.__name__):
            return fn(*args, **kwargs)
    decorator.__name__ = fn.__name__
    return decorator

#----------------------------------------------------------------------------
# Sampler for torch.utils.data.DataLoader that loops over the dataset
# indefinitely, shuffling items as it goes.

class InfiniteSampler(torch.utils.data.Sampler):
    def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
        assert len(dataset) > 0
        assert num_replicas > 0
        assert 0 <= rank < num_replicas
        assert 0 <= window_size <= 1
        super().__init__(dataset)
        self.dataset = dataset
        self.rank = rank
        self.num_replicas = num_replicas
        self.shuffle = shuffle
        self.seed = seed
        self.window_size = window_size

    def __iter__(self):
        order = np.arange(len(self.dataset))
        rnd = None
        window = 0
        if self.shuffle:
            rnd = np.random.RandomState(self.seed)
            rnd.shuffle(order)
            window = int(np.rint(order.size * self.window_size))

        idx = 0
        while True:
            i = idx % order.size
            if idx % self.num_replicas == self.rank:
                yield order[i]
            if window >= 2:
                j = (i - rnd.randint(window)) % order.size
                order[i], order[j] = order[j], order[i]
            idx += 1

#----------------------------------------------------------------------------
# Utilities for operating with torch.nn.Module parameters and buffers.

def params_and_buffers(module):
    assert isinstance(module, torch.nn.Module)
    return list(module.parameters()) + list(module.buffers())

def named_params_and_buffers(module):
    assert isinstance(module, torch.nn.Module)
    return list(module.named_parameters()) + list(module.named_buffers())

@torch.no_grad()
def copy_params_and_buffers(src_module, dst_module, require_all=False):
    assert isinstance(src_module, torch.nn.Module)
    assert isinstance(dst_module, torch.nn.Module)
    src_tensors = dict(named_params_and_buffers(src_module))
    for name, tensor in named_params_and_buffers(dst_module):
        assert (name in src_tensors) or (not require_all)
        if name in src_tensors:
            tensor.copy_(src_tensors[name])

#----------------------------------------------------------------------------
# Context manager for easily enabling/disabling DistributedDataParallel
# synchronization.

@contextlib.contextmanager
def ddp_sync(module, sync):
    assert isinstance(module, torch.nn.Module)
    if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
        yield
    else:
        with module.no_sync():
            yield

#----------------------------------------------------------------------------
# Check DistributedDataParallel consistency across processes.

def check_ddp_consistency(module, ignore_regex=None):
    assert isinstance(module, torch.nn.Module)
    for name, tensor in named_params_and_buffers(module):
        fullname = type(module).__name__ + '.' + name
        if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
            continue
        tensor = tensor.detach()
        if tensor.is_floating_point():
            tensor = nan_to_num(tensor)
        other = tensor.clone()
        torch.distributed.broadcast(tensor=other, src=0)
        assert (tensor == other).all(), fullname

#----------------------------------------------------------------------------
# Print summary table of module hierarchy.

def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
    assert isinstance(module, torch.nn.Module)
    assert not isinstance(module, torch.jit.ScriptModule)
    assert isinstance(inputs, (tuple, list))

    # Register hooks.
    entries = []
    nesting = [0]
    def pre_hook(_mod, _inputs):
        nesting[0] += 1
    def post_hook(mod, _inputs, outputs):
        nesting[0] -= 1
        if nesting[0] <= max_nesting:
            outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
            outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
            entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
    hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
    hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]

    # Run module.
    outputs = module(*inputs)
    for hook in hooks:
        hook.remove()

    # Identify unique outputs, parameters, and buffers.
    tensors_seen = set()
    for e in entries:
        e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
        e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
        e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
        tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}

    # Filter out redundant entries.
    if skip_redundant:
        entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]

    # Construct table.
    rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]
    rows += [['---'] * len(rows[0])]
    param_total = 0
    buffer_total = 0
    submodule_names = {mod: name for name, mod in module.named_modules()}
    for e in entries:
        name = '<top-level>' if e.mod is module else submodule_names[e.mod]
        param_size = sum(t.numel() for t in e.unique_params)
        buffer_size = sum(t.numel() for t in e.unique_buffers)
        output_shapes = [str(list(t.shape)) for t in e.outputs]
        output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
        rows += [[
            name + (':0' if len(e.outputs) >= 2 else ''),
            str(param_size) if param_size else '-',
            str(buffer_size) if buffer_size else '-',
            (output_shapes + ['-'])[0],
            (output_dtypes + ['-'])[0],
        ]]
        for idx in range(1, len(e.outputs)):
            rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
        param_total += param_size
        buffer_total += buffer_size
    rows += [['---'] * len(rows[0])]
    rows += [['Total', str(param_total), str(buffer_total), '-', '-']]

    # Print table.
    widths = [max(len(cell) for cell in column) for column in zip(*rows)]
    print()
    for row in rows:
        print('  '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))
    print()
    return outputs

#----------------------------------------------------------------------------
def count_parameters(model):
    """Count number of both learnable and total parameters for a module"""
    learnable_parameters = filter(lambda p: p.requires_grad, model.parameters())
    num_learned_params = sum([np.prod(p.size()) for p in learnable_parameters])
    num_params = sum([np.prod(p.size()) for p in model.parameters()])
    return num_learned_params, num_params

#----------------------------------------------------------------------------
def save_single_tensor_image(tensors, image_id, folder_path="output_images", image_name="input", cmap='viridis'):
    """
    Save a single tensor image with the specified ID from the batch using a colormap.
    
    Args:
        tensors (torch.Tensor): Batch of tensors to be saved. Shape should be (bs, 128, 128) or (bs, 1, 128, 128).
        image_id (int): The index of the image in the batch to be saved.
        folder_path (str): Folder path to save the image.
        cmap (str): Colormap to apply to the image.
    """
    # Ensure tensors are in the shape (batch_size, 128, 128)
    if tensors.ndim == 4 and tensors.shape[1] == 1:
        # Remove the channel dimension if it exists (from shape [bs, 1, 128, 128] to [bs, 128, 128])
        tensors = tensors.squeeze(1)
    
    # Convert the specified image tensor to NumPy
    img = tensors[image_id].cpu().numpy()
    
    # Apply the colormap using matplotlib
    plt.imshow(img, cmap=cmap)
    plt.axis('off')  # Hide axes
    
    # Save the image to the specified folder
    plt.savefig(f"{folder_path}/image_{image_name}_{image_id}.png", bbox_inches='tight', pad_inches=0)
    plt.close()


#----------------------------------------------------------------------------
def save_tensors_to_pdf(tensor_list, tensor_names, folder_path="output_images", pdf_file_path="output_images.pdf", cmap='viridis', rows_per_page=10):
    """
    Save a list of tensors with their respective names in a PDF file as a grid of images.
    
    Args:
        tensor_list (list): A list of PyTorch tensors. Each tensor should be of shape (batch_size, 128, 128) or (batch_size, 1, 128, 128).
        tensor_names (list): List of names for each tensor in tensor_list. Length should match the number of tensors.
        pdf_file_path (str): Path to save the PDF file.
        cmap (str): Colormap to apply to the images.
        rows_per_page (int): Number of rows (samples) to display per page.
    """
    num_tensors = len(tensor_list)
    batch_size = tensor_list[0].shape[0]  # Assume all tensors have the same batch size
    
    with PdfPages(os.path.join(folder_path, pdf_file_path)) as pdf:
        for start_idx in range(0, batch_size, rows_per_page):
            end_idx = min(start_idx + rows_per_page, batch_size)
            
            fig, axs = plt.subplots(end_idx - start_idx, num_tensors, figsize=(num_tensors * 4, (end_idx - start_idx) * 4))
            fig.tight_layout(pad=3.0)
            
            # If we only have one row, axs won't be a list of lists, so we need to ensure it's always 2D.
            if end_idx - start_idx == 1:
                axs = [axs]
            
            # Loop over samples (rows)
            for row_idx, sample_idx in enumerate(range(start_idx, end_idx)):
                # Loop over tensors (columns)
                for col_idx, tensor in enumerate(tensor_list):
                    # Ensure the tensor is in the correct shape
                    if tensor.ndim == 4 and tensor.shape[1] == 1:
                        tensor = tensor.squeeze(1)
                    
                    # Convert the specified sample to NumPy
                    img = tensor[sample_idx].cpu().numpy()
                    
                    # Plot the image with a colormap
                    axs[row_idx][col_idx].imshow(img, cmap=cmap)
                    axs[row_idx][col_idx].axis('off')
                    
                    # Set the column title with the tensor name for the first row
                    if row_idx == 0:
                        axs[row_idx][col_idx].set_title(tensor_names[col_idx])
            
            # Save the current figure to the PDF
            pdf.savefig(fig)
            plt.close(fig)


#----------------------------------------------------------------------------
def compute_relative_error(predicted, true):
    """
    Compute the relative error between predicted and true tensors for each sample in a batch.

    Args:
    predicted (torch.Tensor): The predicted tensor of shape (batch_size, ...).
    true (torch.Tensor): The ground truth tensor of shape (batch_size, ...).

    Returns:
    torch.Tensor: A tensor containing the relative error for each sample in the batch.
    """
    # Ensure the tensors are of the same shape
    assert predicted.shape == true.shape, "Predicted and True tensors must have the same shape."
    
    # Compute the L2 norm of the difference (numerator) for each sample
    error_norm = torch.norm(predicted - true, p=2, dim=(1, 2, 3))  
    # Adjust dims if using more or fewer channels
    # dims 1,2,3 will calulate norm over these dims and output will be batch shaped so relative error for every sample
    
    # Compute the L2 norm of the ground truth (denominator) for each sample
    true_norm = torch.norm(true, p=2, dim=(1, 2, 3))
    
    # Compute the relative error for each sample in the batch
    relative_error = error_norm / true_norm
    return relative_error


#----------------------------------------------------------------------------

def chmod_recursive(path, mode=0o777):
    for root, dirs, files in os.walk(path):
        # Set permission on current directory
        os.chmod(root, mode)
        # Set permission on sub-directories
        for d in dirs:
            os.chmod(os.path.join(root, d), mode)
        # Set permission on files
        for f in files:
            os.chmod(os.path.join(root, f), mode)

