from spaghettini import quick_register

import torch

from torch.nn.functional import cross_entropy


@quick_register
def sequence_cross_entropy(preds, ys):
    preds = preds.transpose(1, 2)
    flattened_preds = preds.reshape(-1, preds.shape[-1])
    flattened_ys = ys.view(-1)

    return cross_entropy(input=flattened_preds, target=flattened_ys)


@quick_register
def l1_fp_regularization(preds, ys, l1_coeff=0., **kwargs):
    # Extract the fixed points.
    fixed_points = kwargs["model_logs"]["result"]

    # Compute the l1 regularization term on the fixed points.
    sum_dims = list(range(len(fixed_points.shape)))[1:]
    l1_fp_reg = l1_coeff * torch.abs(fixed_points).sum(dim=sum_dims).mean()

    return l1_fp_reg


@quick_register
def cross_example_cosine_sim_regularization(preds, ys, **kwargs):
    bs = ys.shape[0]

    # Get the fixed points predicted by the model.
    model_logs = kwargs["model_logs"]
    fps = model_logs["result"]

    # Flatten, normalize and compute cosine similarity.
    normalize_across_last_axis = lambda x: x / torch.sqrt(torch.sum(x**2, dim=-1, keepdim=True) + 1e-16)
    norm_fps = normalize_across_last_axis(fps.view(bs, -1))
    cosine_sims = torch.sum(norm_fps[:, None, ...] * norm_fps[None, ...], dim=-1)
    cosine_sims = cosine_sims - torch.eye(cosine_sims.shape[0]).type_as(cosine_sims)
    mean_cosine_sims = cosine_sims.mean()

    return kwargs["coeff"] * mean_cosine_sims


@quick_register
def alignment_based_path_independence_penalty(preds, ys, use_cosine=False, **kwargs):
    bs = ys.shape[0]

    # Get the model and the inputs and other kwargs.
    system = kwargs["system"]
    xs = kwargs["xs"]
    repeats = kwargs["input_repeats"]

    # Save the z0_init method, so that we can reinstate it after computing the regularization term.
    original_z0_init_method = system.model.z0_init_method

    # Temporarily switch to initializing z0s using the normal samples.
    system.model.z0_init_method = "normal"

    # Interleave the inputs, run a forward pass and get the fixed points.
    expanded_xs = torch.repeat_interleave(input=xs, repeats=repeats, dim=0)
    expanded_outs, expanded_model_logs = system(expanded_xs)
    expanded_fps = expanded_model_logs["result"]

    # Compute the similarity between the fixed points.
    normalize_across_last_axis = lambda x: x / torch.sqrt(torch.sum(x**2, dim=-1, keepdim=True) + 1e-16)
    expanded_fps = expanded_fps.reshape(bs*repeats, -1)
    if use_cosine:
        expanded_fps = normalize_across_last_axis(expanded_fps)
    similarities = torch.sum(expanded_fps[:, None, ...] * expanded_fps[None, ...], dim=-1)
    similarities = similarities - torch.eye(similarities.shape[0]).type_as(similarities)

    # Prepare the mask so that only the similarity between fixed points computed from the same example.
    if kwargs["mask_cross_example_similarities"]:
        mask_arr = torch.arange(1, bs+1).type_as(similarities)
        mask_arr = torch.repeat_interleave(mask_arr, repeats=repeats, dim=0)
        mask = mask_arr[..., None] @ (1. / mask_arr[None, ...])
        mask = (torch.abs(mask - torch.ones_like(mask)) < 1e-8).type_as(similarities)
        normalizer = (bs*repeats*(repeats-1))
    else:
        mask = torch.ones_like(similarities)
        normalizer = (bs + repeats) ** 2

    # Mask the similarities, normalize properly and return.
    penalty = torch.sum((mask * similarities)) / normalizer

    # Get back to initializing z0s the way it was done before.
    system.model.z0_init_method = original_z0_init_method

    return kwargs["coeff"] * penalty
