# =============================================================================
# CONFIDENTIAL - FOR REVIEW ONLY
# This code is submitted as supplementary material for paper review.
# DO NOT DISTRIBUTE - Pending patent application.
# =============================================================================

from typing import List, Optional, Tuple, Union
import warnings
import math

import numpy as np
import torch
import einops


def gen_new_random_variables(
        shape: Tuple[int, ...],
        device: torch.device,
        new_mu: Union[int, float, torch.Tensor, None] = None,
        new_sigma: Union[int, float, torch.Tensor, None] = None,
        rng_generator: Optional[torch.Generator] = None,
):
    if new_mu is None:
        new_mu = torch.zeros(shape, device=device)
    elif isinstance(new_mu, (int, float)):
        new_mu = torch.full(shape, new_mu, device=device)
    elif isinstance(new_mu, torch.Tensor):
        if new_mu.shape == shape[-1:]:
            new_mu = new_mu.unsqueeze(0).expand(shape)
        assert new_mu.shape == shape, f"new_mu.shape = {new_mu.shape}, shape = {shape}"
    else:
        raise ValueError(f"new_mu must be None, int, float, or torch.Tensor, got {type(new_mu)}")
    
    if new_sigma is None:
        new_sigma = torch.ones(shape, device=device)
    elif isinstance(new_sigma, (int, float)):
        new_sigma = torch.full(shape, new_sigma, device=device)
    elif isinstance(new_sigma, torch.Tensor):
        if new_sigma.shape == shape[-1:]:
            new_sigma = new_sigma.unsqueeze(0).expand(shape)
        assert new_sigma.shape == shape, f"new_sigma.shape = {new_sigma.shape}, shape = {shape}"
    else:
        raise ValueError(f"new_sigma must be None, int, float, or torch.Tensor, got {type(new_sigma)}")
    
    new_x = torch.normal(new_mu, new_sigma, generator=rng_generator)
    return new_x, new_mu, new_sigma


def append_new_random_variables(
        x: torch.Tensor,
        x_mu: torch.Tensor,
        x_var: torch.Tensor,
        x_mu_m: Optional[torch.Tensor],
        x_var_m: torch.Tensor,
        *,
        x_mu_rev: Optional[torch.Tensor] = None,
        x_var_rev: Optional[torch.Tensor] = None,
        new_mu: Union[int, float, torch.Tensor, None] = None,
        new_sigma: Union[int, float, torch.Tensor, None] = None,
        rng_generator: Optional[torch.Generator] = None,
):
    assert x.ndim == 3, f"x.ndim = {x.ndim}"
    b, a, m = x.shape
    n = m + new_mu.shape[-1]
    
    new_x, new_mu, new_sigma = gen_new_random_variables(
        x.shape[:-1] + (n - m,), x.device, new_mu, new_sigma, rng_generator
    )

    new_mu = einops.rearrange(new_mu.unfold(1, 2, 1), 'b a n t -> b a t n')
    new_sigma = new_sigma.unfold(1, 2, 1)
    new_var = torch.diag_embed(new_sigma ** 2, dim1=2, dim2=3)

    if x_mu_rev is not None:
        new_mu_rev = torch.cat([x_mu_rev, new_mu], dim=-1)
        new_var_rev = torch.cat([x_var_rev, new_var], dim=-1)
    else:
        new_mu_rev = None
        new_var_rev = None

    new_mu_m = None
    new_var_m = torch.cat([x_var_m, new_var], dim=-1)
    new_x = torch.cat([x, new_x], dim=-1)
    new_mu = torch.cat([x_mu, new_mu], dim=-1)
    new_var = torch.cat([x_var, new_var], dim=-1)

    return new_x, new_mu, new_var, new_mu_m, new_var_m, new_mu_rev, new_var_rev


def apply_tree_op(
        tree_op: torch.Tensor,
        x: torch.Tensor,
        x_mu: torch.Tensor,
        x_var: torch.Tensor,
        x_mu_m: Optional[torch.Tensor],
        x_var_m: torch.Tensor,
        *,
        x_mu_rev: Optional[torch.Tensor] = None,
        x_var_rev: Optional[torch.Tensor] = None,
        disentangle: Union[bool, float] = False,
):
    b, a, n = x.shape
    if tree_op.ndim == 2:
        tree_op = tree_op.expand(b, a, n, n)
    
    y = torch.einsum('...ij,...j->...i', tree_op, x)
    
    tree_op2 = torch.zeros(b, a-1, 2 * n, 2 * n, device=x.device)
    tree_op2[:, :, :n, :n] = tree_op[:, :-1, :, :]
    tree_op2[:, :, n:, n:] = tree_op[:, 1:, :, :]
    
    y_mu = torch.einsum('...ij,...j->...i', tree_op2, einops.rearrange(x_mu, 'b a t n -> b a (t n)'))
    y_var = torch.einsum('...ij,...jk,...lk->...il',
                         tree_op2,
                         einops.rearrange(torch.diag_embed(x_var, dim1=-2, dim2=-1), 'b a t1 t2 n1 n2 -> b a (t1 n1) (t2 n2)'),
                         tree_op2)
    
    if x_mu_rev is not None:
        y_mu_rev = torch.einsum('...ij,...j->...i', tree_op2, einops.rearrange(x_mu_rev, 'b a t n -> b a (t n)'))
        y_var_rev = torch.einsum('...ij,...jk,...lk->...il',
                                 tree_op2,
                                 einops.rearrange(torch.diag_embed(x_var_rev, dim1=-2, dim2=-1), 'b a t1 t2 n1 n2 -> b a (t1 n1) (t2 n2)'),
                                 tree_op2)
    else:
        y_mu_rev = None
        y_var_rev = None
    
    y_var_m = torch.einsum('...ij,...jk,...lk->...il',
                            tree_op2,
                            einops.rearrange(torch.diag_embed(x_var_m, dim1=-2, dim2=-1), 'b a t1 t2 n1 n2 -> b a (t1 n1) (t2 n2)'),
                            tree_op2)

    y_unfolded = einops.rearrange(y.unfold(1, 2, 1), 'b a n t -> b a (t n)')
    
    if (isinstance(disentangle, bool) and disentangle) or (not isinstance(disentangle, bool) and isinstance(disentangle, (float, int))):
        disentangle_add_factor = 1/5 if isinstance(disentangle, bool) else disentangle
        
        center_marginal_var = y_var_m[0:1, :, n-1:n+1, n-1:n+1]
        left_marginal_var = y_var_m[0:1, :, n-2:n, n-2:n]
        right_marginal_var = y_var_m[0:1, :, n:n+2, n:n+2]
        
        left_eigenvalues = torch.linalg.eigvalsh(left_marginal_var)
        right_eigenvalues = torch.linalg.eigvalsh(right_marginal_var)
        avg_eigenvalues = (left_eigenvalues + right_eigenvalues) / 2
        avg_eigenvalues[..., 1] = avg_eigenvalues[..., 1] + disentangle_add_factor * (2 - avg_eigenvalues[..., 1])
        avg_eigenvalues[..., 0] = 2 - avg_eigenvalues[..., 1]
        
        middle_eigenvalues, middle_eigenvectors = torch.linalg.eigh(center_marginal_var)
        disentangler = torch.einsum('...ij,...j,...kj->...ik', middle_eigenvectors, 
                                    torch.sqrt(avg_eigenvalues)/torch.sqrt(middle_eigenvalues), 
                                    middle_eigenvectors)
        
        full_disentangler = torch.zeros(b, a-1, 2 * n, 2 * n, device=x.device)
        full_disentangler[:, :, n-1:n+1, n-1:n+1] = disentangler
        full_disentangler[:, :, :n-1, :n-1] = torch.eye(n-1, device=x.device)
        full_disentangler[:, :, n+1:, n+1:] = torch.eye(n-1, device=x.device)
        
        y_mu = torch.einsum('...ij,...j->...i', full_disentangler, y_mu)
        y_var = torch.einsum('...ij,...jk,...lk->...il', full_disentangler, y_var, full_disentangler)
        
        if y_mu_rev is not None:
            y_mu_rev = torch.einsum('...ij,...j->...i', full_disentangler, y_mu_rev)
            y_var_rev = torch.einsum('...ij,...jk,...lk->...il', full_disentangler, y_var_rev, full_disentangler)
        
        y_var_m = torch.einsum('...ij,...jk,...lk->...il', full_disentangler, y_var_m, full_disentangler)
        y_unfolded = torch.einsum('...ij,...j->...i', full_disentangler, y_unfolded)

    skip_end = n // 2
    y = torch.zeros((b, a * n), device=x.device)
    for i in range(2 * n):
        if i < skip_end:
            y[:, i] = y_unfolded[:, 0, i]
        elif i >= n + skip_end:
            y[:, (a - 2) * n + i] = y_unfolded[:, -1, i]
        else:
            y[:, i:(a-1)*n+skip_end:n] = y_unfolded[:, :, i]

    y_cond_mu = torch.zeros((b, n * a - 1, 2), device=x.device)
    y_cond_var = torch.zeros((b, n * a - 1, 2, 2), device=x.device)

    for i in range(2 * n - 1):
        if i < skip_end:
            y_cond_mu[:, i, :] = y_mu[:, 0, :2]
            y_cond_var[:, i, :, :] = y_var[:, 0, :2, :2]
        elif i >= n + skip_end:
            y_cond_mu[:, (a - 2) * n + i, :] = y_mu[:, -1, :2]
            y_cond_var[:, (a - 2) * n + i, :, :] = y_var[:, -1, :2, :2]
        else:
            y_cond_mu[:, i:(a-1)*n+skip_end:n, :] = y_mu[:, :, :2]
            y_cond_var[:, i:(a-1)*n+skip_end:n, :, :] = y_var[:, :, :2, :2]

        if i < 2 * n - 2:
            y_mu = y_mu[:, :, 1:] + y_var[:, :, 1:, 0] / y_var[:, :, 0, 0:1] * (y_unfolded[:, :, i:i+1] - y_mu[:, :, 0:1])
            y_var = y_var[:, :, 1:, 1:] - y_var[:, :, 1:, 0:1] * y_var[:, :, 0:1, 1:] / y_var[:, :, 0:1, 0:1]

    if y_mu_rev is not None:
        y_cond_mu_rev = torch.zeros((b, n * a - 1, 2), device=x.device)
        y_cond_var_rev = torch.zeros((b, n * a - 1, 2, 2), device=x.device)

        for i in range(2 * n - 2, -1, -1):
            if i < skip_end:
                y_cond_mu_rev[:, i, :] = y_mu_rev[:, 0, -2:]
                y_cond_var_rev[:, i, :, :] = y_var_rev[:, 0, -2:, -2:]
            elif i >= n + skip_end:
                y_cond_mu_rev[:, (a - 2) * n + i, :] = y_mu_rev[:, -1, -2:]
                y_cond_var_rev[:, (a - 2) * n + i, :, :] = y_var_rev[:, -1, -2:, -2:]
            else:
                y_cond_mu_rev[:, i:(a-1)*n+skip_end:n, :] = y_mu_rev[:, :, -2:]
                y_cond_var_rev[:, i:(a-1)*n+skip_end:n, :, :] = y_var_rev[:, :, -2:, -2:]

            if i > 0:
                y_mu_rev = y_mu_rev[:, :, :-1] + y_var_rev[:, :, -1, :-1] / y_var_rev[:, :, -1, -1:] * (y_unfolded[:, :, i+1:i+2] - y_mu_rev[:, :, -1:])
                y_var_rev = y_var_rev[:, :, :-1, :-1] - y_var_rev[:, :, :-1, -1:] * y_var_rev[:, :, -1:, :-1] / y_var_rev[:, :, -1:, -1:]
    else:
        y_cond_mu_rev = None
        y_cond_var_rev = None

    y_margin_var = torch.zeros((b, n * a - 1, 2, 2), device=x.device)
    for i in range(2 * n - 1):
        if i < skip_end:
            y_margin_var[:, i, :, :] = y_var_m[:, 0, :2, :2]
        elif i >= n + skip_end:
            y_margin_var[:, (a - 2) * n + i, :, :] = y_var_m[:, -1, :2, :2]
        else:
            y_margin_var[:, i:(a-1)*n+skip_end:n, :, :] = y_var_m[:, :, :2, :2]
        if i < 2 * n - 2:
            y_var_m = y_var_m[:, :, 1:, 1:]

    return y, y_cond_mu, y_cond_var, None, y_margin_var, y_cond_mu_rev, y_cond_var_rev


def single_tree_layer(
        tree_op: torch.Tensor,
        x: torch.Tensor,
        x_mu: torch.Tensor,
        x_var: torch.Tensor,
        x_mu_m: Optional[torch.Tensor],
        x_var_m: torch.Tensor,
        *,
        x_mu_rev: Optional[torch.Tensor] = None,
        x_var_rev: Optional[torch.Tensor] = None,
        new_mu: Union[int, float, torch.Tensor, None] = None,
        new_sigma: Union[int, float, torch.Tensor, None] = None,
        disentangle: Union[bool, float] = False,
        rng_generator: Optional[torch.Generator] = None,
):
    new_x, new_mu, new_var, new_mu_m, new_var_m, new_mu_rev, new_var_rev = append_new_random_variables(
        x, x_mu, x_var, x_mu_m, x_var_m,
        x_mu_rev=x_mu_rev, x_var_rev=x_var_rev,
        new_mu=new_mu, new_sigma=new_sigma, rng_generator=rng_generator
    )

    y, y_mu, y_var, y_mu_m, y_var_m, y_mu_rev, y_var_rev = apply_tree_op(
        tree_op, new_x, new_mu, new_var, new_mu_m, new_var_m,
        x_mu_rev=new_mu_rev, x_var_rev=new_var_rev, disentangle=disentangle
    )

    return y, y_mu, y_var, y_mu_m, y_var_m, y_mu_rev, y_var_rev


def multi_tree_layers(
        tree_ops: Union[torch.Tensor, List[torch.Tensor]],
        *,
        m: Optional[int] = None,
        n: Optional[int] = None,
        b: Optional[int] = None,
        num_layers: Optional[int] = None,
        initial_mu: Union[int, float, torch.Tensor, None] = None,
        initial_sigma: Union[int, float, torch.Tensor, None] = None,
        new_mus: Union[int, float, torch.Tensor, List, None] = None,
        new_sigmas: Union[int, float, torch.Tensor, List, None] = None,
        device: Optional[torch.device] = None,
        rng_generator: Optional[torch.Generator] = None,
        cal_reverse: bool = True,
        disentangle: Union[bool, float] = True,
        return_more: bool = True,
):
    if m is None:
        if isinstance(initial_mu, torch.Tensor):
            m = initial_mu.shape[-1]
        elif isinstance(initial_sigma, torch.Tensor):
            m = initial_sigma.shape[-1]
        else:
            raise ValueError("m must be provided or inferred from initial_mu/initial_sigma")
    
    if n is None:
        if isinstance(tree_ops, torch.Tensor):
            n = tree_ops.shape[-1]
        elif isinstance(tree_ops, list):
            n = tree_ops[0].shape[-1]
        else:
            raise ValueError("n must be provided or inferred from tree_ops")
    
    if b is None:
        if num_layers is not None:
            b = m ** (num_layers - 1)
        else:
            raise ValueError("b must be provided or inferred from num_layers")
    
    if num_layers is None:
        if isinstance(tree_ops, list):
            num_layers = len(tree_ops)
        elif b is not None and m is not None and m > 1:
            num_layers = round(math.log(b, m)) + 1
        else:
            raise ValueError("num_layers must be provided or inferred from tree_ops/batch_size")
    
    if not isinstance(tree_ops, list):
        tree_ops = [tree_ops] * num_layers
    if not isinstance(new_mus, list):
        new_mus = [new_mus] * num_layers
    if not isinstance(new_sigmas, list):
        new_sigmas = [new_sigmas] * num_layers
    
    if device is None:
        device = tree_ops[0].device
    
    if new_mus[0] is None:
        new_mus = [torch.zeros((n - m), device=device) for _ in range(num_layers)]
    if new_sigmas[0] is None:
        new_sigmas = [torch.ones((n - m), device=device) for _ in range(num_layers)]

    x_init, init_mu, init_sigma = gen_new_random_variables((b, m), device, initial_mu, initial_sigma)
    x = x_init.reshape(b, 1, m)
    x_mu = init_mu.reshape(b, 1, m)
    x_sigma = init_sigma.reshape(b, 1, m)
    
    new_x, new_mu, new_sigma = gen_new_random_variables(
        shape=(b, 1, n - m), device=device, new_mu=new_mus[0], new_sigma=new_sigmas[0], rng_generator=rng_generator
    )
    x = torch.cat([x, new_x], dim=-1)
    x_mu = torch.cat([x_mu, new_mu], dim=-1)
    x_sigma = torch.cat([x_sigma, new_sigma], dim=-1)
    x_var = torch.diag_embed(x_sigma ** 2, dim1=-1, dim2=-2)
    
    y = torch.einsum('...ij,...j->...i', tree_ops[0], x)
    y_mu = torch.einsum('...ij,...j->...i', tree_ops[0], x_mu)
    y_var = torch.einsum('...ij,...jk,...lk->...il', tree_ops[0], x_var, tree_ops[0])
    
    if cal_reverse:
        y_mu_rev = y_mu.clone()
        y_var_rev = y_var.clone()
    else:
        y_mu_rev = None
        y_var_rev = None
    
    y_mu_m = y_mu.clone()
    y_var_m = y_var.clone()
    
    y_cond_mu = torch.zeros((b, n - 1, 2), device=device)
    y_cond_var = torch.zeros((b, n - 1, 2, 2), device=device)
    for i in range(n - 1):
        y_cond_mu[:, i, :] = y_mu[:, 0, :2]
        y_cond_var[:, i, :, :] = y_var[:, 0, :2, :2]
        if i < n - 2:
            y_mu = y_mu[:, :, 1:] + y_var[:, :, 1:, 0] / y_var[:, :, 0, 0:1] * (y[:, :, i:i+1] - y_mu[:, :, 0:1])
            y_var = y_var[:, :, 1:, 1:] - y_var[:, :, 1:, 0:1] * y_var[:, :, 0:1, 1:] / y_var[:, :, 0:1, 0:1]
    
    if cal_reverse:
        y_cond_mu_rev = torch.zeros((b, n - 1, 2), device=device)
        y_cond_var_rev = torch.zeros((b, n - 1, 2, 2), device=device)
        for i in range(n - 2, -1, -1):
            y_cond_mu_rev[:, i, :] = y_mu_rev[:, -1, -2:]
            y_cond_var_rev[:, i, :, :] = y_var_rev[:, -1, -2:, -2:]
            if i > 0:
                y_mu_rev = y_mu_rev[:, :, :-1] + y_var_rev[:, :, -1, :-1] / y_var_rev[:, :, -1, -1:] * (y[:, :, i+1:i+2] - y_mu_rev[:, :, -1:])
                y_var_rev = y_var_rev[:, :, :-1, :-1] - y_var_rev[:, :, :-1, -1:] * y_var_rev[:, :, -1:, :-1] / y_var_rev[:, :, -1:, -1:]
    else:
        y_cond_mu_rev = None
        y_cond_var_rev = None
    
    y_margin_mu = torch.zeros((b, n - 1, 2), device=device)
    y_margin_var = torch.zeros((b, n - 1, 2, 2), device=device)
    for i in range(n - 1):
        y_margin_mu[:, i, :] = y_mu_m[:, 0, i:i+2]
        y_margin_var[:, i, :, :] = y_var_m[:, 0, i:i+2, i:i+2]

    y = einops.rearrange(y, 'b a n -> b (a n)')

    for i in range(1, num_layers):
        y = einops.rearrange(y, '(b m) a -> b a m', m=m)
        y_cond_mu = einops.rearrange(y_cond_mu, '(b m) a t -> b a t m', m=m)
        y_cond_var = einops.rearrange(y_cond_var, '(b m) a t1 t2 -> b a t1 t2 m', m=m)
        y_margin_mu = None
        y_margin_var = einops.rearrange(y_margin_var, '(b m) a t1 t2 -> b a t1 t2 m', m=m)
        if cal_reverse:
            y_cond_mu_rev = einops.rearrange(y_cond_mu_rev, '(b m) a t -> b a t m', m=m)
            y_cond_var_rev = einops.rearrange(y_cond_var_rev, '(b m) a t1 t2 -> b a t1 t2 m', m=m)
        
        y, y_cond_mu, y_cond_var, y_margin_mu, y_margin_var, y_cond_mu_rev, y_cond_var_rev = single_tree_layer(
            tree_ops[i], y, y_cond_mu, y_cond_var, y_margin_mu, y_margin_var,
            x_mu_rev=y_cond_mu_rev, x_var_rev=y_cond_var_rev,
            new_mu=new_mus[i], new_sigma=new_sigmas[i],
            disentangle=disentangle, rng_generator=rng_generator
        )
    
    y_cond_mu_final = torch.zeros_like(y)
    y_cond_sigma_final = torch.zeros_like(y)
    y_cond_mu_final[:, :-1] = y_cond_mu[:, :, 0]
    y_cond_sigma_final[:, :-1] = torch.sqrt(y_cond_var[:, :, 0, 0])
    y_cond_mu_final[:, -1] = y_cond_mu[:, -1, 1] + y_cond_var[:, -1, 1, 0] / y_cond_var[:, -1, 0, 0] * (y[:, -2] - y_cond_mu[:, -1, 0])
    y_cond_sigma_final[:, -1] = torch.sqrt(y_cond_var[:, -1, 1, 1] - y_cond_var[:, -1, 1, 0] * y_cond_var[:, -1, 0, 1] / y_cond_var[:, -1, 0, 0])
    
    if cal_reverse:
        y_cond_mu_rev_final = torch.zeros_like(y)
        y_cond_sigma_rev_final = torch.zeros_like(y)
        y_cond_mu_rev_final[:, 1:] = y_cond_mu_rev[:, :, 1]
        y_cond_sigma_rev_final[:, 1:] = torch.sqrt(y_cond_var_rev[:, :, 1, 1])
        y_cond_mu_rev_final[:, 0] = y_cond_mu_rev[:, 0, 0] + y_cond_var_rev[:, 0, 0, 1] / y_cond_var_rev[:, 0, 1, 1] * (y[:, 1] - y_cond_mu_rev[:, 0, 1])
        y_cond_sigma_rev_final[:, 0] = torch.sqrt(y_cond_var_rev[:, 0, 0, 0] - y_cond_var_rev[:, 0, 0, 1] * y_cond_var_rev[:, 0, 1, 0] / y_cond_var_rev[:, 0, 1, 1])
    else:
        y_cond_mu_rev_final = None
        y_cond_sigma_rev_final = None

    if return_more:
        return y, y_cond_mu_final, y_cond_sigma_final, y_cond_mu_rev_final, y_cond_sigma_rev_final, y_cond_mu, y_cond_var, y_margin_mu, y_margin_var, y_cond_mu_rev, y_cond_var_rev
    else:
        return y, y_cond_mu_final, y_cond_sigma_final, y_cond_mu_rev_final, y_cond_sigma_rev_final


def gen_streach_matrix2(lambda_1: float = 1., lambda_2: float = 0.5, lambda_3: float = 0.5, lambda_4: float = 0.5) -> torch.Tensor:
    v4_n = torch.tensor([1, 1, 1, 1]) / 2
    v3_n = torch.tensor([1, -1, 1, -1]) / 2
    v2_n = torch.tensor([1, 1, -1, -1]) / 2
    v1_n = torch.tensor([1, -1, -1, 1]) / 2
    Q = torch.column_stack((v1_n, v2_n, v3_n, v4_n))
    D = torch.tensor([lambda_1, lambda_2, lambda_3, lambda_4])
    A = Q * np.sqrt(D)
    return A


def gen_streach_matrix_many(*lambdas) -> torch.Tensor:
    if len(lambdas) == 1 and isinstance(lambdas[0], (list, tuple, np.ndarray, torch.Tensor)):
        lambdas = lambdas[0]
    
    for i in range(int(np.log2(len(lambdas)) + 1)):
        if len(lambdas) == 2**i:
            nb_krons = i
            break
    else:
        raise ValueError(f"The length of lambdas must be a power of 2, got {len(lambdas)}")
    
    hadamard_matrix = torch.tensor([[1, 1], [1, -1]], dtype=torch.float32)
    Q = torch.eye(1, dtype=torch.float32)
    for i in range(nb_krons):
        Q = torch.kron(Q, hadamard_matrix)
    D = torch.tensor(lambdas, dtype=torch.float32)
    D /= torch.sum(D)
    A = Q * torch.sqrt(D)
    return A


def log_prob_normal(x: torch.Tensor, mu: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
    return -0.5 * math.log(2 * math.pi) - torch.log(sigma) - 0.5 * ((x - mu) / sigma)**2


def kl_divergence(mu1: torch.Tensor, sigma1: torch.Tensor, mu2: torch.Tensor, sigma2: torch.Tensor) -> torch.Tensor:
    return torch.log(sigma2 / sigma1) + (sigma1 ** 2 + (mu1 - mu2) ** 2) / (2 * sigma2 ** 2) - 0.5


def entropy_normal(sigma: torch.Tensor) -> torch.Tensor:
    return 0.5 * math.log(2 * math.pi * math.e) + torch.log(sigma)


def mutual_info_normal(sigma: torch.Tensor, sigma_rev: torch.Tensor) -> torch.Tensor:
    cond_entropy = entropy_normal(sigma)
    rev_cond_entropy = entropy_normal(sigma_rev)
    total_entropy = cond_entropy.sum(-1, keepdim=True)
    partial_cond_entropy = torch.cumsum(cond_entropy, dim=-1)[..., :-1]
    partial_rev_cond_entropy = torch.cumsum(rev_cond_entropy.flip(dims=(-1,)), dim=-1).flip(dims=(-1,))[..., 1:]
    return partial_cond_entropy + partial_rev_cond_entropy - total_entropy

