# https://github.com/ngushchin/LightSB/src/light_sb.py

import jax
import math
import flax
import optax
import jax.numpy as jnp
import numpy as np
from fn import logsumexp


def init_r_by_samples(params, samples):
    return jax.tree_util.tree_map_with_path(
        lambda path, leaf: samples if path[-1].key == "r" else leaf, params)

def init_orthogonal_S(params):
    return jax.tree_util.tree_map_with_path(
        lambda path, leaf: jax.lax.linalg.qr(leaf)[0] if path[-1].key == "S_rotation_matrix" else leaf, params)

def set_epsilon(params, new_epsilon):
    return jax.tree_util.tree_map_with_path(
        lambda path, leaf: (new_epsilon * jnp.ones((), dtype=leaf.dtype)) if path[-1].key == "eps" else leaf, params)

def orthogonal_S_rotation_matrix() -> optax.GradientTransformation:

    def init_fn(_):
        return optax.EmptyState()

    def update_fn(updates, state, params):
        return jax.tree_util.tree_map_with_path(
            lambda path, u, p: jax.lax.linalg.qr(u + p)[0] - p if path[-1].key == "S_rotation_matrix" else u, updates, params), state

    return optax.GradientTransformation(init_fn, update_fn)

def sb_opt(x: optax.GradientTransformation):
    return optax.chain(x, orthogonal_S_rotation_matrix())


class Lsb(flax.linen.Module):
    dim: int = 2
    n_potential: int = 5
    epsilon: float = 1.0
    is_diagonal: bool = True
    S_diagonal_init: float = 0.1
    dtype: jax.typing.DTypeLike = jnp.float64

    def setup(self):
        self.eps_init = flax.linen.initializers.constant(self.epsilon)
        self.log_alpha_raw_init = flax.linen.initializers.constant(-self.epsilon * math.log(self.n_potential))
        # self.r_init = flax.linen.initializers.normal(1.0 * float(np.sqrt(self.epsilon)))
        self.r_init = flax.linen.initializers.normal(1.0 * float(np.sqrt(self.epsilon)))
        self.S_log_diagonal_matrix_init = flax.linen.initializers.constant(math.log(self.S_diagonal_init))
        self.S_rotation_matrix_init = flax.linen.initializers.normal(1.0)

        self.eps = self.setup_eps()
        self.log_alpha_raw = self.setup_log_alpha_raw()
        self.r = self.setup_r()
        self.S_log_diagonal_matrix = self.setup_S_log_diagonal_matrix()
        if not(self.is_diagonal): self.S_rotation_matrix = self.setup_S_rotation_matrix()

    def setup_eps(self):
        return jax.lax.stop_gradient(self.param("eps", self.eps_init, (), self.dtype))

    def setup_log_alpha_raw(self):
        return self.param("log_alpha_raw", self.log_alpha_raw_init, (self.n_potential,), self.dtype)

    def setup_r(self):
        return self.param("r", self.r_init, (self.n_potential, self.dim), self.dtype)

    def setup_S_log_diagonal_matrix(self):
        return self.param("S_log_diagonal_matrix", self.S_log_diagonal_matrix_init, (self.n_potential, self.dim), self.dtype)

    def setup_S_rotation_matrix(self):
        return self.param("S_rotation_matrix", self.S_rotation_matrix_init, (self.n_potential, self.dim, self.dim), self.dtype)

    def get_log_alpha(self):
        return self.log_alpha_raw / self.eps

    def get_S(self):
        S_log_diagonal_matrix = self.S_log_diagonal_matrix
        if self.is_diagonal:
            return jax.lax.exp(S_log_diagonal_matrix)
        else:
            S_rotation_matrix = self.S_rotation_matrix
            return jax.lax.dot_general((S_rotation_matrix * jnp.expand_dims(jax.lax.exp(S_log_diagonal_matrix), 1)), S_rotation_matrix, (((2,), (2,)), ((0,), (0,))))

    def __call__(self, x, key):
        S = self.get_S()
        r = self.r
        eps = self.eps
        log_alpha = self.get_log_alpha()

        x_ = jnp.expand_dims(x, 1)
        S_ = jnp.expand_dims(S, 0)
        r_ = jnp.expand_dims(r, 0)

        if self.is_diagonal:
            Sx_ = S_ * x_
            x_S_x = jnp.sum(x_ * Sx_, -1)
            x_r = jnp.sum(x_ * r_, -1)
            r_x = r_ + Sx_
        else:
            x_tiled = jnp.tile(x_, (1, self.n_potential, 1))
            S_tiled = jnp.tile(S_, (x.shape[0], 1, 1, 1))

            S_x = jax.lax.dot_general(S_tiled, x_tiled, (((3,), (2,)), ((0, 1), (0, 1))))
            x_S_x = jnp.sum(x_tiled * S_x, -1)
            x_r = jnp.sum(x_ * r_, -1)
            r_x = r_ + S_x

        exp_arg = (x_S_x + 2 * x_r) / (2 * eps) + jnp.expand_dims(log_alpha, 0)

        key1, key2 = jax.random.split(key, 2)

        mix = jax.random.categorical(key1, exp_arg)

        if self.is_diagonal:
            comp = r_x + jax.lax.sqrt(eps * S_) * jax.random.normal(key2, shape=r_x.shape)
        else:
            L_tiled = jnp.tile(jax.lax.linalg.cholesky(S_), (r_x.shape[0], 1, 1, 1))
            comp = r_x + jax.lax.sqrt(eps) * jax.lax.dot_general(L_tiled, jax.random.normal(key2, shape=r_x.shape), (((3,), (2,)), ((0, 1), (0, 1))))

        sample = comp[jnp.arange(mix.shape[0]), mix]

        return sample

    def sample_comp(self, eps, n, key):
        S = self.get_S()
        r = self.r

        nmix, dim = r.shape

        if self.is_diagonal:
            return jnp.expand_dims(r, 0) + jax.lax.sqrt(eps * jnp.expand_dims(S, 0)) * jax.random.normal(key, shape=(n, nmix, dim))

        return jnp.expand_dims(r, 0) + jax.lax.sqrt(eps) * jax.lax.dot_general(jnp.tile(jax.lax.linalg.cholesky(jnp.expand_dims(S, 0)), (n, 1, 1, 1)), jax.random.normal(key, shape=(n, nmix, dim)), (((3,), (2,)), ((0, 1), (0, 1))))


    def sample_comp_each(self, eps, x, n, key):
        S = self.get_S()
        r = self.r
        # eps = self.eps

        x_ = jnp.expand_dims(x, 1) ##
        S_ = jnp.expand_dims(S, 0)
        r_ = jnp.expand_dims(r, 0)

        if self.is_diagonal:
            Sx_ = S_ * x_
            r_x = r_ + Sx_
        else:
            x_tiled = jnp.tile(x_, (1, self.n_potential, 1))
            S_tiled = jnp.tile(S_, (x.shape[0], 1, 1, 1))
            S_x = jax.lax.dot_general(S_tiled, x_tiled, (((3,), (2,)), ((0, 1), (0, 1))))
            r_x = r_ + S_x

        nb, nmix, dim = r_x.shape

        if self.is_diagonal:
            return jnp.expand_dims(r_x, 1) + jax.lax.sqrt(eps * jnp.expand_dims(S_, 1)) * jax.random.normal(key, shape=(nb, n, nmix, dim))

        return jnp.expand_dims(r_x, 1) + jax.lax.sqrt(eps) * jax.lax.dot_general(jnp.tile(jax.lax.linalg.cholesky(S_), (nb, n, 1, 1, 1)), jax.random.normal(key, shape=(nb, n, nmix, dim)), (((4,), (3,)), ((0, 1, 2), (0, 1, 2))))
    

    def get_drift(self, x, t):
        eps = self.eps
        r = self.r

        S_log_diagonal_matrix = self.S_log_diagonal_matrix

        S_diagonal = jax.lax.exp(S_log_diagonal_matrix)
        A_diagonal = jax.lax.expand_dims(t / (eps * (1 - t)), (-1, -2)) + jnp.expand_dims(1 / (eps * S_diagonal), 0)

        S_log_det = jnp.sum(self.S_log_diagonal_matrix, -1)
        A_log_det = jnp.sum(jax.lax.log(A_diagonal), -1)

        log_alpha = self.get_log_alpha()

        def f(x):
            if self.is_diagonal:
                S_inv = 1 / S_diagonal
                A_inv = 1 / A_diagonal

                c = jnp.expand_dims(jnp.expand_dims(1 / (eps * (1 - t)), -1) * x, 1) + jnp.expand_dims(r / (eps * S_diagonal), 0)

                exp_arg = (
                    jnp.expand_dims(log_alpha, 0) - 0.5 * jnp.expand_dims(S_log_det, 0) - 0.5 * A_log_det
                    - 0.5 * jnp.expand_dims(jnp.sum((r * S_inv * r), -1), 0) / eps + 0.5 * jnp.sum((c * A_inv * c), -1)
                )
            else:
                S_diagonal_ = jnp.expand_dims(S_diagonal, 1)
                A_diagonal_ = jnp.expand_dims(A_diagonal, 2)

                S_rotation_matrix = self.S_rotation_matrix
                S_rotation_matrix_ = jnp.expand_dims(S_rotation_matrix, 0)
                S_rotation_matrix_tiled = jnp.tile(S_rotation_matrix_, (x.shape[0], 1, 1, 1))

                S = jax.lax.dot_general((S_rotation_matrix * S_diagonal_), S_rotation_matrix, (((2,), (2,)), ((0,), (0,))))
                A = jax.lax.dot_general((S_rotation_matrix_ * A_diagonal_), S_rotation_matrix_tiled, (((3,), (3,)), ((0, 1), (0, 1))))

                S_inv = jax.lax.dot_general((S_rotation_matrix * (1 / S_diagonal_)), S_rotation_matrix, (((2,), (2,)), ((0,), (0,))))
                A_inv = jax.lax.dot_general((S_rotation_matrix_ * (1 / A_diagonal_)), S_rotation_matrix_tiled, (((3,), (3,)), ((0, 1), (0, 1))))

                c = jnp.expand_dims(jnp.expand_dims(1 / (eps * (1 - t)), -1) * x, 1) + jnp.expand_dims(jax.lax.dot_general(S_inv, r, (((2,), (1,)), ((0,), (0,)))) / eps, 0)

                r_S_inv_r = jnp.sum(jax.lax.dot_general(r, S_inv, (((1,), (1,)), ((0,), (0,)))) * r, -1)
                c_A_inv_c = jnp.sum(jax.lax.dot_general(c, A_inv, (((2,), (2,)), ((0, 1), (0, 1)))) * c, -1)

                exp_arg = (
                    jnp.expand_dims(log_alpha, 0) - 0.5 * jnp.expand_dims(S_log_det, 0) - 0.5 * A_log_det - 0.5 * jnp.expand_dims(r_S_inv_r, 0) / eps + 0.5 * c_A_inv_c
                )

            return jnp.sum(logsumexp(exp_arg, -1))

        return - x / (1 - jnp.expand_dims(t, -1)) + eps * jax.grad(f)(x)

    def sample_at_time_moment(self, x, t, key):
        key1, key2 = jax.random.split(key)
        return t * self(x, key1) + (1 - t) * x + jax.lax.sqrt(self.eps * t * (1 - t)) * jax.random.normal(key2, x.shape)

    def get_log_potential(self, x):
        S = self.get_S()
        r = self.r
        log_alpha = self.get_log_alpha()
        eps = self.eps
        batch_size = x.shape[0]

        x_tiled = jnp.tile(jnp.expand_dims(x, 1), (1, self.n_potential, 1))
        r_tiled = jnp.tile(jnp.expand_dims(r, 0), (batch_size, 1, 1))

        if self.is_diagonal:
            S_eps_tiled = jnp.tile(jnp.expand_dims((eps * S), 0), (batch_size, 1, 1))
            log_prob_x = jnp.sum(jax.scipy.stats.norm.logpdf(x_tiled, r_tiled, jax.lax.sqrt(S_eps_tiled)), -1)
        else:
            S_eps_tiled = jnp.tile(jnp.expand_dims((eps * S), 0), (batch_size, 1, 1, 1))
            log_prob_x = jax.scipy.stats.multivariate_normal.logpdf(x_tiled, r_tiled, S_eps_tiled)

        return logsumexp(log_prob_x + log_alpha, -1)

    def get_log_C(self, x):
        S = self.get_S()
        r = self.r
        eps = self.eps
        log_alpha = self.get_log_alpha()

        x_ = jnp.expand_dims(x, 1)
        S_ = jnp.expand_dims(S, 0)
        r_ = jnp.expand_dims(r, 0)

        if self.is_diagonal:
            Sx_ = S_ * x_
            x_S_x = jnp.sum(x_ * Sx_, -1)
            x_r = jnp.sum(x_ * r_, -1)
        else:
            x_tiled = jnp.tile(x_, (1, self.n_potential, 1))
            S_tiled = jnp.tile(S_, (x.shape[0], 1, 1, 1))

            S_x = jax.lax.dot_general(S_tiled, x_tiled, (((3,), (2,)), ((0, 1), (0, 1))))
            x_S_x = jnp.sum(x_tiled * S_x, -1)
            x_r = jnp.sum(x_ * r_, -1)

        return logsumexp((x_S_x + 2 * x_r) / (2 * eps) + jnp.expand_dims(log_alpha, 0), -1)

    def get_logits(self, x):
        S = self.get_S()
        r = self.r
        eps = self.eps
        log_alpha = self.get_log_alpha()

        x_ = jnp.expand_dims(x, 1)
        S_ = jnp.expand_dims(S, 0)
        r_ = jnp.expand_dims(r, 0)

        if self.is_diagonal:
            Sx_ = S_ * x_
            x_S_x = jnp.sum(x_ * Sx_, -1)
            x_r = jnp.sum(x_ * r_, -1)
        else:
            x_tiled = jnp.tile(x_, (1, self.n_potential, 1))
            S_tiled = jnp.tile(S_, (x.shape[0], 1, 1, 1))

            S_x = jax.lax.dot_general(S_tiled, x_tiled, (((3,), (2,)), ((0, 1), (0, 1))))
            x_S_x = jnp.sum(x_tiled * S_x, -1)
            x_r = jnp.sum(x_ * r_, -1)

        return (x_S_x + 2 * x_r) / (2 * eps) + jnp.expand_dims(log_alpha, 0)

    def log_prob(self, x_1, x_0):
        eps = self.eps
        log_alpha = self.get_log_alpha()
        S = self.get_S()
        r = self.r

        x_0_ = jnp.expand_dims(x_0, 1)
        S_ = jnp.expand_dims(S, 0)
        r_ = jnp.expand_dims(r, 0)

        if self.is_diagonal:
            Sx_ = S_ * x_0_
            x_S_x = jnp.sum(x_0_ * Sx_, -1)
            x_r = jnp.sum(x_0_ * r_, -1)
            r_x = r_ + Sx_

            x_1_tiled = jnp.tile(jnp.expand_dims(x_1, 1), (1, self.n_potential, 1))
            S_eps_tiled = jnp.tile(eps * S_, (x_1.shape[0], 1, 1))
            log_prob_x_1 = jnp.sum(jax.scipy.stats.norm.logpdf(x_1_tiled, r_x, jax.lax.sqrt(S_eps_tiled)), -1)
        else:
            x_tiled = jnp.tile(x_0_, (1, self.n_potential, 1))
            S_tiled = jnp.tile(S_, (x_0_.shape[0], 1, 1, 1))

            S_x = jax.lax.dot_general(S_tiled, x_tiled, (((3,), (2,)), ((0, 1), (0, 1))))
            x_S_x = jnp.sum(x_tiled * S_x, -1)
            x_r = jnp.sum(x_0_ * r_, -1)
            r_x = r_ + S_x

            x_1_tiled = jnp.tile(jnp.expand_dims(x_1, 1), (1, self.n_potential, 1))
            S_eps_tiled = jnp.tile(eps * S_, (x_1.shape[0], 1, 1, 1))
            log_prob_x_1 = jax.scipy.stats.multivariate_normal.logpdf(x_1_tiled, r_x, S_eps_tiled)

        exp_arg = (x_S_x + 2 * x_r) / (2 * eps) + jnp.expand_dims(log_alpha, 0)
        return logsumexp(jax.nn.log_softmax(exp_arg) + log_prob_x_1, -1)

    def sample_euler_maruyama(self, x, n_steps, key):
        dt = 1 / n_steps
        sqrt_eps_dt = jax.lax.sqrt(self.eps * dt)

        def fn(carry, t):
            x, key = carry
            key, subkey = jax.random.split(key)
            y = x + self.get_drift(x, t * jnp.ones(x.shape[0])) * dt + jax.random.normal(subkey, x.shape) * sqrt_eps_dt
            return (y, key), y

        return jnp.transpose(jnp.concatenate((jnp.expand_dims(x, 0), jax.lax.scan(fn, (x, key), jnp.linspace(0, 1, n_steps, endpoint=False))[1]), 0), (1, 0, 2))
