from mbrl.env import reward_fns
from typing import KeysView
from mbrl.third_party.unrolled_actor_soft_critic.agent import Agent
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from numpy import random
from torch import nn

import mbrl
from mbrl.models.one_dim_tr_model import OneDTransitionRewardModel


class AdjointParameters(object):
    """
    Immutable mapping from parameters to a gradient. Supports elementwise arithmetic.

    This is a utility class used to make computing gradient easier -- it allows for
    easy extension to RK4, for example.
    """
    def __init__(self, parameters):
        self.l = list(parameters)
        self.d = {p: torch.zeros_like(p) for p in self.l}

    def parameters(self):
        return tuple(self.l)

    def items(self):
        return self.d.items()

    def grad(self, grad):
        result = AdjointParameters(self.d.keys())
        for k, v in zip(self.l, grad):
            if v is not None:
                assert k.shape == v.shape
                assert not v.isnan().any()
                result.d[k] = v
        return result

    def __add__(self, other):
        assert isinstance(other, AdjointParameters)
        result = AdjointParameters(self.d.keys())
        assert len(self.d) == len(other.d)
        for k, me in self.d.items():
            assert not other.d[k].isnan().any()
            result.d[k] += me + other.d[k]
        return result

    def __mul__(self, other):
        assert isinstance(other, (float, np.float32, np.float64, int))
        result = AdjointParameters(self.d.keys())
        for k, me in self.d.items():
            result.d[k] += me * other
        return result

    def __truediv__(self, other):
        assert isinstance(other, (float, np.float32, np.float64, int))
        result = AdjointParameters(self.d.keys())
        for k, me in self.d.items():
            result.d[k] += me / other
        return result

    def __rmul__(self, other):
        return self.__mul__(other)

    def nancheck(self):
        nans = [a.isnan().any() for p, a in self.items()]
        if any(nans):
            print(f"NaNs found: {nans}")
            return True

class SuccessorAugmentedDynamics(nn.Module):
    def __init__(self, dynamics, agent):
        """
        Represents the *augmented* dynamics; the concatenation of state with the associated adjoint values.
        This is designed to be used with dynamics models that predict the successor state (and maintain an
        execution graph).

            x_{t+1} = dynamics(x_t, policy(x_t))

        See section 2 of the Neural ODE paper (https://arxiv.org/pdf/1806.07366.pdf) for an explanation.

        Args:
            dynamics (OneDTransitionRewardModel): the underlying dynamics function
            agent (Agent): the agent
        """

        super().__init__()
        self.dynamics = dynamics
        self.agent = agent

    def augmented_dynamics(self, xy, xy_successor, at, params : AdjointParameters):
        assert isinstance(params, AdjointParameters)
        assert xy.requires_grad
        assert not at.requires_grad

        adjoint = None
        # Calculate the local dynamics gradient:
        # Note the negative of the adjoint value is used; this is explained in Algorithm 1 of the neural ODE paper.
        adjoint, *param_grads = torch.autograd.grad((xy_successor,), (xy,) + params.parameters(), grad_outputs=(-at), allow_unused=True, retain_graph=True)

        return adjoint.detach(), params.grad(param_grads)

    def forward(self, curr_state, succ_state, succ_adjoint, succ_adjoint_params, timestep=-1.0):
        """Compute the current step from the successor step using backwards Euler integration.

        The timestep is negative because this function is typically run backwards. The magnitude is 1.0 because
        the OneDTransitionRewardModel predicts a state _delta_, not a gradient; so it already accounts for the timestep.

        Args:
            curr_state (torch.tensor((batch, len(obs)))): the augmented state so far
            succ_state (torch.tensor((batch, len(obs)))): the successor state
            succ_adjoint (torch.tensor((batch, len(obs))): the adjoint of the successor state (gradient of succ_state with respect to total reward)
            succ_adjoint_params (AdjointParameters) : the parameter adjoint up to the successor state (gradient of actor parameters with respect to total reward)
        Returns:
            (curr_state, curr_adjoint_params): the successor state and adjoint parameters.
        """
        # Euler integration, take a step forward:
        curr_adjoint, delta_adjoint_params = self.augmented_dynamics(curr_state, succ_state, succ_adjoint, succ_adjoint_params)
        # The adjoint computed here is the sum, so we divide it by the batch size to get the mean gradient.
        curr_adjoint_params = succ_adjoint_params + timestep * delta_adjoint_params/(curr_state.shape[0])

        return curr_adjoint, curr_adjoint_params

    @staticmethod
    def base_case(agent, state, action, log_prob, *args, discount=1.0, **kwargs):
        """Compute d(V(xt))/d(xt),

        Args:
            agent (agent): xt
            state ([type]): xt
            action ([type]): action from xt
            log_prob ([type]): log_prob corresponding to action
            discount ([type]): current discount factor
        """
        adj_params = AdjointParameters(agent.actor.parameters())

        # Get the ending dr/dxt = V(xt)/dxt = d(min(Q1(xt,at), Q2(xt,at)))/dxt.
        actor_Q1, actor_Q2 = agent.critic_target(state, action)
        target_V = torch.min(actor_Q1, actor_Q2) - agent.alpha.detach()*log_prob
        ending_loss = -torch.mean(target_V, dim=0) * discount

        # We compute adjoint = dr/dxt, and adj_params = dr/dp
        # This is the base case for the adjoint unrolling:
        adjoint, *param_grads = torch.autograd.grad((ending_loss,), (state,) + adj_params.parameters(),
            grad_outputs=torch.ones_like(ending_loss), allow_unused=True, retain_graph=True)
        adj_params=adj_params.grad(param_grads)

        return state, adjoint, adj_params

    @staticmethod
    def reward_case(agent, state, log_prob, reward, discount=1.0):
        """Compute d(r(xt))/d(xt)

        Args:
            state ([type]): xt
            log_prob ([type]): log_prob corresponding to action
            discount ([type]): current discount factor
        """
        adj_params = AdjointParameters(agent.actor.parameters())
        loss = -discount * (reward - agent.alpha.detach()*log_prob).mean()

        # We compute adjoint = dr/dxt, and adj_params = dr/dp
        # This is the base case for the adjoint unrolling:
        adjoint_delta, *param_grads = torch.autograd.grad((loss,), (state,) + adj_params.parameters(),
            grad_outputs=torch.ones_like(loss), allow_unused=True, retain_graph=True)
        return adjoint_delta, adj_params.grad(param_grads)


#
# Adjoint implementation
#
def adjoint_forward(dynamics : OneDTransitionRewardModel, agent : Agent, starting_state, timestep=1.0, rollout_length=4, rng=np.random.default_rng()):
    """
    Generate unrolled trajectories and the final state reached at the end of the unrolling.

    This is a utility function for the adjoint computation.

    Args:
        dynamics (OneDTransitionRewardModel): delta-predicting state transition model
        agent (Agent): the agent
        starting_state (torch.tensor): [NxK] torch array with N starting states, each K elements long.
        rollout_length (int, optional): Number of steps to unroll.
        timestep (float, optional): Timestep for successor calculation. Defaults to 1.0

    Returns List[Tuple[
        observ. :  (torch.tensor([N, K])) -- state,
        action :   (torch.tensor([N, J])) -- actions taken corresponding to each state
        log_prob : (torch.tensor([N, 1])) -- entropy of the action used
        reward :   (torch.tensor([N, 1])) -- reward associated with the state
    ]] for (rollout_length + 1) states.

    All these with an intact execution graph linking them back to starting_state.

    Note: the entries at return_value[k] correspond to the kth observation in the trajectory (where k=0 is `starting_state`),
    the reward from entering the (k+1)th state, and the action (and associated entropy) taken from the kth state.
    This matches the reference implementation `unroll_actor_direct`.
    """
    assert torch.is_tensor(starting_state)

    rv = []
    assert starting_state.requires_grad
    obs = starting_state

    for i in range(rollout_length + 1):
        # print(f"loop {i}, |obs|={obs.size()}")
        batch = mbrl.types.TransitionBatch(obs, None, None, None, None)

        # Capture the immediate reward:
        dist = agent.actor(obs)
        batch.act = dist.rsample() # Sample from the set of actions
        obs, rwd = dynamics.sample(batch, deterministic=True, rng=rng)
        assert rwd is not None # Check the dynamics model also predicts the reward

        # Calculate the entropy of the action distribution
        log_prob = dist.log_prob(batch.act).sum(-1, keepdim=True)

        rv.append((batch.obs, batch.act, log_prob, rwd))

    return rv

def adjoint_params(dynamics_model : OneDTransitionRewardModel, agent : Agent, starting_state, timestep=1.0, rollout_length=4, rng=np.random.default_rng(), discount=1.0, compensate=True):
    """Calculate d(Reward)/d(Starting State) and d(Reward)/d(Agent Parameters) using the adjoint method.

    In the code, dr refers to d(reward), dx0 refers to d(starting_state), and dp refers to d(agent params). In this function, we are computing:

        dr/dp = dr/dxt*dxt/dp + dr/dx_{t-1}*dx_{t-1}/dp + ... + dr/dx0*dx0/dp

    Where dr/dxt is estimated using the value function in `agent`, and the rest is computed using the adjoint method.

    Args:
        (Refer to adjoint_forward)
        compensate (bool) : if true, divide the gradients by the sum of discounted steps; this compensates for the unrolling length

    Returns: {
        "adjoint" (torch.tensor([NxK])) : d(r)/d(x0)
        "params" (AdjointParameters)    : d(r)/d(p)
    }
    """
    # Forward simulation:
    starting_state = starting_state.detach().clone()
    starting_state.requires_grad = True
    traj = adjoint_forward(dynamics_model, agent, starting_state, timestep, rollout_length=rollout_length, rng=rng)
    log_prob_history = [log_prob for _, _, log_prob, _ in traj]
    # And the discounts associated with each element:
    discounts = [discount**i for i in range(len(traj))]
    if compensate:
        discount_compensation = sum(discounts)
        discounts = [d / discount_compensation for d in discounts]

    # Prepare the fancy container object we will be using to compute the reverse pass:
    augmented_dynamics = SuccessorAugmentedDynamics(dynamics_model, agent.actor)

    # This is the final state we arrive at, along with the next action we are going to take.
    state_succ, adjoint_succ, adj_params = SuccessorAugmentedDynamics.base_case(agent, *traj.pop(), discount=discounts.pop())

    while len(traj):
        # Variables concerning the current step:
        state_curr, action_curr, entropy_curr, reward_curr = traj.pop()
        discount_curr = discounts.pop()
        # state_succ, adjoint_succ contains the successor state and adjoint respectively

        adjoint_curr, adj_params = augmented_dynamics(state_curr, state_succ, adjoint_succ, adj_params)
        # We have the adjoint value with respect to the current state, but excluding the current reward and entropy.
        # We incorporate both into the value now:
        # We're assuming there that the reward is purely a function of the state.
        adjoint_curr_delta, adj_params_delta = SuccessorAugmentedDynamics.reward_case(agent, state_curr, entropy_curr, reward_curr, discount=discount_curr)
        adjoint_curr = adjoint_curr + adjoint_curr_delta
        adj_params = adj_params + adj_params_delta

        state_succ = state_curr
        adjoint_succ = adjoint_curr

    return {
        "adjoint": adjoint_succ,
        "params": adj_params,
        "log_prob": log_prob_history,
    }

#
# RK4-based implementation
#

class RK4AugmentedDynamics(nn.Module):
    def __init__(self, physics, policy_model):
        """
        Represents the *augmented* dynamics; the concatenation of state with the associated adjoint values.
        This is designed to be used with continuous dynamics models. i.e. those that predict the time-derivative
        of the state, and require integration to generate a trajectory.

            \dot x_{t} = dynamics(x_t, policy(x_t))

        Given that (1) most MBRL settings involve Euler integration, and (2) PyTorch has lightweight call graphs,
        we can use the much simpler SuccessorAugmentedDynamics instead of this.

        This needs some additional work before it functions with the ensemble dynamics used by UASC.

        Args:
            physics (state, action => time-derivative): the underlying dynamics function
            policy_model (state => action): the actor function
        """

        super().__init__()
        self.physics = physics
        self.policy_model = policy_model

    def augmented_dynamics(self, x, params):
        assert isinstance(params, AdjointParameters)
        dim = x.shape[1]
        assert dim % 2 == 0
        xy = x[:,:dim//2].detach().clone()
        xy.requires_grad = True
        at = x[:,dim//2:]

        # Calculate the local dynamics gradient:
        adjoint = None

        f = self.physics(xy, self.policy_model(xy))
        adjoint, *param_grads = torch.autograd.grad((f,), (xy,) + params.parameters(), grad_outputs=(-at), allow_unused=True, retain_graph=True)

        return torch.cat((f.detach(), adjoint.detach()), dim=1), params.grad(param_grads)

    def forward(self, timestep, curr_state, curr_adjoint_params):
        # RK4 integration, take a step forward:
        # Note that we are _not_ multiplying the gradient w.r.t. the parameters by the
        # timestep. For this derivation, consult the NeuralODE paper.
        k1, k1_adjoint = self.augmented_dynamics(curr_state, curr_adjoint_params)
        k1 = timestep * k1

        k2, k2_adjoint = self.augmented_dynamics(curr_state + k1/2, curr_adjoint_params + k1_adjoint/2)
        k2 = timestep * k2

        k3, k3_adjoint = self.augmented_dynamics(curr_state + k2/2, curr_adjoint_params + k2_adjoint/2)
        k3 = timestep * k3

        k4, k4_adjoint = self.augmented_dynamics(curr_state + k3, curr_adjoint_params + k3_adjoint)
        k4 = timestep * k4

        new_state = (curr_state + 1/6 * (k1 + 2*k2 + 2*k3 + k4))
        new_adjoint = curr_adjoint_params + 1/6 * (k1_adjoint + 2*k2_adjoint + 2*k3_adjoint + k4_adjoint) * timestep / (curr_state.shape[0])

        return new_state, new_adjoint
