import math
from functools import partial

import jax
import jax.numpy as jnp
from jax import lax


@jax.jit
def operator_norm_2(A_prev, k, A, B, C):
    A_k = A @ A_prev
    norm = jnp.linalg.norm(C @ A_k @ B, ord=None)**2
    return A_k, norm


@jax.jit
def operator_norm_1(A_prev, k, A, B, C):
    A_k = A @ A_prev
    norm = jnp.linalg.norm(C @ A_k @ B, ord=1, axis=1)
    return A_k, norm


def get_bound_relu_nonet(state, N, K_u=1,
                         L_l=jnp.sqrt(2).item(), K_l=1, delta=0.5,
                         alpha=0, k_limit=1000, complex_ssms=False):
    metrics = {}
    K_enc = jnp.linalg.norm(
        state.params['encoder']['encoder']['kernel'], ord=2)
    K_dec = jnp.linalg.norm(state.params['decoder']['kernel'], ord=jnp.inf)
    metrics['encoder'] = K_enc
    metrics['decoder'] = K_dec
    metrics['relu'] = 1
    # K_relu = 1
    # mu, c = 1, 0
    use_norm_2 = True
    for key in state.params['encoder'].keys():
        if key != 'encoder':
            ssm_params = state.params['encoder'][key]['seq']

            A = jnp.diag(jnp.exp(-jnp.exp(ssm_params['nu_log'])))
            # + 1j * jnp.exp(self.theta_log))
            B = ssm_params['B_re'] * jnp.expand_dims(jnp.exp(ssm_params['gamma_log']),
                                                     axis=-1)
            C = ssm_params['C_re']
            D = jnp.expand_dims(ssm_params['D'], 1)

            if use_norm_2:
                _, traj = lax.scan(partial(operator_norm_2,
                                           A=A, B=B, C=C),
                                   init=A,
                                   xs=jnp.arange(k_limit))
                K = jnp.sqrt(jnp.linalg.norm(D, ord='fro')**2 + jnp.sum(traj))
            else:
                _, traj = lax.scan(partial(operator_norm_1,
                                           A=A, B=B, C=C),
                                   init=A,
                                   xs=jnp.arange(k_limit))
                K = jnp.max(jnp.linalg.norm(D, ord=1, axis=1) + traj)

            # mu_block = K + alpha
            metrics[key] = K + alpha
            # mu = K_enc * mu_block * K_dec

            use_norm_2 = False

    mu = math.prod(list(metrics.values()))
    bound = mu * K_u * L_l + K_l * jnp.sqrt(2 * jnp.log(4 / delta))
    bound = bound / jnp.sqrt(N)
    return bound, metrics


def get_bound_relu_nonet_cifar(state, N, K_u=1,
                         L_l=jnp.sqrt(2).item(), K_l=1, delta=0.5,
                         alpha=0, k_limit=1000, complex_ssms=False):
    metrics = {}
    K_enc = jnp.linalg.norm(
        state.params['encoder']['encoder']['kernel'], ord=2)
    K_dec = jnp.linalg.norm(state.params['decoder']['kernel'], ord=jnp.inf)
    metrics['encoder'] = K_enc
    metrics['decoder'] = K_dec
    metrics['relu'] = 1
    # K_relu = 1
    # mu, c = 1, 0
    use_norm_2 = True
    for key in state.params['encoder'].keys():
        if key != 'encoder':
            ssm_params = state.params['encoder'][key]['seq']

            A = jnp.diag(1 / (1 + ssm_params['nu_log']**2))
            # + 1j * jnp.exp(self.theta_log))
            B = ssm_params['B_re'] * jnp.expand_dims(jnp.exp(ssm_params['gamma_log']),
                                                     axis=-1)
            C = ssm_params['C_re']
            D = jnp.expand_dims(ssm_params['D'], 1)

            if use_norm_2:
                _, traj = lax.scan(partial(operator_norm_2,
                                           A=A, B=B, C=C),
                                   init=A,
                                   xs=jnp.arange(k_limit))
                K = jnp.sqrt(jnp.linalg.norm(D, ord='fro')**2 + jnp.sum(traj))
            else:
                _, traj = lax.scan(partial(operator_norm_1,
                                           A=A, B=B, C=C),
                                   init=A,
                                   xs=jnp.arange(k_limit))
                K = jnp.max(jnp.linalg.norm(D, ord=1, axis=1) + traj)

            # mu_block = K + alpha
            metrics[key] = K + alpha
            # mu = K_enc * mu_block * K_dec

            use_norm_2 = False

    mu = math.prod(list(metrics.values()))
    bound = mu * K_u * L_l + K_l * jnp.sqrt(2 * jnp.log(4 / delta))
    bound = bound / jnp.sqrt(N)
    return bound, metrics
