from functools import partial
from typing import Literal

import flax.linen as nn
import jax
import tree_math as tm
from jax import numpy as jnp
from jax.numpy import exp, ones, stack, sum
from jax.random import PRNGKey
from lpips_j.lpips import VGGExtractor
from optax import l2_loss


def sse_loss(preds, y):
    residual = preds - y
    return jnp.sum(residual**2)


def mse_recon_loss(model_fn, params, batch, rng):
    imgs, _ = batch
    recon_imgs = model_fn(params, imgs, rng)
    loss = (
        ((recon_imgs - imgs) ** 2).mean(axis=0).sum()
    )  # Mean over batch, sum over pixels
    return loss


@jax.vmap
def elbo(model_fn, params, batch, rng):
    imgs, _ = batch
    recon_imgs, mean, logvar = model_fn(params, imgs, rng)
    kld = -0.5 * jnp.sum(1 + logvar - jnp.square(mean) - jnp.exp(logvar))
    recon_loss = ((recon_imgs - imgs) ** 2).mean(axis=0).sum()
    loss = recon_loss + kld
    return loss


def cross_entropy_loss(preds, y, rho=1.0):
    """
    preds: (n_samples, n_classes) (logits)
    y: (n_samples, n_classes) (one-hot labels)
    """
    preds = preds * rho
    preds = jax.nn.log_softmax(preds, axis=-1)
    return -jnp.sum(jnp.sum(preds * y, axis=-1))


def cross_entropy_loss_per_datapoint(preds, y):
    """
    preds: (n_samples, n_classes) (logits)
    y: (n_samples, n_classes) (one-hot labels)
    """
    preds = jax.nn.log_softmax(preds, axis=-1)
    return -jnp.sum(preds * y, axis=-1)


def accuracy(params, model, batch_x, batch_y):
    preds = model.apply(params, batch_x)
    return jnp.sum(preds.argmax(axis=-1) == batch_y.argmax(axis=-1))


def accuracy_preds(preds, batch_y):
    return jnp.sum(preds.argmax(axis=-1) == batch_y.argmax(axis=-1))


def nll(preds, y):
    preds = jax.nn.log_softmax(preds, axis=-1)
    return (-jnp.sum(jnp.sum(preds * y, axis=-1), axis=-1)).mean()


def sse_loss(preds, y):
    residual = preds - y
    return jnp.sum(residual**2)


def gaussian_log_lik_loss(preds, y, rho=1.0):
    O = y.shape[-1]
    return (
        0.5 * O * jnp.log(2 * jnp.pi)
        - 0.5 * O * jnp.log(rho)
        + 0.5 * rho * sse_loss(preds, y)
    )


@partial(
    jax.jit,
    static_argnames=["alpha", "rho", "model", "D", "N", "likelihood", "extra_stats"],
)
def log_posterior_loss(
    params,
    alpha,
    rho,
    model,
    x_batch,
    y_batch,
    D: int,
    N: int,
    likelihood: Literal["classification", "regression"] = "classification",
    extra_stats: bool = False,
):
    # define dict for logging purposes
    loss_dict = {}
    B = x_batch.shape[0]
    O = y_batch.shape[-1]
    vparams = tm.Vector(params)

    if likelihood == "regression":
        negative_log_likelihood = gaussian_log_lik_loss
    elif likelihood == "classification":
        negative_log_likelihood = cross_entropy_loss
    else:
        raise ValueError(
            f"Likelihood {likelihood} not supported. Use either 'regression' or 'classification'."
        )

    y_pred = model.apply(params, x_batch)

    neg_loglikelihood = negative_log_likelihood(y_pred, y_batch, rho)
    logprior = (
        -D / 2 * jnp.log(2 * jnp.pi)
        - (1 / 2) * alpha * (vparams @ vparams)
        + D / 2 * jnp.log(alpha)
    )
    logposterior = -neg_loglikelihood + logprior
    scaled_neg_logposterior = (1 / B) * neg_loglikelihood - logprior  # this is the loss

    loss_dict = {
        "log_likelihood": -neg_loglikelihood,
        "log_prior": logprior,
        "log_posterior": logposterior,
    }
    if extra_stats:
        loss_dict["sum_squared_error"] = sse_loss(y_pred, y_batch)
    return scaled_neg_logposterior, loss_dict


class LPIPSFIX(nn.Module):
    def setup(self):
        self.vgg = VGGExtractor()

    def __call__(self, x, t):
        x = self.vgg(x)
        t = self.vgg(t)

        # conv_names = ['conv1_1', 'conv1_2', 'conv2_1', 'conv2_2', 'conv3_1',
        #               'conv3_2', 'conv3_3', 'conv3_3', 'conv4_1', 'conv4_2',
        #               'conv4_3', 'conv5_1', 'conv5_2', 'conv5_3']

        # starting CONV layers are more important for perceptual similarity
        # and produces better results. Including later layers leads to worse results.

        conv_names = [
            "conv1_1",
            "conv1_2",
            "conv2_1",
            "conv2_2",
            "conv3_1",
            "conv3_2",
            "conv3_3",
            "conv3_3",
        ]

        diffs = []
        for f in conv_names:
            diff = (x[f] - t[f]) ** 2
            diff = 0.5 * diff.mean([1, 2, 3])
            diffs.append(diff)

        return stack(diffs, axis=1).sum(axis=1)


class PRCLoss:
    def __init__(self, beta=1.0, image_shape=[28, 28, 1], elbo=False):
        self.beta = beta
        self.lpips_obj = lpips_obj = LPIPSFIX()
        example = ones(image_shape)
        self.lpips_params = lpips_obj.init(PRNGKey(0), example, example)
        self.elbo = elbo

    def __call__(self, outputs, batch):
        if not self.elbo:
            x_hat, z_mu, z_logvar = outputs
            kl_loss = -0.5 * sum(1.0 + z_logvar - z_mu**2 - exp(z_logvar), axis=-1)
            prc_loss = jnp.squeeze(
                self.lpips_obj.apply(self.lpips_params, batch, x_hat)
            )
        else:
            x_hat = outputs
            kl_loss = 0.0
            prc_loss = 0.0

        rec_loss = l2_loss(x_hat, batch).sum([-1, -2, -3])
        loss = rec_loss + prc_loss + self.beta * kl_loss
        return loss
