from ..agents.rlpd import SACLearner, decay_mask_fn, Temperature
import copy
from functools import partial
from typing import Any, Dict, Optional, Sequence, Tuple, Union

import flax
import gym
import jax
import numpy as np
import jax.numpy as jnp
import optax
from flax import struct
from flax.training.train_state import TrainState
from flax.core.frozen_dict import unfreeze, freeze
import flax.linen as nn

from ..utils.dataset_utils import DatasetDict
from ..models.distributions import TanhNormal
from ..models.model import (
    MLP_RLPD,
    Ensemble,
    MLPResNetV2,
    StateActionValue,
    QValue,
    DynamicsModel,
    LSTM_RLPD,
    AdjusterModel,
    subsample_ensemble,
)

from flax import struct
from flax.training.train_state import TrainState

DataType = Union[np.ndarray, Dict[str, "DataType"]]
PRNGKey = Any
Params = flax.core.FrozenDict[str, Any]


class RLPhiLearner(SACLearner):

    dynamics: TrainState
    adjuster: TrainState
    reward_model: TrainState
    adjuster_Q: TrainState
    target_adjuster_Q: TrainState

    @classmethod
    def create(
        cls,
        seed: int,
        observation_space: gym.Space,
        action_space: gym.Space,
        actor_lr: float = 3e-4,
        critic_lr: float = 3e-4,
        temp_lr: float = 3e-4,
        hidden_dims: Sequence[int] = (256, 256),
        adjuster_hidden_dim: int = 128,
        use_lstm: bool = True,
        discount: float = 0.99,
        tau: float = 0.005,
        batch_size: int = 256,
        num_qs: int = 2,
        num_min_qs: Optional[int] = None,
        mb_ensemble_size: int = 5,
        rm_ensemble_size: int = 1,
        critic_dropout_rate: Optional[float] = None,
        critic_weight_decay: Optional[float] = None,
        critic_layer_norm: bool = False,
        target_entropy: Optional[float] = None,
        init_temperature: float = 1.0,
        backup_entropy: bool = True,
        use_pnorm: bool = False,
        use_critic_resnet: bool = False,
        preload_actor_params: Optional[dict] = None,
        preload_critic_params: Optional[dict] = None,
        preload_target_critic_params: Optional[dict] = None,
    ):
        """
        An implementation of the version of Soft-Actor-Critic described in https://arxiv.org/abs/1812.05905
        """

        action_dim = action_space.shape[-1]
        observations = observation_space.sample()
        actions = action_space.sample()

        if target_entropy is None:
            target_entropy = -action_dim / 2

        rng = jax.random.PRNGKey(seed)
        rng, actor_key, critic_key, temp_key, \
            dynamics_key, reward_model_key, adjuster_key = jax.random.split(rng, 7)

        actor_base_cls = partial(
            MLP_RLPD, hidden_dims=hidden_dims, activate_final=True, use_pnorm=use_pnorm
        )
        actor_def = TanhNormal(actor_base_cls, action_dim)
        if not preload_actor_params is None:
            actor_params = preload_actor_params
            actor_params = unfreeze(actor_params)
            actor_params['MLP_RLPD_0'] = actor_params['MLP_0']
            del actor_params['MLP_0']
            actor_params = freeze(actor_params)
        else:
            actor_params = actor_def.init(actor_key, observations)["params"]
        actor = TrainState.create(
            apply_fn=actor_def.apply,
            params=actor_params,
            tx=optax.adam(learning_rate=actor_lr),
        )

        if use_critic_resnet:
            critic_base_cls = partial(
                MLPResNetV2,
                num_blocks=1,
            )
        else:
            critic_base_cls = partial(
                MLP_RLPD,
                hidden_dims=hidden_dims,
                activate_final=True,
                dropout_rate=critic_dropout_rate,
                use_layer_norm=critic_layer_norm,
                use_pnorm=use_pnorm,
            )
        critic_cls = partial(StateActionValue, base_cls=critic_base_cls)
        # ensembled critic
        critic_def = Ensemble(critic_cls, num=num_qs)
        target_critic_def = Ensemble(critic_cls, num=num_min_qs or num_qs)

        if critic_weight_decay is not None:
            tx = optax.adamw(
                learning_rate=critic_lr,
                weight_decay=critic_weight_decay,
                mask=decay_mask_fn,
            )
        else:
            tx = optax.adam(learning_rate=critic_lr)
        
        if not preload_critic_params is None:
            critic_params = preload_critic_params

            critic_params = unfreeze(critic_params)
            critic_params['VmapStateActionValue_0']['MLP_RLPD_0'] = critic_params['VmapStateActionValue_0']['MLP_0']
            del critic_params['VmapStateActionValue_0']['MLP_0']
            critic_params = freeze(critic_params)

            target_critic_params = preload_target_critic_params

            target_critic_params = unfreeze(target_critic_params)
            target_critic_params['VmapStateActionValue_0']['MLP_RLPD_0'] = target_critic_params['VmapStateActionValue_0']['MLP_0']
            del target_critic_params['VmapStateActionValue_0']['MLP_0']
            target_critic_params = freeze(target_critic_params)
        else:
            critic_params = critic_def.init(critic_key, observations, actions)["params"]
            target_critic_params = copy.deepcopy(critic_params)
        critic = TrainState.create(
            apply_fn=critic_def.apply,
            params=critic_params,
            tx=tx,
        )
        target_critic = TrainState.create(
            apply_fn=target_critic_def.apply,
            params=target_critic_params,
            tx=optax.GradientTransformation(lambda _: None, lambda _: None),
        )

        temp_def = Temperature(init_temperature)
        temp_params = temp_def.init(temp_key)["params"]
        temp = TrainState.create(
            apply_fn=temp_def.apply,
            params=temp_params,
            tx=optax.adam(learning_rate=temp_lr),
        )
        
        
        dynamics_base_cls = partial(
            MLP_RLPD,
            hidden_dims=hidden_dims,
            activate_final=True,
            dropout_rate=critic_dropout_rate,
            use_layer_norm=False,
            use_pnorm=use_pnorm,
        )
        dynamics_cls = partial(DynamicsModel, base_cls=dynamics_base_cls)
        dynamics_def = Ensemble(dynamics_cls, num=mb_ensemble_size)
        tx = optax.adam(learning_rate=critic_lr)
        dynamics_params = dynamics_def.init(dynamics_key, observations, actions)["params"]
        dynamics = TrainState.create(
            apply_fn=dynamics_def.apply,
            params=dynamics_params,
            tx=tx,
        )
        
        reward_model_base_cls = partial(
            MLP_RLPD,
            hidden_dims=hidden_dims,
            activate_final=True,
            dropout_rate=critic_dropout_rate,
            use_layer_norm=False,
            use_pnorm=use_pnorm,
        )
        reward_model_cls = partial(StateActionValue, base_cls=reward_model_base_cls)
        reward_model_def = Ensemble(reward_model_cls, num=rm_ensemble_size)
        tx = optax.adam(learning_rate=critic_lr)
        reward_model_params = reward_model_def.init(dynamics_key, observations, actions)["params"]
        reward_model = TrainState.create(
            apply_fn=reward_model_def.apply,
            params=reward_model_params,
            tx=tx,
        )

        # adjuster: neural adjuster trained by supervised learning
        adjuster_base_cls = partial(
            LSTM_RLPD,
        )
        # adjuster_cls = AdjusterModel(base_cls=adjuster_base_cls, dropout_rate=critic_dropout_rate)
        adjuster_cls = partial(
            AdjusterModel,
            dropout_rate=critic_dropout_rate,
            hidden_dim=adjuster_hidden_dim,
            use_lstm=use_lstm,
        )
        adjuster_def = adjuster_cls(base_cls=adjuster_base_cls)
        tx = optax.adam(learning_rate=critic_lr)
        # print(observations.shape)
        adjuster_params = adjuster_def.init(adjuster_key, jnp.empty(adjuster_hidden_dim), (observations, actions, [False]))["params"]
        adjuster = TrainState.create(
            apply_fn=adjuster_def.apply,
            params=adjuster_params,
            tx=tx,
        )
        
        # adjuster_Q: Q-value model for the RL adjuster agent, which have two 
        # actions: to intervene or not to intervene
        adjuster_Q_base_cls = partial(
            MLP_RLPD,
            hidden_dims=hidden_dims,
            activate_final=True,
            dropout_rate=critic_dropout_rate,
            use_layer_norm=False,
            use_pnorm=use_pnorm,
        )
        adjuster_Q_cls = partial(QValue, base_cls=adjuster_Q_base_cls, num_actions=2)
        adjuster_Q_def = Ensemble(adjuster_Q_cls, num=num_qs)
        target_adjuster_Q_def = Ensemble(adjuster_Q_cls, num=num_min_qs or num_qs)
        tx = optax.adam(learning_rate=critic_lr * 1e-4)
        adjuster_Q_params = adjuster_Q_def.init(critic_key, observations, actions)["params"]
        target_adjuster_Q_params = copy.deepcopy(adjuster_Q_params)
        adjuster_Q = TrainState.create(
            apply_fn=adjuster_Q_def.apply,
            params=adjuster_Q_params,
            tx=tx,
        )
        target_adjuster_Q = TrainState.create(
            apply_fn=target_adjuster_Q_def.apply,
            params=target_adjuster_Q_params,
            tx=optax.GradientTransformation(lambda _: None, lambda _: None),
        )        
        

        return cls(
            rng=rng,
            actor=actor,
            critic=critic,
            target_critic=target_critic,
            temp=temp,
            dynamics=dynamics,
            reward_model=reward_model,
            adjuster=adjuster,
            adjuster_Q=adjuster_Q,
            target_adjuster_Q=target_adjuster_Q,
            target_entropy=target_entropy,
            tau=tau,
            discount=discount,
            num_qs=num_qs,
            num_min_qs=num_min_qs,
            backup_entropy=backup_entropy,
        )

    def update_dynamics(self, batch: DatasetDict):
        
        key, rng = jax.random.split(self.rng)
        
        def dynamics_loss_fn(params: Params) -> Tuple[jnp.ndarray, Dict[str, jnp.ndarray]]:
            obs = batch["observations"]
            action = batch["actions"]
            next_obs = batch["next_observations"]

            pred = self.dynamics.apply_fn(
                {"params": params}, 
                obs, 
                action,
                True,
                rngs={"dropout": key},
            )
            loss = jnp.mean((next_obs - pred) ** 2)
            state_uncertainty = pred.std(axis=0).mean()

            return loss, {"dynamics_loss": loss, "state_uncertainty": state_uncertainty}

        grads, info = jax.grad(dynamics_loss_fn, has_aux=True)(self.dynamics.params)
        dynamics = self.dynamics.apply_gradients(grads=grads)

        return self.replace(dynamics=dynamics, rng=rng), info
    
    def update_reward_model(self, batch: DatasetDict, pbrl_normalize_ratio):
            
        key, rng = jax.random.split(self.rng)

        if not batch:
            return self, {}
        
        def reward_model_loss_fn(params: Params) -> Tuple[jnp.ndarray, Dict[str, jnp.ndarray]]:

            obs, action = batch["obs"], batch["action"]
            batch_size = obs.shape[0]
            
            traj_output = self.reward_model.apply_fn(
                {"params": params}, 
                obs,
                action,
                True,
                rngs={"dropout": key},
            ).squeeze(0)
            
            # (batch_size, 2)
            logits = jnp.mean(traj_output, axis=1)
            labels = jnp.tile(jnp.array([1, 0]), (batch_size, 1))
            loss = optax.softmax_cross_entropy(logits, labels).mean() + pbrl_normalize_ratio * jnp.abs(traj_output).mean()
            # loss = optax.softmax_cross_entropy(logits, labels).mean()

            return loss, {"reward_model_loss": loss, "avg_better": jnp.mean(traj_output[:, :, 0]), "avg_worse": jnp.mean(traj_output[:, :, 1])}

        grads, info = jax.grad(reward_model_loss_fn, has_aux=True)(self.reward_model.params)
        reward_model = self.reward_model.apply_gradients(grads=grads)

        return self.replace(reward_model=reward_model, rng=rng), info
    
    def update_adjuster(self, batch: DatasetDict, hstate):
        
        # no ground truth intervene labels
        # jax.debug.print(" ".join(batch.keys()))
        if "gt_intervenes" not in batch:
            return self, {}
        key, rng = jax.random.split(self.rng)
    
        def adjuster_loss_fn(params: Params) -> Tuple[jnp.ndarray, Dict[str, jnp.ndarray]]:
            obs = batch["observations"]
            action = batch["actions"]
            gt_intervene = batch["gt_intervenes"]
            done = batch["dones"]

            h_state, pred = self.adjuster.apply_fn(
                {"params": params}, 
                hstate, 
                (obs, action, done),
                True,
                rngs={"dropout": key},
            )
            # compute cross entropy loss with logits inputs
            # TODO: mask out ground truth intervene with label -1
            # jax.debug.print(str(gt_intervene.shape) + " " + str(pred.shape))
            loss = jnp.mean(optax.sigmoid_binary_cross_entropy(pred, gt_intervene.reshape(-1, 1)))
            # use sigmoid to get binary prediction
            pred_label = jnp.round(jax.nn.sigmoid(pred))
            accuracy = jnp.mean(pred_label == gt_intervene)
            info = {
                "adjuster_loss": loss, 
                "avg_h": jnp.mean(h_state), 
                "avg_out": jnp.mean(pred),
                "avg_pred_label": jnp.mean(pred_label),
                "avg_label": jnp.mean(gt_intervene),
                "accuracy": accuracy,
                "h_state": h_state,
            }

            return loss, info

        grads, info = jax.grad(adjuster_loss_fn, has_aux=True)(self.adjuster.params)
        hstate = info.pop("h_state")
        adjuster = self.adjuster.apply_gradients(grads=grads)

        return self.replace(adjuster=adjuster, rng=rng), info, hstate
    
    def update_adjuster_Q(self, batch: DatasetDict, use_pbrl=False, pbrl_rew_ratio=0.01):
        
        key, rng = jax.random.split(self.rng)
        target_params = subsample_ensemble(
            key, self.target_adjuster_Q.params, self.num_min_qs, self.num_qs
        )

        key, rng = jax.random.split(rng)
        next_qs = self.target_adjuster_Q.apply_fn(
            {"params": target_params},
            batch["next_observations"],
            True,
            rngs={"dropout": key},
        )  # training=True

        # min over ensemble, max over actions
        next_q = next_qs.min(axis=0).max(axis=-1)

        if use_pbrl:
            traj_reward = self.reward_model.apply_fn(
                {"params": self.reward_model.params}, 
                batch["human_observation_list"],
                batch["human_action_list"],
                True,
                rngs={"dropout": key},
            ).squeeze(0).sum(axis=1)
        
            # batch human observation list is empty when adjuster rewards are 0, i.e., no intervention
            rm_output = jax.lax.select(
                batch["has_human"],
                traj_reward,
                jnp.zeros_like(traj_reward),
            )

            # normalize rm_output in batch
            rm_output = rm_output / (1+jnp.abs(rm_output).max())
        else:
            rm_output = jnp.zeros_like(batch["adjuster_rewards"])
            
        target_q = batch["adjuster_rewards"] + pbrl_rew_ratio * rm_output + \
                    self.discount * batch["masks"] * next_q
        # target_q = batch["adjuster_rewards"] + self.discount * batch["masks"] * next_q
        
        key, rng = jax.random.split(rng)
        def adjuster_Q_loss_fn(params: Params):
            # (num_qs, batch_size, 2)
            qs = self.adjuster_Q.apply_fn(
                {"params": params},
                batch["observations"],
                True,
                rngs={"dropout": key},
            )  # training=True

            # expand actions for num_qs times
            # (num_qs, batch_size, 1)
            # action_to_take = jnp.tile(batch["adjuster_actions"], (self.num_qs, 1)).astype(jnp.int32)[:,:,None]
            # Retrieve the q-values for the actions from the replay buffer
            # (num_qs, batch_size, 1)
            # current_q_values = jnp.take_along_axis(qs, action_to_take, axis=2)
            # target q: (batch_size,)

            # new way of computing current q
            # (num_qs, batch_size)
            current_q_values = qs[jnp.arange(qs.shape[0])[:, None], jnp.arange(qs.shape[1]), batch["adjuster_actions"]]
            # jax.debug.print(f"current q values shape: {str(current_q_values.shape)}")
            # adjuster_Q_loss = jnp.mean((current_q_values - target_q) ** 2)

            # Compute Huber loss (less sensitive to outliers)
            adjuster_Q_loss = jnp.mean(optax.huber_loss(current_q_values, target_q))
            
            return adjuster_Q_loss, {
                "adjuster_q/adjuster_Q_loss": adjuster_Q_loss, 
                "adjuster_q/adjuster_q_mean": qs.mean(), 
                "adjuster_q/adjuster_q_std": qs.std(),
                "adjuster_q/current_q_mean": current_q_values.mean(),
                "adjuster_q/current_q_std": current_q_values.std(),
                "adjuster_q/adjuster_q_0_mean": qs[:,:,0].mean(),
                "adjuster_q/adjuster_q_1_mean": qs[:,:,1].mean(),
                "adjuster_q/target_adjuster_q_mean": target_q.mean(),
                "adjuster_q/target_adjuster_q_std": target_q.std(),
                "adjuster_q/adjuster_a_mean": batch["adjuster_actions"].mean(),
            }

        grads, info = jax.grad(adjuster_Q_loss_fn, has_aux=True)(self.adjuster_Q.params)
        adjuster_Q = self.adjuster_Q.apply_gradients(grads=grads)

        target_adjuster_Q_params = optax.incremental_update(
            adjuster_Q.params, self.target_adjuster_Q.params, self.tau
        )
        target_adjuster_Q = self.target_adjuster_Q.replace(params=target_adjuster_Q_params)

        if use_pbrl:
            info["adjuster_q/avg_rm_output"] = rm_output.sum()/(1+batch["has_human"].sum())
            info["adjuster_q/avg_adj_rew_train"] = batch["adjuster_rewards"].mean()
            info["adjuster_q/has_human_ratio"] = batch["has_human"].mean()
            # info["adj_q_grad"] = grads["params"]["VmapQValue_0"]["MLP_RLPD_0"]["Dense_0"]["kernel"]


        return self.replace(adjuster_Q=adjuster_Q, target_adjuster_Q=target_adjuster_Q, rng=rng), info
        

    def update_critic(self, batch: DatasetDict, use_pbrl_reward=False) -> Tuple[TrainState, Dict[str, float]]:

        dist = self.actor.apply_fn(
            {"params": self.actor.params}, batch["next_observations"]
        )

        rng = self.rng

        key, rng = jax.random.split(rng)
        next_actions = dist.sample(seed=key)

        # Used only for REDQ.
        key, rng = jax.random.split(rng)
        target_params = subsample_ensemble(
            key, self.target_critic.params, self.num_min_qs, self.num_qs
        )

        key, rng = jax.random.split(rng)
        next_qs = self.target_critic.apply_fn(
            {"params": target_params},
            batch["next_observations"],
            next_actions,
            True,
            rngs={"dropout": key},
        )  # training=True
        next_q = next_qs.min(axis=0)
        
        if use_pbrl_reward:
            traj_reward = self.reward_model.apply_fn(
                    {"params": self.reward_model.params}, 
                    batch["observations"],
                    batch["actions"],
                    True,
                    rngs={"dropout": key},
                ).squeeze(0)
           # .squeeze(0).sum(axis=1)
        else:
            traj_reward = 0

        target_q = batch["rewards"] + 0.03 * traj_reward + self.discount * batch["masks"] * next_q

        if self.backup_entropy:
            next_log_probs = dist.log_prob(next_actions)
            target_q -= (
                self.discount
                * batch["masks"]
                * self.temp.apply_fn({"params": self.temp.params})
                * next_log_probs
            )

        key, rng = jax.random.split(rng)

        def critic_loss_fn(critic_params) -> Tuple[jnp.ndarray, Dict[str, float]]:
            qs = self.critic.apply_fn(
                {"params": critic_params},
                batch["observations"],
                batch["actions"],
                True,
                rngs={"dropout": key},
            )  # training=True
            critic_loss = ((qs - target_q) ** 2).mean()
            return critic_loss, {"critic_loss": critic_loss, "q": qs.mean()}

        grads, info = jax.grad(critic_loss_fn, has_aux=True)(self.critic.params)
        critic = self.critic.apply_gradients(grads=grads)

        target_critic_params = optax.incremental_update(
            critic.params, self.target_critic.params, self.tau
        )
        target_critic = self.target_critic.replace(params=target_critic_params)

        return self.replace(critic=critic, target_critic=target_critic, rng=rng), info

    
    @partial(jax.jit, static_argnames=["intervene_type", "utd_ratio", "use_pbrl", "pbrl_normalize_ratio", "pbrl_rew_ratio", "pbrl_in_learning_agent"])
    def update(self, batch: DatasetDict, pbrl_batch: DatasetDict, hstate, intervene_type, 
               utd_ratio: int, use_pbrl=False, pbrl_normalize_ratio=0, pbrl_rew_ratio=0,
               pbrl_in_learning_agent=False):

        new_agent: RLPhiLearner = self
        for i in range(utd_ratio):

            def slice(x):
                assert x.shape[0] % utd_ratio == 0
                # ensembled critic
                batch_size = x.shape[0] // utd_ratio
                return x[batch_size * i : batch_size * (i + 1)]

            mini_batch = jax.tree_util.tree_map(slice, batch)
            new_agent, critic_info = new_agent.update_critic(mini_batch, pbrl_in_learning_agent)
            if intervene_type == 'neural_rl':
                new_agent, adjuster_Q_info = new_agent.update_adjuster_Q(mini_batch, use_pbrl, pbrl_rew_ratio)
            else:
                adjuster_Q_info = {}

        new_agent, actor_info = new_agent.update_actor(mini_batch)
        new_agent, temp_info = new_agent.update_temperature(actor_info["entropy"])
        new_agent, dynamics_info = new_agent.update_dynamics(mini_batch)
        if use_pbrl:
            new_agent, reward_model_info = new_agent.update_reward_model(pbrl_batch, pbrl_normalize_ratio)
        else:
            reward_model_info = {}
        if intervene_type == 'neural':
            new_agent, adjuster_info, hstate = new_agent.update_adjuster(mini_batch, hstate)
        else:
            adjuster_info = {}
        
        info_dict = {
            **actor_info,
            **critic_info,
            **temp_info,
            **dynamics_info,
            **reward_model_info,
            **adjuster_info,
            **adjuster_Q_info,
        }

        return new_agent, info_dict, hstate
