from benchmark_utils.stochastic_jax_solver import StochasticJaxSolver

from benchopt import safe_import_context

with safe_import_context() as import_ctx:
    import jax
    import jax.numpy as jnp
    from functools import partial

    from benchmark_utils.sgd_inner import sgd_inner_jax, szgd_inner_jax
    from benchmark_utils.hessian_approximation import shia_jax
    from benchmark_utils.learning_rate_scheduler import update_lr
    from benchmark_utils.learning_rate_scheduler import init_lr_scheduler
    import jaxopt.prox as prox

def clip_vector_jax(grad_vector, max_norm=100):
    norm = jnp.linalg.norm(grad_vector)
    
    clip_coef = jnp.minimum(1, max_norm / (norm + 1e-8))
    
    grad_vector = grad_vector * clip_coef
        
    return grad_vector

class Solver(StochasticJaxSolver):
    name = 'HOZOG'

    # any parameter defined here is accessible as a class attribute
    parameters = {
        'step_size': [.1],
        'outer_ratio': [1.],
        'n_inner_steps': [100],
        'batch_size': [64],
        'n_shia_steps': [10],
        'n_gaussian_vectors': [1],
        'mu': [.1],
        'random_state': [1],
        'large_batch_size': [256],
        **StochasticJaxSolver.parameters
    }

    def init(self):
        # Init variables
        self.inner_var = self.inner_var0.copy()
        self.outer_var = self.outer_var0.copy()

        step_sizes = jnp.array(
            [self.step_size, self.step_size,
             self.step_size / self.outer_ratio]
        )
        exponents = jnp.zeros(3)
        state_lr = init_lr_scheduler(step_sizes, exponents)

        return dict(
            inner_var=self.inner_var, outer_var=self.outer_var,
            state_lr=state_lr,
            state_inner_sampler=self.state_inner_sampler,
            state_outer_sampler=self.state_outer_sampler,
            key=jax.random.PRNGKey(self.random_state),
            query=0
        )
    def reset(self, carry):
        # carry['query'] = 0
        return carry
    def get_step(self, inner_sampler, outer_sampler):
        grad_inner = jax.grad(self.f_inner, argnums=0)
        grad_outer = jax.grad(self.f_outer, argnums=(0, 1))

        sgd_inner = partial(
            szgd_inner_jax, n_steps=self.n_inner_steps, sampler=inner_sampler,
            k = self.n_gaussian_vectors, mu=self.mu
        )


        def stocbio_one_iter(carry, _):
            (inner_lr, hia_lr, outer_lr), carry['state_lr'] = update_lr(
                carry['state_lr']
            )

            inner_origial = carry['inner_var'].copy()
            carry['inner_var'], carry['state_inner_sampler'] = sgd_inner(
                self.f_inner, carry['inner_var'], carry['outer_var'],
                carry['state_inner_sampler'], step_size=inner_lr, key=carry['key'])
            
            carry['query'] += 2*self.n_gaussian_vectors*self.n_inner_steps
            
            # jax.debug.print('inner is {}', carry['inner_var'])
            carry['key'] = jax.random.split(carry['key'], 1)[0]
            # Sample gaussian vectors and perform ES estimation
            U_o = jax.random.normal(carry['key'], (self.n_gaussian_vectors, carry['outer_var'].shape[0]))
            outer_var_u = carry['outer_var'].copy() + self.mu * U_o
            
            start_outer, *_, carry['state_outer_sampler'] = outer_sampler(
                carry['state_outer_sampler']
            )
            start_inner, *_, carry['state_inner_sampler'] = inner_sampler(
                carry['state_inner_sampler']
            )
            
            outer_value = self.f_outer(
                    carry['inner_var'], carry['outer_var'], start_outer
                )
            carry['query'] += 1
            def iter_fun(q, init_carry):
                deltas_out = init_carry
                
                inner_var_new, _ = sgd_inner(
                    self.f_inner, inner_origial, outer_var_u[q],
                    carry['state_inner_sampler'], step_size=inner_lr, key=carry['key'])
                
                outer_value_u = self.f_outer(
                    inner_var_new, outer_var_u[q], start_outer
                )

                deltas_out = deltas_out.at[q].set(outer_value_u - outer_value)
                return deltas_out

            deltas_out = jax.lax.fori_loop(0, self.n_gaussian_vectors, iter_fun,  
                                                               jnp.zeros((self.n_gaussian_vectors,1)))
            carry['query'] += self.n_gaussian_vectors*(1+self.n_inner_steps*2)
            
            deltas_out = deltas_out / self.mu
            es_estimator_fx = (U_o.T.dot(deltas_out)/self.n_gaussian_vectors).squeeze(1)

            carry['outer_var'] -= outer_lr * es_estimator_fx
            # carry['outer_var'] = jnp.clip(carry['outer_var'],min=-5,max=5)
            carry['outer_var'] = prox.prox_lasso(carry['outer_var'], l1reg=0.1)
            # jax.debug.print('outer is {}', carry['outer_var'])

            return carry, _
        return stocbio_one_iter
