import warnings
import logging

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.models import resnet
import scipy

from models.base_model import SequentialModel
from compute.grad_wrt_to_inputs import compute_jac_norm
from compute.matrix_2norm import compute_matrix_2norm_power_method
from compute.conv_to_matrix import get_conv2d_matrix
from utils.functions import check_bottleneck_structure, check_basic_block_structure


def compute_lipschitz_bounds(
    model: SequentialModel,
    layers_to_look_at: list[int],
    train_dataloader: DataLoader,
    device: torch.device,
    verbose: bool = False,
) -> dict[str, tuple[float, float, float, float]]:
    """Computes the Lipschitz constant bounds for the Sequential model at specified layers.

    Parameters
    ----------
    model
        Model object.
    layers_to_look_at
        List of layers to compute the Lipschitz constant for.
    train_dataloader
        Train dataloader.
    device
        Torch device to use for the computation.
    verbose, optional
        Log info on Lipschitz computation to the INFO log, by default False.

    Returns
    -------
        A dictionary, where each key is the string of the layer index and the value is a tuple of 4 values:
        lower bound, mean norm, rms norm, upper bound for the model at this layer.
    """
    results = {}

    # supremum of the norms
    lower_bounds = dict([(str(l), torch.tensor(0.0).to(device)) for l in layers_to_look_at])
    # mean of the norms
    mean_bounds = dict([(str(l), torch.tensor(0.0).to(device)) for l in layers_to_look_at])
    # rms of the norms
    rms_bounds = dict([(str(l), torch.tensor(0.0).to(device)) for l in layers_to_look_at])

    n_samples = 0

    if verbose:
        logging.info("Computing the lower bound...")
    for i, (x_batch, y_batch) in enumerate(train_dataloader):
        if verbose:
            logging.info(f"Processing batch {i+1}/{len(train_dataloader)}...")

        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)

        n_samples += x_batch.shape[0]

        for layer in layers_to_look_at:
            m = lambda input: model.forward_up_to_k_layer(input, layer)
            norms = compute_jac_norm(m, x_batch)
            # lower bound is the supremum over norms
            lower = torch.max(norms)
            if lower > lower_bounds[str(layer)]:
                lower_bounds[str(layer)] = lower

            mean_bounds[str(layer)] += torch.sum(norms)
            rms_bounds[str(layer)] += torch.sum(norms**2)

    if verbose:
        logging.info("Lower bound computed!")

    # compute this for the upper_bounds
    # NOTE: indexation of layers in per_layer_Lipschitzness works in the following way:
    # layer 0 == input (always 1-Lipschitz)
    # layer 1 == first layer applied (compute Lip. wrt. to input)
    # layer 2 == first two layers applied (compute Lip. wrt. to output of layer 1)
    # ...

    if verbose:
        logging.info("Computing the upper bound...")
    per_layer_Lipschitzness = [torch.tensor(1.0)]
    for i in range(len(model.layers)):
        if verbose:
            logging.info(f"Processing layer {i+1}/{len(model.layers)}...")
        per_layer_Lipschitzness.append(
            # collapse all nested layers to a singular layer
            compute_final_upper_bound(
                compute_lipschitz_upper_bound_per_layer(
                    model.layers[i], model.layer_input_shapes[i]
                )
            )
        )
    per_layer_Lipschitzness = torch.Tensor(per_layer_Lipschitzness)

    if verbose:
        logging.info("Upper bound computed!")

    # format results
    for layer in layers_to_look_at:
        # upper bound is the product of per-layer Lipschitz constants)
        upper_bound = torch.prod(per_layer_Lipschitzness[: layer + 1])
        results[str(layer)] = (
            # sup of norms (lower bound)
            float(lower_bounds[str(layer)].item()),
            # mean of norms (mean bound)
            float((mean_bounds[str(layer)] / n_samples).item()),
            # rms of norms (rms bound)
            float(torch.sqrt(rms_bounds[str(layer)] / n_samples).item()),
            # upper bound
            float(upper_bound.item()),
        )

    return results


def compute_lipschitz_upper_bound_per_layer(
    layer: torch.nn.Module,
    layer_input_shape: list[int],
    dtype: torch.dtype = torch.float32,
) -> torch.Tensor | list | dict:
    """Returns the Lipschitz constant for each particular layer of the model.
    If a Sequence is given, outputs a list of Lipschitz constants.
    If a Bottleneck or BasicBlock are given, output a dictionary
    with a sequence of Lipschitz constants for the sequential and residual parts.

    Parameters
    ----------
    layer
        Layer object.
    layer_input_shape
        Shape(s) of the input for this layer. Must be present to compute Conv layers' Lip. constant.
    dtype
        dtype to use for Lipschitz computation, by default torch.float32

    Returns
    -------
        The Lipschitz constant for this particular layer (not considering previous layers.)
    """
    # supress warnings from scipy
    warnings.filterwarnings(action="ignore", module="scipy")

    if isinstance(layer, nn.Sequential):
        # recursive call for nested sequential layers
        return [
            compute_lipschitz_upper_bound_per_layer(layer[i], layer_input_shape[i], dtype)
            for i in range(len(layer))
        ]

    if isinstance(layer, nn.Linear):
        # for linear layers, jacobian is the weight matrix
        w = layer.state_dict()["weight"]
        return compute_matrix_2norm_power_method(w).type(dtype)

    if isinstance(layer, nn.Conv2d):
        # Here, we compute the lipschitz constant of the layer as a
        # largest singular value of the identical linear matrix multiplication.

        conv_kernel = layer.weight.cpu().detach().numpy()
        img_size = layer_input_shape[-1]
        K = get_conv2d_matrix(conv_kernel, layer.padding[0], layer.stride[0], img_size)
        return torch.tensor(scipy.sparse.linalg.norm(K, ord=2), dtype=dtype)

    if isinstance(layer, nn.BatchNorm2d):
        state = layer.state_dict()
        var = state["running_var"]
        gamma = state["weight"]
        eps = layer.eps

        return torch.max(gamma / torch.sqrt(var + eps)).type(dtype)

    if isinstance(layer, resnet.BasicBlock) or isinstance(layer, resnet.Bottleneck):
        # this is a ResNet residual block
        if isinstance(layer, resnet.BasicBlock):
            assert check_basic_block_structure(layer)
        else:
            assert check_bottleneck_structure(layer)

        # in comments starting with # //, we present a forward pass of the BasicBlock acc. to
        # https://pytorch.org/vision/stable/_modules/torchvision/models/resnet.html
        # for clarity

        # // identity = x
        lip_residual = torch.tensor(1.0, dtype=dtype)

        # // out = self.conv1(x)
        lip = [compute_lipschitz_upper_bound_per_layer(layer.conv1, layer_input_shape[0], dtype)]
        # // out = self.bn1(out)
        lip += [compute_lipschitz_upper_bound_per_layer(layer.bn1, [], dtype)]
        # // out = self.relu(out)

        # // out = self.conv2(out)
        lip += [compute_lipschitz_upper_bound_per_layer(layer.conv2, layer_input_shape[1], dtype)]
        # // out = self.bn2(out)
        lip += [compute_lipschitz_upper_bound_per_layer(layer.bn2, [], dtype)]

        # in case of "Bottleneck" module
        if isinstance(layer, resnet.Bottleneck):
            # // out = self.relu(out)

            # // out = self.conv3(out)
            lip += [
                compute_lipschitz_upper_bound_per_layer(layer.conv3, layer_input_shape[2], dtype)
            ]
            # // out = self.bn3(out)
            lip += [compute_lipschitz_upper_bound_per_layer(layer.bn3, [], dtype)]

        # // if self.downsample is not None:
        # //     identity = self.downsample(x)
        if layer.downsample is not None:
            # downsample is a sequential layer with a convolution and batchnorm
            lip_residual = compute_lipschitz_upper_bound_per_layer(
                layer.downsample, [layer_input_shape[0], []], dtype
            )

        # // out += identity
        # // out = self.relu(out)
        return {"residual": lip_residual, "sequential": lip}

    # by default, for other types of layers, output 1
    # (be careful to include all non 1-Lipschitz layers in the check before)
    return torch.tensor(1.0, dtype=dtype)


def compute_final_upper_bound(
    lipschitz_per_layer: float | list | dict | torch.Tensor, dtype: torch.dtype = torch.float32
) -> torch.Tensor:
    """Computes the upper bound of the whole model, given all Lipschitz constant for each layer.

    Parameters
    ----------
    lipschitz_per_layer
        Output of the `compute_lipschitz_upper_bound_per_layer` function, can be a Tensor, float, list or a dict
    dtype, optional
        dtype to use for Lipschitz computation, by default torch.float32

    Returns
    -------
        A tensor with one number, representing the answer

    Raises
    ------
    TypeError
        This exception is raised when `lipschitz_per_layer` type does not match torch.Tensor, dict or list

    """
    if isinstance(lipschitz_per_layer, torch.Tensor):
        return lipschitz_per_layer.type(dtype)

    if isinstance(lipschitz_per_layer, float):
        return torch.tensor(lipschitz_per_layer, dtype=dtype)

    if isinstance(lipschitz_per_layer, dict):
        residual = compute_final_upper_bound(lipschitz_per_layer["residual"], dtype)
        sequential = compute_final_upper_bound(lipschitz_per_layer["sequential"], dtype)
        return residual + sequential

    if isinstance(lipschitz_per_layer, list):
        return torch.stack(
            [compute_final_upper_bound(i, dtype) for i in lipschitz_per_layer]
        ).prod()

    raise TypeError(
        f"Lipschitz per layer type {type(lipschitz_per_layer)} is unknown for this function."
    )
