from benchmark_utils.stochastic_jax_solver import StochasticJaxSolver

from benchopt import safe_import_context

with safe_import_context() as import_ctx:
    from benchmark_utils.learning_rate_scheduler import update_lr
    from benchmark_utils.learning_rate_scheduler import init_lr_scheduler

    import jax
    import jax.numpy as jnp
    import jaxopt.prox as prox
def clip_vector_jax(grad_vector, max_norm=10):
    norm = jnp.linalg.norm(grad_vector)
    
    clip_coef = jnp.minimum(1, max_norm / (norm + 1e-8))
    
    grad_vector = grad_vector * clip_coef
        
    return grad_vector

class Solver(StochasticJaxSolver):
    name = 'ZOSOBA'

    # any parameter defined here is accessible as a class attribute
    parameters = {
        'step_size': [.1],
        'outer_ratio': [1.],
        'batch_size': [64],
        'large_batch_size': [256],
        'n_gaussian_vectors': [1],
        'mu': [.1],
        'rz': [10.],
        'random_state': [1],
        **StochasticJaxSolver.parameters
    }
    def reset(self, carry):
        # carry['query'] = 0
        return carry
    def init(self):
        # Init variables
        self.inner_var = self.inner_var0.copy()
        self.outer_var = self.outer_var0.copy()
        v = jnp.zeros_like(self.inner_var)
        self.outer_var_shape = self.outer_var.shape[0]
        self.inner_var_shape = self.inner_var.shape[0]
        # Init lr scheduler
        step_sizes = jnp.array(
            [self.step_size, self.step_size / self.outer_ratio]
        )
        exponents = jnp.array(
            [0.3, 0.3]
        )
        state_lr = init_lr_scheduler(step_sizes, exponents)
        return dict(
            inner_var=self.inner_var, outer_var=self.outer_var, v=v,
            state_lr=state_lr,
            state_inner_sampler=self.state_inner_sampler,
            state_outer_sampler=self.state_outer_sampler,
            key=jax.random.PRNGKey(self.random_state),
            query=0
        )

    def approximate_gradient(self, inner_var, outer_var, v, key, 
                             start_inner, start_outer, k):
        key = jax.random.split(key, 1)[0]
        # Sample gaussian vectors and perform ES estimation
        U_o = jax.random.normal(key, (k, self.outer_var_shape))
        outer_var_u = outer_var + self.mu * U_o
        outer_var_i = outer_var - self.mu * U_o
        U_i = jax.random.normal(key, (k, self.inner_var_shape))
        inner_var_u = inner_var + self.mu * U_i
        inner_var_i = inner_var - self.mu * U_i
        
        inner_value = self.f_inner(
                inner_var, outer_var, start_inner
        )
        
        def iter_fun(q, init_carry):
            (deltas_out,deltas_in,deltas_th) = init_carry
            outer_value_u = self.f_outer(
                inner_var_u[q], outer_var_u[q], start_outer
            )
            outer_value_i = self.f_outer(
                inner_var_i[q], outer_var_i[q], start_outer
            )
        
            deltas_out = deltas_out.at[q].set(outer_value_u - outer_value_i)
            
            inner_value_u = self.f_inner(
                inner_var_u[q], outer_var_u[q], start_inner
            )
            inner_value_i = self.f_inner(
                inner_var_i[q], outer_var_i[q], start_inner
            )
            deltas_in = deltas_in.at[q].set(inner_value_u - inner_value_i)

            deltas_th = deltas_th.at[q].set(inner_value_u + inner_value_i - 2*inner_value)
            return (deltas_out,deltas_in,deltas_th)

        deltas_out,deltas_in,deltas_th = jax.lax.fori_loop(0, k, iter_fun,  
                                                            (jnp.zeros((k,1)),
                                                            jnp.zeros((k,1)),
                                                            jnp.zeros((k,1))))
        
        deltas_out = deltas_out / (2*self.mu)
        deltas_in = deltas_in / (2*self.mu)
        deltas_th = deltas_th / (self.mu**2)
        es_estimator_fx = (U_o.T.dot(deltas_out)/k).squeeze(1)
        es_estimator_fy = (U_i.T.dot(deltas_out)/k).squeeze(1)
        es_estimator_gy = (U_i.T.dot(deltas_in)/k).squeeze(1)
        # print(deltas_th)
        es_estimator_gxyv = (deltas_th*U_o).T.dot(U_i@ jnp.expand_dims(v, axis=1)).squeeze(1)
        es_estimator_gyyv = (deltas_th*U_i).T.dot(U_i@ jnp.expand_dims(v, axis=1)).squeeze(1) \
            - jnp.sum(deltas_th*(jnp.expand_dims(v, axis=1).T), axis=0)
        
        grad_gy = es_estimator_gy
        grad_R = es_estimator_gyyv + es_estimator_fy
        grad_f = es_estimator_gxyv + es_estimator_fx
        return grad_gy, grad_f, grad_R, key

    def get_step(self, inner_sampler, outer_sampler):

        grad_inner = jax.grad(self.f_inner, argnums=0)
        grad_outer = jax.grad(self.f_outer, argnums=(0, 1))
        
        def soba_one_iter(carry, _):

            (inner_step_size, outer_step_size), carry['state_lr'] = update_lr(
                carry['state_lr']
            )

            # Step.1 - get all gradients and compute the implicit gradient.
            start_inner, *_, carry['state_inner_sampler'] = inner_sampler(
                carry['state_inner_sampler']
            )
    
            start_outer, *_, carry['state_outer_sampler'] = outer_sampler(
                carry['state_outer_sampler']
            )
            
            grad_gy, grad_f, grad_R, key = self.approximate_gradient(carry['inner_var'], carry['outer_var'],carry['v'], carry['key'],
                                                                        start_inner, start_outer, self.n_gaussian_vectors)
            
            carry['key'] = key
            carry['query'] += 4*self.n_gaussian_vectors+1
            # self.query += (3*self.batch_size_inner+2*self.batch_size_outer)*self.n_gaussian_vectors
            # jax.debug.print('query {}', carry['query'])
            carry['inner_var'] -= inner_step_size * clip_vector_jax(grad_gy)
            carry['v'] -= inner_step_size * clip_vector_jax(grad_R)
            norm = jnp.linalg.norm(carry['v'])
            clip_coef = jnp.minimum(1, self.rz / norm)
            carry['v'] = carry['v'] * clip_coef
            carry['outer_var'] -= outer_step_size * clip_vector_jax(grad_f)
            # carry['outer_var'] = jnp.clip(carry['outer_var'],min=-5,max=5)
            carry['outer_var'] = prox.prox_lasso(carry['outer_var'], l1reg=0.1)
            # jax.debug.print('{}', carry['outer_var'])
            # jax.debug.print('{}', a.shape)
            # jax.debug.print('{}', carry['outer_var'].shape)
            return carry, _

        return soba_one_iter
