import jax
import jax.numpy as jnp
import numpy as np
import optax
from typing import Tuple
from jax.lax import stop_gradient
from jaxopt import LBFGS
from functools import partial

####################################################################
####################################################################
"""
Chunked Trust Region Projection 
"""
####################################################################
####################################################################

def make_chunked_compute_etas(policy_apply_fn, kl_cov_proj, nr_steps: int, nr_envs: int, mean_bound: float, cov_bound: float):
    """
    chunked_compute_etas(policy_buffer, inputs, chunk_size)
    states: (batch, ...)
    returns: etas (nr_pairs,) where nr_pairs = len(policy_buffer) - 1
    """

    vmapped_policy = jax.vmap(
        jax.vmap(policy_apply_fn, in_axes=(None, 0), out_axes=(0, None)),
        in_axes=(0, None), out_axes=(0, 0)
    )

    def kl_mean_proj(mean, old_mean, old_std, eps):
        maha = 0.5 * jnp.sum(((old_mean - mean)/ old_std) ** 2)

        def proj(_):
            omega = jnp.sqrt(maha / eps) - 1.0
            proj_mean = (mean + omega * old_mean) / (1.0 + omega)
            post = stop_gradient(0.5 * jnp.sum(((old_mean - proj_mean)/ old_std) ** 2))
            return proj_mean, omega, post

        def nop(_):
            return mean, 0.0, maha
        return jax.lax.cond(maha > eps, proj, nop, operand=None)

    vmapped_kl_mean = jax.vmap(kl_mean_proj, in_axes=(0,0,0,None), out_axes=(0,0,0))
    vmap_cov = jax.vmap(kl_cov_proj, in_axes=(0,0,None,None))

    @jax.jit
    def process_chunk_on_device(etas, write_index, params_stacked, states):
        mean, logstd = vmapped_policy(params_stacked, states) # mean: (chunk_len, batch, A), logstd: (chunk_len, A)
        std = jnp.exp(logstd) # (chunk_len, A)
        old = mean[:-1]     # (pair_count, batch, A)
        curr = mean[1:]     # (pair_count, batch, A)
        old_std = std[:-1]  # (pair_count, A)
        curr_std = std[1:]  # (pair_count, A)
        pair_count = curr.shape[0]
        batch = curr.shape[1]
        action_dim = curr.shape[2]

        def _no_pairs(args):
            return args[0], args[1]

        def _has_pairs(args):
            etas, write_index, curr, old, curr_std, old_std = args

            # flatten pair*batch for mean projection vmapping
            curr_flat = curr.reshape((pair_count * batch, action_dim))
            old_flat = old.reshape((pair_count * batch, action_dim))
            old_std_flat = jnp.repeat(old_std, repeats=batch, axis=0)  # (pair*batch, A)

            _, omega_mu_flat, _ = vmapped_kl_mean(curr_flat, old_flat, old_std_flat, mean_bound)
            omega_mu = jnp.reshape(omega_mu_flat, (pair_count, batch))  # (pair_count, batch)
            eta_mu = jax.lax.stop_gradient(omega_mu)

            cov_out = vmap_cov(curr_std, old_std, cov_bound, action_dim)
            omega_cov = cov_out[1] if isinstance(cov_out, (tuple, list)) else cov_out
            eta_cov = jax.lax.stop_gradient(omega_cov)  # (pair_count,)

            eta_pair = jnp.reshape(jnp.maximum(eta_mu, eta_cov[:, None]), (pair_count, nr_steps, nr_envs)) # (pair_count, nr_steps, nr_envs)
            etas = jax.lax.dynamic_update_slice(etas, eta_pair, (write_index, 0, 0))
            return etas, write_index + pair_count

        return jax.lax.cond(
            pair_count > 0,
            _has_pairs,
            _no_pairs,
            operand=(etas, write_index, curr, old, curr_std, old_std)
        )


    def chunked_compute_etas(policy_buffer, states, chunk_size: int):
        nr_policies = len(policy_buffer)  
        nr_pairs = max(0, nr_policies - 1)
        if nr_pairs == 0:
            return jnp.zeros((1, nr_steps, nr_envs), dtype=jnp.float32)

        etas = jnp.zeros((nr_policies, nr_steps, nr_envs), dtype=jnp.float32)
        write_index = 0
        models_consumed = 0

        for i in range(0, nr_policies, chunk_size):
            if models_consumed >= nr_policies:
                break
            avail = min(chunk_size, nr_policies - models_consumed)
            chunk_params = list(policy_buffer[i : i + avail])
            if i > 0:
                chunk_params = [policy_buffer[i - 1]] + chunk_params
            params_stacked = jax.tree.map(lambda *xs: jnp.stack(xs), *chunk_params)
            etas, write_index = process_chunk_on_device(etas, write_index, params_stacked, states)
            models_consumed += avail

        return stop_gradient(etas)

    return chunked_compute_etas


####################################################################
####################################################################
"""
Chunked Reward Projection 
"""
####################################################################
####################################################################

def make_chunked_ensemble_rew_project(log_density_ratio_fn, nr_steps: int, nr_envs: int, epsilon: float, beta: float, entropy_coef: float, maximum_eta: bool):
    """
    Only load a fixed-size chunk of the buffer into memory. Avoid OOM issues and allows deepr discriminator chains

    Returns:
    chunked_project(buffer, inputs, etas, nr_projections, chunk_size=None) -> proj_reward (nr_steps, nr_envs)

    Args:
        buffer: python list/tuple of param pytrees (length B, dynamic)
        inputs: (states, actions) with states/actions shaped (N, ...) where N = nr_steps * nr_envs
        etas: jnp array shape (B, nr_steps, nr_envs)
        nr_projections: number of models to project through (<= B)
    """
    vmapped_chunk = jax.vmap(log_density_ratio_fn, in_axes=(0, None, None, None, None), out_axes=0)

    # @jax.jit
    @partial(jax.jit, static_argnames=('maximum_eta',))
    def process_chunk_on_device(proj, params_stacked, eta_chunk, states, actions, next_states, absorbing, maximum_eta: bool):
        """
        Args:
            proj: (nr_steps, nr_envs)
            params_stacked: pytrees with leading axis chunk_len
            eta_chunk: (chunk_len, nr_steps, nr_envs)
            states/actions: (N, ...) where N = nr_steps*nr_envs
        """
        out_flat = vmapped_chunk(params_stacked, states, actions, next_states, absorbing)
        chunk_len = out_flat.shape[0]
        out_chunk = out_flat.reshape((chunk_len, nr_steps, nr_envs))
        def body_fun(t, carry):
            ldr_t = out_chunk[t]
            eta_t = eta_chunk[t]
            if maximum_eta:
                step = epsilon / (1.0 + jnp.clip(jnp.max(eta_t), a_min=0.0))
            else:
                step = epsilon / (1.0 + jnp.clip(eta_t, a_min=0.0))
            return (1.0 - step) * carry + step * beta * ldr_t
        proj_final = jax.lax.fori_loop(0, chunk_len, body_fun, proj)
        return proj_final


    def chunked_project(buffer, inputs, etas, chunk_size: int):
        """
        Args:
            buffer: python list/tuple of param pytrees, length B
            inputs: (states, actions) where states/actions shape (N, ...)
            etas: jnp array shape (B, nr_steps, nr_envs)
        """
        states, actions, next_states, absorbing = inputs
        B = len(buffer)
        proj = jnp.zeros((nr_steps, nr_envs))
        models_consumed = 0

        for i in range(0, B, chunk_size):
            if models_consumed >= B:
                break

            available = min(chunk_size, B - models_consumed)
            if available <= 0:
                break
            chunk_params = buffer[i : i + available]
            params_stacked = jax.tree.map(lambda *xs: jnp.stack(xs), *chunk_params)

            eta_chunk = etas[models_consumed : models_consumed + available]
            proj = process_chunk_on_device(proj, params_stacked, eta_chunk, states, actions, next_states, absorbing, maximum_eta=maximum_eta)
            models_consumed += available

        return entropy_coef * stop_gradient(proj) # scaling to reduce critic loss

    return chunked_project


####################################################################
####################################################################
"""
KL divergence projection layer from https://arxiv.org/pdf/2101.09207
"""
####################################################################
####################################################################


def dual(eta_omega, pred_std, target_std, target_logdet, eps, omega_offset):

    eta = jnp.where(eta_omega[0] > 0.0, eta_omega[0], 0.0)
    new_std = jnp.sqrt((eta + omega_offset) / jnp.clip((eta/(target_std**2)) + (1/pred_std)**2, a_min=1e-8))
    new_std = jnp.clip(jnp.nan_to_num(new_std), a_min=1e-8)
    new_logdet = -2.0 * jnp.sum(jnp.log(1/new_std))

    dual_val = eta * eps - 0.5 * eta * target_logdet
    dual_val += 0.5 * (eta + omega_offset) * new_logdet

    kl = 0.5 * jnp.sum(2.0 * (jnp.log(target_std) - jnp.log(new_std)) + (new_std/target_std)**2 - 1.0)
    grad_val = eps - kl

    return dual_val, jnp.array([grad_val])

@jax.custom_vjp
def kl_cov_proj(pred_std: jnp.ndarray, target_std: jnp.ndarray, eps: float, max_eval: int = 50, 
                        omega_offset: float = 1.0, eta_init: float = 0.0) -> Tuple[jnp.ndarray, float]:

    """
    pred_std: standard deviation of the policy's prediction:  (1, as_dim)
    target_std: standard deviation of the old_policy:  (1, as_dim)
    max_eval: number of iterations of L-BFGS
    omega_offset: offset to include entropy term in the original optimization objective (see pg: 17 in paper)
    eta_init: initial guess for L-BFGS
    """
    target_logdet = -2.0 * jnp.sum(jnp.log(jnp.clip(1/target_std, a_min=1e-8)))

    """ Optax L BFGS """
    def objective(eta_omega):
        val, grad = dual(eta_omega, pred_std, target_std, target_logdet, eps, omega_offset)
        return val, grad

    def opt_bfgs(init_params, fun, opt, max_iter, tol):

        value_and_grad_fun = objective

        def step(carry):
            params, state = carry
            value, grad = objective(params)
            updates, state = opt.update(
                grad, state, params, value=value, grad=grad, value_fn=fun
            )
            params = optax.apply_updates(params, updates)
            return params, state

        def continuing_criterion(carry):
            _, state = carry
            iter_num = optax.tree.get(state, 'count')
            grad = optax.tree.get(state, 'grad')
            err = optax.tree.norm(grad)
            return (iter_num == 0) | ((iter_num < max_iter) & (err >= tol))

        init_carry = (init_params, opt.init(init_params))
        final_params, final_state = jax.lax.while_loop(
            continuing_criterion, step, init_carry
        )
        return final_params, final_state


    # Initialize parameters
    init_params = jnp.array([eta_init])

    # Optax LBFGS optimizer setup
    lbfgs = optax.lbfgs()
    params, _ = opt_bfgs(init_params, lambda x: objective(x)[0], lbfgs, max_iter=max_eval, tol=1e-9)
    eta_opt = params[0]

    # """Jaxopt L BFGS """
    # def objective(eta_omega):
    #     val, grad = dual(eta_omega, pred_std, target_std, target_logdet, eps, omega_offset)
    #     return val, grad

    # # Initialize parameters
    # init_params = jnp.array([eta_init])

    # solver = LBFGS(fun=objective, maxiter=max_eval, value_and_grad=True, tol=1e-9)
    # opt_res = solver.run(init_params)

    # params = opt_res.params
    # eta_opt = params[0]

    """ Computing projected cov """
    projected_cov = (eta_opt + omega_offset) / jnp.clip(eta_opt/(target_std**2) + (1.0/pred_std)**2, a_min=1e-8)
    projected_cov = jnp.clip(jnp.nan_to_num(projected_cov), a_min=1e-16)
    return projected_cov, eta_opt



def kl_cov_proj_backward(d_proj: jnp.ndarray,
                                succ: bool,
                                omega_offset: float,
                                eta: float,
                                pred_std: jnp.ndarray,
                                target_std: jnp.ndarray,
                                projected_std: jnp.ndarray) -> jnp.ndarray:

    if not succ:
        raise RuntimeError("Optimization was not successful, cannot run backward")

    def last_eta_grad(pred_std, target_std, projected_std, eta, omega_offset):

        def eta_zero_case(_):
            return jnp.zeros_like(target_std)

        def eta_positive_case(eta):
            dQ_deta = (omega_offset * (1/target_std**2) - (1/pred_std**2)) / (eta + omega_offset)
            tmp = jnp.ones_like(target_std) - (1/target_std**2) * projected_std**2
            f2_dQ = projected_std**2 * tmp
            sum_val = jnp.sum(f2_dQ * dQ_deta)
            c = -1.0 / jnp.where(jnp.abs(sum_val) < 1e-8, jnp.sign(sum_val) * 1e-8, sum_val + 1e-12)
            return c * f2_dQ

        def eta_negative_case(_):
            return jnp.full_like(target_std, jnp.nan)

        return jax.lax.cond(
            eta == 0.0,
            eta_zero_case,
            lambda _: jax.lax.cond(
                eta > 0.0,
                eta_positive_case,
                eta_negative_case,
                operand=eta
            ),
            operand=eta
        )

    deta_dQ_pred = last_eta_grad(pred_std, target_std, projected_std, eta, omega_offset)
    eo = omega_offset + eta
    eo_squared = eo * eo
    dQ_deta = (omega_offset * (1/target_std**2) - (1/pred_std**2)) / eo_squared
    d_Q = - (projected_std**2) * d_proj * (projected_std**2)
    d_eta = jnp.sum(d_Q * dQ_deta)
    d_Q_pred = d_eta * deta_dQ_pred + d_Q / eo
    d_cov_pred = - (1/pred_std**2) * d_Q_pred * (1/pred_std**2)  
    d_cov_pred = jnp.clip(jnp.nan_to_num(d_cov_pred), a_min=1e-20)
    d_pred_std = 2.0 * pred_std * d_cov_pred

    return d_pred_std

def kl_cov_proj_fwd(pred_std, target_std, eps, max_eval, omega_offset, eta_init):
    projected_cov, eta = kl_cov_proj(pred_std, target_std, eps, max_eval, omega_offset, eta_init)
    residuals = (pred_std, target_std, projected_cov, eta, omega_offset)
    return (projected_cov, eta), residuals

def kl_cov_proj_bwd(residuals, cotangents):
    d_projected_cov, d_eta = cotangents
    pred_std, target_std, projected_cov, eta, omega_offset = residuals
    projected_std = jnp.sqrt(projected_cov)
    d_std_pred = kl_cov_proj_backward(d_proj=d_projected_cov, succ=True, omega_offset=omega_offset, eta=eta, pred_std=pred_std, target_std=target_std, projected_std=projected_std)
    return (d_std_pred, None, None, None, None, None)

kl_cov_proj.defvjp(kl_cov_proj_fwd, kl_cov_proj_bwd)

def kl_projection(mean, std, mean_other, std_other, eps_mean, eps_cov):
    """
    Project the pred policy back to satisfy reverse KL --> KL(pi_old || pi)  = (old_mean - mean)* inv_cov * (old_mean - mean) + cov_kl

    Args:
        mean, std: mean and std of the current policy
        mean_other, std_other: mean and std of the old policy
    """

    mean_part = 0.5 * jnp.sum(((mean_other - mean)/ std_other) ** 2)

    def mean_projection(mean, old_mean, maha, std_other, eps):
        """
        Projects the mean based on the Mahalanobis objective and trust region.
        Args:
            mean: current mean vectors
            old_mean: old mean vectors
            mean_part: Mahalanobis distance between the two mean vectors
            eps_mean: trust region bound

        Returns:
            projected mean that satisfies the trust region
            lagrangian multipliers
        """

        def true_fn(_):
            omega = jnp.sqrt(maha / eps) - 1.0
            proj_mean = (mean + omega * old_mean) / (1 + omega)
            # kl contribution of the mean after projection
            post_proj_kl_mean_part = stop_gradient(0.5 * jnp.sum(((mean_other - proj_mean)/ std_other) ** 2))
            return proj_mean, omega, post_proj_kl_mean_part

        def false_fn(_):
            # Skip if already in the trust region
            return mean, 0.0, maha

        proj_mean, omega, post_proj_kl_mean_part = jax.lax.cond(maha > eps, true_fn, false_fn, operand=None)
        return proj_mean, omega, post_proj_kl_mean_part


    proj_mean, eta_mu, post_proj_kl_mean_part = mean_projection(mean, mean_other, mean_part, std_other, eps_mean)

    # Skip if already in the trust region
    kl_cov_part = 0.5 * jnp.sum(2.0 * (jnp.log(std_other) - jnp.log(std)) + (std/std_other)**2 - 1.0)
    def do_cov_proj(_):
        proj_cov, eta_cov = kl_cov_proj(jnp.squeeze(std), jnp.squeeze(std_other), eps_cov)
        proj_std = jnp.expand_dims(jnp.sqrt(proj_cov), 0)
        post_proj_kl_cov_part = stop_gradient(0.5 * jnp.sum(2.0 * (jnp.log(std_other) - jnp.log(proj_std)) + (proj_std/std_other)**2 - 1.0))
        return proj_std, eta_cov, post_proj_kl_cov_part

    def skip_cov_proj(_):
        return std, 0.0, kl_cov_part

    proj_std, eta_cov, post_proj_kl_cov_part = jax.lax.cond(
        kl_cov_part > eps_cov,
        do_cov_proj,
        skip_cov_proj,
        operand=None
    )

    return proj_mean, proj_std, eta_mu, eta_cov, mean_part, post_proj_kl_mean_part, kl_cov_part, post_proj_kl_cov_part


def entropy_projection(action_logstd, beta, dim):
    """
    Projects std to satisfy an entropy inequality constraint.
    """
    entropy = jnp.sum(action_logstd + 0.5 * jnp.log(2.0 * jnp.pi * jnp.e))

    def true_fn(_):
        logalpha = (beta - entropy)/dim
        proj_logstd = logalpha + action_logstd
        return proj_logstd

    def false_fn(_):
        return action_logstd

    proj_logstd = jax.lax.cond(entropy < beta, true_fn, false_fn, operand=None)
    return proj_logstd