"""Utility functions for the experiment."""

from copy import deepcopy
from math import sqrt
from os import path
from typing import Iterable, Tuple

from backpack.utils.convert_parameters import vector_to_parameter_list
from curvlinops import JacobianLinearOperator, TransposedJacobianLinearOperator
from numpy import eye, stack
from scipy.linalg import eigvalsh
from scipy.sparse import identity
from scipy.sparse.linalg import aslinearoperator, eigsh, svds
from torch import Tensor, cat, manual_seed, no_grad, rand, randn
from torch.nn import Module
from torch.random import fork_rng
from torch.utils.data import DataLoader, Dataset, Subset, TensorDataset
from torchvision.datasets import CIFAR10, CIFAR100
from torchvision.transforms import Compose, Lambda, Normalize, ToTensor
from tqdm import tqdm

from datasets import synthetic_1d_regression
from models import ShallowReLU, deep_relu

HEREDIR = path.dirname(path.abspath(__file__))

# make WideResNet implementation from git sub-module importable
WRN_DIR = path.join(path.dirname(HEREDIR), "WideResNet-pytorch")
import sys

if WRN_DIR not in sys.path:
    sys.path.append(WRN_DIR)

from wideresnet import WideResNet

CHECK_DETERMINISTIC = False # WARNING: Only turn off if you know what you are doing.


def model_and_data(
    data_name: str, model_name: str, width: int, init_seed: int
) -> Tuple[Module, Iterable[Tuple[Tensor, Tensor]]]:
    """Set up neural network and data set.

    Args:
        data_name: Name of the data set.
        model_name: Name of the neural network.
        width: Width of the neural network.
        init_seed: Seed for initialization of the neural network.

    Returns:
        Neural network and data set.
    """
    assert (
        (data_name, model_name)
        in {
            ("synthetic_1d_regression", "shallow_relu"),
            ("synthetic_1d_regression", "less_deep_relu"),
            ("cifar100", "wideresnet"),
            ("cifar10", "wideresnet"),
        }
        or "cifar10_subset" in data_name
        and model_name == "wideresnet"
        or data_name == "synthetic_1d_regression"
        and "deep_relu" in model_name
    )

    seed_offset = 1469672034

    if (data_name, model_name) == ("synthetic_1d_regression", "shallow_relu"):
        X, y = synthetic_1d_regression(num_data=16)
        # satisfy theoretical assumptions
        X /= X.norm(dim=1, keepdim=True)

        in_features, out_features = X.shape[1], y.shape[1]
        nu = 1.0
        with fork_rng():
            manual_seed(seed_offset + init_seed)
            model = ShallowReLU(in_features, width, out_features=out_features, nu=nu)

        return model, [(X, y)]

    elif data_name == "synthetic_1d_regression" and "deep_relu" in model_name:
        prefix1, prefix2, depth = model_name.split("_")
        assert prefix1 == "deep" and prefix2 == "relu"
        depth = int(depth)

        X, y = synthetic_1d_regression(num_data=16)
        # satisfy theoretical assumptions
        X /= X.norm(dim=1, keepdim=True)

        in_features, out_features = X.shape[1], y.shape[1]
        with fork_rng():
            manual_seed(seed_offset + init_seed)
            model = deep_relu(in_features, out_features, width, depth)

        return model, [(X, y)]

    elif (data_name, model_name) == ("cifar100", "wideresnet"):
        # create data loader
        normalization = Normalize(
            # from https://gist.github.com/weiaicunzai/e623931921efefd4c331622c344d8151?permalink_comment_id=2627261#gistcomment-2627261
            mean=(0.5071, 0.4867, 0.4408),
            std=(0.2675, 0.2565, 0.2761),
        )
        dataset = CIFAR100(
            root=path.join(HEREDIR, "datasets"),
            train=True,
            download=True,
            transform=Compose([ToTensor(), normalization]),
        )
        batch_size = max(1, 1000 // width)
        # load data set into RAM to avoid multiple SLURM jobs reading from the same file
        dataloader = ram_dataloader(
            dataset, batch_size=batch_size, shuffle=False, drop_last=False
        )

        with fork_rng():
            manual_seed(seed_offset + init_seed)
            model = WideResNet(
                num_classes=100,
                depth=10,  # same as ``catastrophic_forgetting`` experiment
                widen_factor=width,
                dropRate=0.0,  # same as ``catastrophic_forgetting`` experiment
            )
            # make forward pass deterministic
            model = model.eval()

        return model, dataloader

    elif (data_name, model_name) == ("cifar10", "wideresnet"):
        # create data loader
        normalization = Normalize(
            # from https://gist.github.com/weiaicunzai/e623931921efefd4c331622c344d8151?permalink_comment_id=2627261#gistcomment-2627261
            mean=(0.4914, 0.4822, 0.4465),
            # from https://gist.github.com/weiaicunzai/e623931921efefd4c331622c344d8151?permalink_comment_id=2851662#gistcomment-2851662
            std=(0.2470, 0.2435, 0.2616),
        )
        dataset = CIFAR10(
            root=path.join(HEREDIR, "datasets"),
            train=True,
            download=True,
            transform=Compose([ToTensor(), normalization]),
        )
        batch_size = max(1, 1000 // width)
        # load data set into RAM to avoid multiple SLURM jobs reading from the same file
        dataloader = ram_dataloader(
            dataset, batch_size=batch_size, shuffle=False, drop_last=False
        )

        with fork_rng():
            manual_seed(seed_offset + init_seed)
            model = WideResNet(
                num_classes=10,
                depth=10,  # same as ``catastrophic_forgetting`` experiment
                widen_factor=width,
                dropRate=0.0,  # same as ``catastrophic_forgetting`` experiment
            )
            # make forward pass deterministic
            model = model.eval()

        return model, dataloader

    elif "cifar10_subset" in data_name and model_name == "wideresnet":
        prefix1, prefix2, num_points = data_name.split("_")
        assert prefix1 == "cifar10" and prefix2 == "subset"
        num_points = int(num_points)

        # create data loader
        normalization = Normalize(
            # from https://gist.github.com/weiaicunzai/e623931921efefd4c331622c344d8151?permalink_comment_id=2627261#gistcomment-2627261
            mean=(0.4914, 0.4822, 0.4465),
            # from https://gist.github.com/weiaicunzai/e623931921efefd4c331622c344d8151?permalink_comment_id=2851662#gistcomment-2851662
            std=(0.2470, 0.2435, 0.2616),
        )
        as_float = Lambda(lambda y: Tensor([float(y)]))
        dataset = CIFAR10(
            root=path.join(HEREDIR, "datasets"),
            train=True,
            download=True,
            transform=Compose([ToTensor(), normalization]),
            target_transform=as_float,
        )

        # use only a sub-set
        subset_factor = len(dataset) // num_points
        active = list(range(0, len(dataset), subset_factor))

        if len(active) < num_points:
            missing = num_points - len(active)
            inactive = set(range(len(dataset))) - set(active)
            active.extend(list(inactive)[:missing])
        assert len(active) == num_points and len(set(active)) == num_points

        dataset = Subset(dataset, active)

        batch_size = max(1, 2000 // width)
        # load data set into RAM to avoid multiple SLURM jobs reading from the same file
        dataloader = ram_dataloader(
            dataset, batch_size=batch_size, shuffle=False, drop_last=False
        )

        with fork_rng():
            manual_seed(seed_offset + init_seed)
            model = WideResNet(
                num_classes=1,
                depth=10,  # same as ``catastrophic_forgetting`` experiment
                widen_factor=width,
                dropRate=0.0,  # same as ``catastrophic_forgetting`` experiment
            )
            # make forward pass deterministic
            model = model.eval()

        return model, dataloader


def ram_dataloader(
    dataset: Dataset, batch_size: int, shuffle: bool, drop_last: bool
) -> DataLoader:
    """Load the dataset into RAM and create a data loader for it.

    Args:
        dataset: Dataset which will be loaded into RAM.
        batch_size: Batch size used by the returned loader.
        shuffle: Whether the returned loader uses shuffling.
        drop_last: Whether the returned loader discards the last batch if it is smaller
            than the specified batch size.

    Returns:
        A dataloader whose dataset lives in RAM.
    """
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, drop_last=False)
    inputs, labels = list(zip(*[batch for batch in loader]))
    dataset = TensorDataset(cat(inputs), cat(labels))
    return DataLoader(
        dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
    )


def rand_ball(radius: float, dim: int) -> Tensor:
    """Sample a random vector uniformly from an L2-ball.

    Let ``S`` be a random vector on the ``dim``-dimensional sphere. Multiplying ``S`` by
    ``U ** (1 / dim)``, where ``U`` has the uniform distribution on the unit interval
    ``(0,1)``, creates the uniform distribution in the unit ``dim``-dimensional ball.

    Useful references:
    - https://blogs.sas.com/content/iml/2016/04/06/generate-points-uniformly-in-ball.html
    - https://www.sciencedirect.com/science/article/pii/S0047259X10001211

    Args:
        radius: radius of the ball.
        dim: dimension of the ball.

    Returns:
        A random vector uniformly sampled from the ``dim``-dimensional ball with
        radius ``radius``.
    """
    Y = rand_sphere(1, dim)
    U = rand(dim)
    return radius * Y / U.pow(1 / dim)


def rand_sphere(radius: float, dim: int) -> Tensor:
    """Sample a random vector uniformly from a sphere.

    If ``Y`` is drawn from the uncorrelated multivariate normal distribution, then
    ``S = Y / ||Y||`` has the uniform distribution on the unit ``dim``-sphere.

    Useful references:
    - https://blogs.sas.com/content/iml/2016/04/06/generate-points-uniformly-in-ball.html
    - https://www.sciencedirect.com/science/article/pii/S0047259X10001211

    Args:
        radius: radius of the sphere.
        dim: dimension of the sphere.

    Returns:
        A random vector uniformly sampled from the ``dim``-dimensional sphere with
        radius ``radius``.
    """
    Y = randn(dim)
    Y /= Y.norm()
    return radius * Y


def gram_matrix_smallest_eigval(
    model: Module,
    data: Iterable[Tuple[Tensor, Tensor]],
    epsilon: float = 0.0,
    tol: float = 0.0,
) -> float:
    """Compute the Gram matrix's minimal eigenvalue.

    Args:
        model: Neural network.
        data: Data set.
        epsilon: Small constant added to the Gram matrix diagonal to improve
            convergence of the eigensolver. Default: ``0``.
        tol: Relative tolerance for the eigensolver. Default: ``0``.

    Returns:
        Minimal eigenvalue of the Gram matrix (J^T J).
    """
    params = [p for p in model.parameters() if p.requires_grad]
    jac = JacobianLinearOperator(model, params, data, progressbar=True, check_deterministic=CHECK_DETERMINISTIC)
    jac_t = TransposedJacobianLinearOperator(model, params, data, progressbar=True, check_deterministic=CHECK_DETERMINISTIC)
    gram = jac @ jac_t

    if epsilon != 0.0:
        damping = aslinearoperator(epsilon * identity(gram.shape[0], dtype=gram.dtype))
        gram = gram + damping

    # If Gram matrix is small enough, compute and eigen-decompose it explicitly
    if gram.shape[0] <= 10_000:
        print("Computing and eigen-decomposing Gram matrix explicitly.")
        gram_mat = stack(
            [
                gram @ v
                for v in tqdm(eye(gram.shape[0]), desc="Computing explicit Gram matrix")
            ]
        )
        min_eigval = eigvalsh(gram_mat)[0]
        return max(epsilon, min_eigval) if epsilon != 0.0 else min_eigval

    # NOTE Computing smallest eigenvalues is slow in ARPACK
    # (see https://docs.scipy.org/doc/scipy/tutorial/arpack.html#examples).
    #
    # There are some tricks to speed up convergence:
    #
    # 1) Using a larger `tol`
    #
    # 2) Using shift-invert mode (`sigma != None`): Requires computing a matrix
    #    decomposition which may be slow when using the default solver. One way
    #    to fix this is use a different solver (``OpInv``) that converges faster
    #    (see https://gist.github.com/denis-bz/dd17d36a5365378e1c2ee79ecdc419b4#notes-on-some-solvers).
    #
    # 3) Compute ``λ_max(A)``, then ``λ_min(A) = λ_max(A) - λ_max(A - λ_max(A) * I)``.
    #    (proposed in https://stackoverflow.com/a/60042695). This yields the wrong
    #    result if ``λ_max(A) - λ_min(A)`` cannot be accurately represented in the used
    #    floating point precision.
    #
    # For now, we allow using a larger `tol` and use shift-invert if the damping is non-zero
    print("Using ARPACK (eigsh).")

    if epsilon == 0.0:
        min_eigval = eigsh(gram, k=1, which="SM", return_eigenvectors=False, tol=tol)[0]
    else:
        print("Using shift-invert mode")
        min_eigval = eigsh(
            gram, k=1, which="LM", return_eigenvectors=False, sigma=0.0, tol=tol
        )[0]

    return max(epsilon, min_eigval) if epsilon != 0.0 else min_eigval


def close_to_linear_condition(
    model: Module,
    data: Iterable[Tuple[Tensor, Tensor]],
    gram_min_eigval: float,
    perturbation_seed: int,
) -> float:
    """Evaluate the condition whether a model is close to linear.

    Args:
        model: Neural network.
        data: Data set
        gram_min_eigval: Minimum eigenvalue of the model Jacobian's Gram matrix.
        perturbation_seed: Random seed used to perturb the model's parameters.

    Returns:
        A number which, if it lies between 0 and 1/2, indicates that the neural network
        is sufficiently close to linear such that NGD converges fast (strictly speaking
        we would have to evaluate this number for 'all' possible parameter
        perturbations), but a finite number of samples should give a reasonable proxy
        to this property.
    """
    params = [p for p in model.parameters() if p.requires_grad]
    jac_init = JacobianLinearOperator(model, params, data, progressbar=True, check_deterministic=CHECK_DETERMINISTIC)

    dev = params[0].device
    # compute radius of ball
    with no_grad():
        predictions, labels = zip(
            *[(model(X.to(dev)).flatten(), y.to(dev).flatten()) for X, y in data]
        )
        predictions = cat(predictions)
        labels = cat(labels)
        assert predictions.shape == labels.shape
        radius = 3 * (predictions - labels).norm().item() / sqrt(gram_min_eigval)

    # sample parameters in ball and generate model with shifted parameters
    with fork_rng():
        offset = 240236572
        manual_seed(offset + perturbation_seed)
        num_params = sum(p.numel() for p in params)
        # shift_sampled = rand_ball(radius, num_params)
        shift_sampled = rand_sphere(radius, num_params)
        shift_sampled = vector_to_parameter_list(shift_sampled, params)

    model_shifted = deepcopy(model)
    params_shifted = [p for p in model_shifted.parameters() if p.requires_grad]
    for shift, param in zip(shift_sampled, params_shifted):
        param.data += shift.to(param.device)
    jac_shifted = JacobianLinearOperator(
        model_shifted, params_shifted, data, progressbar=True, check_deterministic=CHECK_DETERMINISTIC
    )

    jac_diff = jac_shifted - jac_init

    if any(dim <= 10_000 for dim in jac_diff.shape):
        # compute spectral norm as ``sqrt(λ_max(A^T A)) = sqrt(λ_max(A A^T))`` (see
        # https://math.stackexchange.com/a/586835) with ``A`` the Jacobian difference
        print("Computing matrix and spectral norm explicitly.")

        if jac_diff.shape[0] < jac_diff.shape[1]:
            jac_diff_outer = jac_diff @ jac_diff.T
        else:
            jac_diff_outer = jac_diff.T @ jac_diff

        # expand into explicit matrix
        jac_diff_outer_mat = stack(
            [
                jac_diff_outer @ v
                for v in tqdm(
                    eye(jac_diff_outer.shape[0]), desc="Computing explicit matrix"
                )
            ]
        )
        max_eigval = eigvalsh(jac_diff_outer_mat)[-1]
        spectral_norm = sqrt(max_eigval)

    else:  # use sparse SVD
        print("Using sparse SVD (svds).")
        # compute spectral norm (top singular value, see
        # https://math.stackexchange.com/a/586835) of Jacobian difference
        spectral_norm = svds(jac_diff, k=1, which="LM", return_singular_vectors=False)[
            0
        ]

    return 3 * spectral_norm / sqrt(abs(gram_min_eigval))
