from enum import Enum
import jax
import jax.numpy as jnp
import numpy as np
import optax
from src.surr_loss import minmax_pgd_residual_fn, minmax_pgd_surrogate_loss
from functools import partial
import chex
from typing import Callable, Tuple

class Alg(str, Enum):
    
    SURR = 'Surr-GD'
    GN = 'GN'
    DGN = 'DGN'
    LM = 'LM'
    ALM = 'ALM'



@chex.dataclass 
class InnerLoopData:
    surrogate_loss: jnp.array
    max_singular_value: jnp.array = -jnp.inf
    min_singular_value: jnp.array = jnp.inf

    def update_svs(self, svs):
        self.max_singular_value = jnp.maximum(jnp.max(jnp.max(jnp.array(svs))), self.max_singular_value)
        self.min_singular_value = jnp.minimum(jnp.min(jnp.min(jnp.array(svs))), self.min_singular_value)


    @property
    def loss_beg(self)-> float:
        return self.surrogate_loss[0]

    @property
    def loss_end(self)-> float:
        return self.surrogate_loss[-1]

    @property
    def min_loss(self)-> float:
        return min(self.surrogate_loss)

    @property
    def max_loss(self)-> float:
        return max(self.surrogate_loss)


def surrogate_inner_loop(params: chex.ArrayTree, surr_loss: Callable, iter: int, 
                         optimizer: optax.GradientTransformation, opt_state: optax.OptState) -> Tuple[chex.ArrayTree, chex.ArrayTree, InnerLoopData]:

    
    data = InnerLoopData(surrogate_loss=jnp.zeros(iter +1))
    def _update(i: int, x: Tuple[chex.ArrayTree, chex.ArrayTree, InnerLoopData]):
        params, opt_state, _data = x

        loss_value, grads = jax.value_and_grad(surr_loss, argnums=0)(params)
        _data.surrogate_loss = _data.surrogate_loss.at[i].set(loss_value) 

        updates, opt_state = optimizer.update(grads, opt_state)
        params = optax.apply_updates(params, updates)
        return (params, opt_state, _data)
    

    params, opt_state, data = jax.lax.fori_loop(0, iter, _update, (params, opt_state, data))
    
    # add last loss
    data.surrogate_loss = data.surrogate_loss.at[iter].set(surr_loss(params))
    return params, opt_state, data

@partial(jax.jit, static_argnums=(2, 3, 4, 5))
def _gn_direction_and_loss(residual, jac_residual, pinv_fn=None, lm_reg=0., adaptive_lm=0., compute_svs=False) -> Tuple[jnp.array, jnp.array]:
    residual = jnp.array(residual)
    _surr_loss_val = 0.5 * jnp.sum(residual**2)

    jac_residual = jnp.array(jac_residual)

    svs = None
    if compute_svs:
       svs = jnp.linalg.svd(jac_residual, full_matrices=False, compute_uv=False)
       svs = svs**2
    
    grad_residual = jac_residual.T @ residual
    approx_hessian = jac_residual.T @ jac_residual

    # if passed custom inverse function then use it
    if pinv_fn:
        direction = jnp.linalg.pinv(jac_residual)@ residual
        return direction, _surr_loss_val, svs

    if adaptive_lm > 0:
        lm_reg = jnp.sqrt(jnp.linalg.norm(adaptive_lm* grad_residual))

    approx_hessian = jax.lax.select(lm_reg > 0, approx_hessian + lm_reg*np.eye(approx_hessian.shape[0]),
                                    approx_hessian)

    condition_mat = jnp.linalg.pinv(approx_hessian)
    return condition_mat @ (grad_residual), _surr_loss_val, svs


def damped_gauss_newton(params, residual_fn, iter, 
                        optimizer: optax.GradientTransformation, opt_state: optax.OptState, jac_fn=None,
                        jac_pinv_fn=None, **gn_kwargs):
    
    data = InnerLoopData(surrogate_loss=jnp.zeros(iter +1))
    def _update(i, x):

        params, opt_state, _data = x
        residuals = residual_fn(params)
        
        if jac_fn:
            jac_residuals = jac_fn(params)
        else:
            jac_residuals = jax.jacobian(residual_fn)(params)        

        gn_directions_losses_svs = [_gn_direction_and_loss(residuals, j, pinv_fn=jac_pinv_fn,  **gn_kwargs) for j in  jac_residuals]
        gn_directions, losses, svs = list(zip(*gn_directions_losses_svs))
        
        loss_value = sum(losses)
        _data.surrogate_loss = _data.surrogate_loss.at[i].set(loss_value) 
        _data.update_svs(svs)
        updates, opt_state = optimizer.update(gn_directions, opt_state)
        params = optax.apply_updates(params, list(updates))
        return params, opt_state, _data

    params, opt_state, data = jax.lax.fori_loop(0, iter, _update, (params, opt_state, data))

    residuals = residual_fn(params)
    _surr_loss_val = 0.5 * jnp.sum(residuals**2)
    data.surrogate_loss = data.surrogate_loss.at[iter].set(_surr_loss_val)
    return params, opt_state, data

def gauss_newton(params, residual_fn, iter, **kwargs):
    optimizer = optax.sgd(learning_rate=1.)
    opt_state = optimizer.init(params)
    compute_svs = kwargs.get('compute_svs', None) 
    params, _, data = damped_gauss_newton(params, residual_fn, iter, optimizer, opt_state, compute_svs=compute_svs)
    return params, None, data

def levenberg_marquardt(params, residual_fn, iter, **kwargs):
    optimizer = optax.sgd(learning_rate=1.)
    opt_state = optimizer.init(params)
    lm_reg = kwargs['lm_reg']
    compute_svs = kwargs.get('compute_svs', None) 
    params, _, data = damped_gauss_newton(params, residual_fn, iter, optimizer, opt_state, lm_reg=lm_reg, compute_svs=compute_svs)
    return params, None, data

def adaptive_levenberg_marquardt(params, residual_fn, iter, **kwargs):
    optimizer = optax.sgd(learning_rate=1.)
    opt_state = optimizer.init(params)
    adaptive_lm = kwargs['adaptive_lm']
    params, _, data = damped_gauss_newton(params, residual_fn, iter, optimizer, opt_state, adaptive_lm=adaptive_lm)
    return params, None, data

# return correct update depending on algorithm

def update_fn_min_max(alg, loss, surr_step, apply_fn, inner_iter, optimizer, **kwargs):

    if alg == Alg.SURR:
        
        def update_params(params, opt_state):
            surrogate_loss = minmax_pgd_surrogate_loss(loss, params, surr_step, apply_fn)
            params, opt_state, data = surrogate_inner_loop(params, surrogate_loss, inner_iter, optimizer, opt_state)
            return params, opt_state, data
        
        return update_params

    elif alg == Alg.GN:

        def update_params(params, opt_state=None):
            residual_fn = minmax_pgd_residual_fn(loss, params, surr_step, apply_fn)
            return  gauss_newton(params, residual_fn, inner_iter, **kwargs)

        return update_params
    
    elif alg == Alg.DGN:

        def update_params(params, opt_state=None):
            residual_fn = minmax_pgd_residual_fn(loss, params, surr_step, apply_fn)
            return  damped_gauss_newton(params, residual_fn, inner_iter, optimizer, opt_state, **kwargs)

        return update_params
    
    elif alg == Alg.LM:

        def update_params(params, opt_state=None):
            residual_fn = minmax_pgd_residual_fn(loss, params, surr_step, apply_fn)
            return  levenberg_marquardt(params, residual_fn, inner_iter, **kwargs)

        return update_params
    
    elif alg == Alg.ALM:

        def update_params(params, opt_state=None):
            residual_fn = minmax_pgd_residual_fn(loss, params, surr_step, apply_fn)
            return  adaptive_levenberg_marquardt(params, residual_fn, inner_iter, **kwargs)

        return update_params


