from benchmark_utils.stochastic_jax_solver import StochasticJaxSolver

from benchopt import safe_import_context

with safe_import_context() as import_ctx:
    import jax
    import jax.numpy as jnp
    from functools import partial

    from benchmark_utils.sgd_inner import sgd_inner_jax, szgd_inner_jax
    from benchmark_utils.hessian_approximation import shia_jax
    from benchmark_utils.learning_rate_scheduler import update_lr
    from benchmark_utils.learning_rate_scheduler import init_lr_scheduler
    import jaxopt.prox as prox

def clip_vector_jax(grad_vector, max_norm=10):
    norm = jnp.linalg.norm(grad_vector)
    
    clip_coef = jnp.minimum(1, max_norm / (norm + 1e-8))
    
    grad_vector = grad_vector * clip_coef
        
    return grad_vector

class Solver(StochasticJaxSolver):
    name = 'ZDSBA'

    # any parameter defined here is accessible as a class attribute
    parameters = {
        'step_size': [.1],
        'outer_ratio': [1.],
        'n_inner_steps': [100],
        'batch_size': [64],
        'n_shia_steps': [100],
        'n_gaussian_vectors': [1],
        'mu': [.1],
        'random_state': [1],
        'large_batch_size': [256],
        **StochasticJaxSolver.parameters
    }

    def init(self):
        # Init variables
        self.inner_var = self.inner_var0.copy()
        self.outer_var = self.outer_var0.copy()

        step_sizes = jnp.array(
            [self.step_size, self.step_size,
             self.step_size / self.outer_ratio]
        )
        exponents = jnp.zeros(3)
        state_lr = init_lr_scheduler(step_sizes, exponents)

        return dict(
            inner_var=self.inner_var, outer_var=self.outer_var,
            state_lr=state_lr,
            state_inner_sampler=self.state_inner_sampler,
            state_outer_sampler=self.state_outer_sampler,
            key=jax.random.PRNGKey(self.random_state),
            query=0
        )
    def reset(self, carry):
        # carry['query'] = 0
        return carry
    def get_step(self, inner_sampler, outer_sampler):
        grad_inner = jax.grad(self.f_inner, argnums=0)
        grad_outer = jax.grad(self.f_outer, argnums=(0, 1))

        sgd_inner = partial(
            szgd_inner_jax, n_steps=self.n_inner_steps, sampler=inner_sampler,
            k = self.n_gaussian_vectors, mu=self.mu
        )

        # shia = partial(
        #     shia_jax, n_steps=self.n_shia_steps, grad_inner=grad_inner,
        #     sampler=inner_sampler
        # )

        def stocbio_one_iter(carry, _):
            (inner_lr, hia_lr, outer_lr), carry['state_lr'] = update_lr(
                carry['state_lr']
            )

            carry['inner_var'], carry['state_inner_sampler'] = sgd_inner(
                self.f_inner, carry['inner_var'], carry['outer_var'],
                carry['state_inner_sampler'], step_size=inner_lr, key=carry['key'])
            
            carry['query'] += 2*self.n_gaussian_vectors*self.n_inner_steps
            
            # jax.debug.print('inner is {}', carry['inner_var'])
            carry['key'] = jax.random.split(carry['key'], 1)[0]
            # Sample gaussian vectors and perform ES estimation
            U_o = jax.random.normal(carry['key'], (self.n_gaussian_vectors, carry['outer_var'].shape[0]))
            outer_var_u = carry['outer_var'].copy() + self.mu * U_o
            outer_var_i = carry['outer_var'].copy() - self.mu * U_o
            U_i = jax.random.normal(carry['key'], (self.n_gaussian_vectors, carry['inner_var'].shape[0]))
            inner_var_u = carry['inner_var'].copy() + self.mu * U_i
            inner_var_i = carry['inner_var'].copy() - self.mu * U_i
            start_outer, *_, carry['state_outer_sampler'] = outer_sampler(
                carry['state_outer_sampler']
            )
            start_inner, *_, carry['state_inner_sampler'] = inner_sampler(
                carry['state_inner_sampler']
            )
            inner_value = self.f_inner(
                    carry['inner_var'], carry['outer_var'], start_inner
            )
            def iter_fun(q, init_carry):
                (deltas_out,deltas_th) = init_carry
                outer_value_u = self.f_outer(
                    inner_var_u[q], outer_var_u[q], start_outer
                )
                outer_value_i = self.f_outer(
                    inner_var_i[q], outer_var_i[q], start_outer
                )
            
                deltas_out = deltas_out.at[q].set(outer_value_u - outer_value_i)
                
                inner_value_u = self.f_inner(
                    inner_var_u[q], outer_var_u[q], start_inner
                )
                inner_value_i = self.f_inner(
                    inner_var_i[q], outer_var_i[q], start_inner
                )

                deltas_th = deltas_th.at[q].set(inner_value_u + inner_value_i - 2*inner_value)
                return (deltas_out,deltas_th)

            deltas_out,deltas_th = jax.lax.fori_loop(0, self.n_gaussian_vectors, iter_fun,  
                                                               (jnp.zeros((self.n_gaussian_vectors,1)),                               
                                                                jnp.zeros((self.n_gaussian_vectors,1))))
            carry['query'] += 4*self.n_gaussian_vectors+1
            
            deltas_out = deltas_out / (2*self.mu)
            deltas_th = deltas_th / (self.mu**2)
            es_estimator_fx = (U_o.T.dot(deltas_out)/self.n_gaussian_vectors).squeeze(1)
            es_estimator_fy = (U_i.T.dot(deltas_out)/self.n_gaussian_vectors).squeeze(1)
            
            def iter_fun_szhia(q, v):
                es_estimator_gyyv = (deltas_th*U_i).T.dot(U_i@ jnp.expand_dims(v, axis=1)).squeeze(1) \
                    - jnp.sum(deltas_th*(jnp.expand_dims(v, axis=1).T), axis=0)
                # jax.debug.print("A {}",(deltas_th*U_i).T.dot(U_i@ jnp.expand_dims(v, axis=1)).squeeze(1).shape)
                # jax.debug.print("B {}",(deltas_th*v.reshape(-1)).shape)
                v -= hia_lr * clip_vector_jax( es_estimator_gyyv + es_estimator_fy)
                norm = jnp.linalg.norm(v)
                clip_coef = jnp.minimum(1, 10 / norm)
                v = v * clip_coef
                return v
            v = jax.lax.fori_loop(0, self.n_shia_steps, iter_fun_szhia, jnp.zeros_like(self.inner_var))
            
            es_estimator_gxyv = (deltas_th*U_o).T.dot(U_i@ jnp.expand_dims(v, axis=1)).squeeze(1)
            # jax.debug.print('v is {}', v)


            carry['outer_var'] -= outer_lr * clip_vector_jax(es_estimator_gxyv + es_estimator_fx)
            # carry['outer_var'] = jnp.clip(carry['outer_var'],min=-5,max=5)
            # jax.debug.print('outer is {}', carry['outer_var'])
            carry['outer_var'] = prox.prox_lasso(carry['outer_var'], l1reg=0.1)

            return carry, _
        return stocbio_one_iter
