from benchmark_utils.stochastic_jax_solver import StochasticJaxSolver

from benchopt import safe_import_context

with safe_import_context() as import_ctx:
    from benchmark_utils.learning_rate_scheduler import update_lr
    from benchmark_utils.learning_rate_scheduler import init_lr_scheduler
    import random
    import jax
    import jax.numpy as jnp

def clip_vector_jax(grad_vector, max_norm=10):
    norm = jnp.linalg.norm(grad_vector)
    
    clip_coef = jnp.minimum(1, max_norm / (norm + 1e-8))
    
    grad_vector = grad_vector * clip_coef
        
    return grad_vector

class Solver(StochasticJaxSolver):
    
    name = 'SRZOBA'

    # any parameter defined here is accessible as a class attribute
    parameters = {
        'step_size': [.1],
        'outer_ratio': [1.],
        'batch_size': [64],
        'n_gaussian_vectors': [1],
        'mu': [.1],
        'random_state': [1],
        'large_batch_size': [256],
        'p': [0.5],
        'eta': [1.],
        'large_k': [5],
        **StochasticJaxSolver.parameters
    }

    def init(self):
        # Init variables
        self.inner_var = self.inner_var0.copy()
        self.outer_var = self.outer_var0.copy()
        v = jnp.ones_like(self.inner_var)
        self.outer_var_shape = self.outer_var.shape[0]
        self.inner_var_shape = self.inner_var.shape[0]
        # Init lr scheduler
        step_sizes = jnp.array(
            [self.step_size, self.step_size / self.outer_ratio, self.eta]
        )
        exponents = jnp.array(
            [1/3,1/3, 2/3]
        )
        # memory 0: save the old parameter; 1: save the momemtum parameter
        memory_inner = jnp.zeros((2, *self.inner_var.shape))
        memory_outer = jnp.zeros((2, *self.outer_var.shape))
        memory_v = jnp.zeros((2, *self.inner_var.shape))

        state_lr = init_lr_scheduler(step_sizes, exponents)
        return dict(
            inner_var=self.inner_var, outer_var=self.outer_var, v=v,
            memory_inner=memory_inner, memory_outer=memory_outer, memory_v=memory_v,
            iter_num = 0,
            state_lr=state_lr,
            state_inner_sampler=self.state_inner_sampler,
            state_outer_sampler=self.state_outer_sampler,
            state_inner_sampler1=self.state_inner_sampler1,
            state_outer_sampler1=self.state_outer_sampler1,
            key=jax.random.PRNGKey(self.random_state),
            query=0
        )

    def approximate_gradient(self, inner_var, outer_var, v, key, 
                             start_inner, start_outer, k):
        key = jax.random.split(key, 1)[0]
        # Sample gaussian vectors and perform ES estimation
        U_o = jax.random.normal(key, (k, self.outer_var_shape))
        outer_var_u = outer_var + self.mu * U_o
        outer_var_i = outer_var - self.mu * U_o
        U_i = jax.random.normal(key, (k, self.inner_var_shape))
        inner_var_u = inner_var + self.mu * U_i
        inner_var_i = inner_var - self.mu * U_i
        
        inner_value = self.f_inner(
                inner_var, outer_var, start_inner
        )
        
        def iter_fun(q, init_carry):
            (deltas_out,deltas_in,deltas_th) = init_carry
            outer_value_u = self.f_outer(
                inner_var_u[q], outer_var_u[q], start_outer
            )
            outer_value_i = self.f_outer(
                inner_var_i[q], outer_var_i[q], start_outer
            )
        
            deltas_out = deltas_out.at[q].set(outer_value_u - outer_value_i)
            
            inner_value_u = self.f_inner(
                inner_var_u[q], outer_var_u[q], start_inner
            )
            inner_value_i = self.f_inner(
                inner_var_i[q], outer_var_i[q], start_inner
            )
            deltas_in = deltas_in.at[q].set(inner_value_u - inner_value_i)

            deltas_th = deltas_th.at[q].set(inner_value_u - inner_value_i + 2*inner_value)
            return (deltas_out,deltas_in,deltas_th)

        deltas_out,deltas_in,deltas_th = jax.lax.fori_loop(0, k, iter_fun,  
                                                            (jnp.zeros((k,1)),
                                                            jnp.zeros((k,1)),
                                                            jnp.zeros((k,1))))
        
        deltas_out = deltas_out / (2*self.mu)
        deltas_in = deltas_in / (2*self.mu)
        deltas_th = deltas_th / (self.mu**2)
        es_estimator_fx = (U_o.T.dot(deltas_out)/k).squeeze(1)
        es_estimator_fy = (U_i.T.dot(deltas_out)/k).squeeze(1)
        es_estimator_gy = (U_i.T.dot(deltas_in)/k).squeeze(1)
        # print(deltas_th)
        es_estimator_gxyv = (deltas_th*U_o).T.dot(U_i@ jnp.expand_dims(v, axis=1)).squeeze(1)
        es_estimator_gyyv = (deltas_th*U_i).T.dot(U_i@ jnp.expand_dims(v, axis=1)).squeeze(1) \
            - jnp.sum(deltas_th*(jnp.expand_dims(v, axis=1).T), axis=0)
        
        grad_gy = es_estimator_gy
        grad_R = es_estimator_gyyv + es_estimator_fy
        grad_f = es_estimator_gxyv + es_estimator_fx
        return grad_gy, grad_f, grad_R, key

    def reset(self, carry):
        carry['iter_num'] = 0
        # carry['query'] = 0
        return carry
    
    def get_step(self, inner_sampler, outer_sampler, inner_sampler1, outer_sampler1):

        grad_inner = jax.grad(self.f_inner, argnums=0)
        grad_outer = jax.grad(self.f_outer, argnums=(0, 1))
        
        def soba_one_iter(carry, _):

            (inner_step_size, outer_step_size, eta), carry['state_lr'] = update_lr(
                carry['state_lr'],1000
            )
            carry['memory_outer'] = carry['memory_outer'].at[0].set(
                    carry['outer_var']
            )
            carry['memory_inner'] = carry['memory_inner'].at[0].set(
                carry['inner_var']
            )
            carry['memory_v'] = carry['memory_v'].at[0].set(
                carry['v']
            )  
            # probability for 0
            # jax.debug.print("iter_num {}",carry['iter_num'])
            # jax.debug.print("iter_num==0 {}", carry['iter_num']==0)
            carry['key'] = jax.random.split(carry['key'], 1)[0]
            # rp = jax.random.bernoulli(carry['key'], 1-self.p)
            # rp = jax.lax.cond(jax.random.uniform(carry['key']) < p, lambda _: 0, lambda _: 1, None)
            # jax.debug.print("rp {}", rp)

            def full_gradient(carry):
                zeroth_num = self.n_gaussian_vectors
                start_inner, *_, carry['state_inner_sampler'] = inner_sampler(
                    carry['state_inner_sampler']
                )
                start_outer, *_, carry['state_outer_sampler'] = outer_sampler(
                    carry['state_outer_sampler']
                )
                
                grad_gy, grad_f, grad_R, key = self.approximate_gradient(carry['inner_var'], carry['outer_var'], carry['v'], 
                                                                         carry['key'],
                                                                        start_inner, start_outer, k=self.n_gaussian_vectors)
                carry['memory_outer'] = carry['memory_outer'].at[1].set(
                    grad_f
                )
                carry['memory_inner'] = carry['memory_inner'].at[1].set(
                    grad_gy
                )
                carry['memory_v'] = carry['memory_v'].at[1].set(
                    grad_R
                )
                carry['key'] = key
                carry['query'] += 4*self.n_gaussian_vectors+1
                # jax.debug.print("full ..............")
                return carry

            def sotchastic_gradient(carry):
                start_inner, *_, carry['state_inner_sampler'] = inner_sampler(
                    carry['state_inner_sampler']
                )
                start_outer, *_, carry['state_outer_sampler'] = outer_sampler(
                    carry['state_outer_sampler']
                )
                grad_gy, grad_f, grad_R, key = self.approximate_gradient(carry['inner_var'], carry['outer_var'], carry['v'], 
                                                                         carry['key'],
                                                                        start_inner, start_outer, k=self.n_gaussian_vectors)
                # carry['key'] = key
                grad_gy_old, grad_f_old, grad_R_old, key = self.approximate_gradient(carry['memory_inner'][0], carry['memory_outer'][0], 
                                                                         carry['memory_v'][0], carry['key'],
                                                                        start_inner, start_outer, k=self.n_gaussian_vectors)
                # jax.debug.print("outer {}",carry['outer_var'])
                # jax.debug.print("outer {}",carry['memory_outer'][1])
                # jax.debug.print("inner 1 {}",carry['memory_inner'][1])
                # jax.debug.print("v 1 {}",carry['memory_v'][1])
                carry['memory_outer'] = carry['memory_outer'].at[1].set(
                    grad_f + (1-99*eta) *( carry['memory_outer'][1]- grad_f_old)
                )
                carry['memory_inner'] = carry['memory_inner'].at[1].set(
                    grad_gy +(1-99*eta) *(carry['memory_inner'][1] - grad_gy_old )
                )
                carry['memory_v'] = carry['memory_v'].at[1].set(
                    grad_R + (1-99*eta) *(carry['memory_v'][1] - grad_R_old )
                )
                carry['key'] = key
                carry['query'] += (4*self.n_gaussian_vectors+1)*2
                # jax.debug.print("stochastic ..............")
                return carry
            
            carry = jax.lax.cond((carry['iter_num'] == 0), 
                                 full_gradient, sotchastic_gradient, carry)
            # carry = full_gradient(carry)
            carry['iter_num']+= 1
            
            
            # jax.debug.print("a {}", self.batch_size)
            # jax.debug.print("a {}", self.large_batch_size)                    
            carry['inner_var'] -= inner_step_size * carry['memory_inner'][1]
            carry['v'] -= inner_step_size * carry['memory_v'][1]
            norm = jnp.linalg.norm(carry['v'])
            clip_coef = jnp.minimum(1, 10 / norm )
            carry['v'] = carry['v'] * clip_coef
            carry['outer_var'] -= outer_step_size * carry['memory_outer'][1]
            carry['outer_var'] = jnp.clip(carry['outer_var'],min=-5,max=5)
            
            return carry, _

        return soba_one_iter
