import jax.random
from rl.td3_plusbc import TD3PlusBC
from rl.policies.policy import DeterministicTQCPolicy, RLTrainState
from rl.utils.replay_buffer import ReplayBuffer, QLearningBatch
import gymnasium as gym
from rl.utils.risk_utils import *
from tqdm import trange


class QuantileTD3PlusBC(TD3PlusBC):
    policy: DeterministicTQCPolicy
    risk_measures = {"cvar": cvar, "pow": pow, "power": pow, "wang": wang}
    densities = {"cvar": cvar_density, "wang": wang_density}
    actor_sampling_quantiles: int = 100

    def __init__(self,
                 env: gym.Env,
                 buffer: ReplayBuffer,
                 normalizer: Optional[Callable] = False,
                 gamma: float = 0.99,
                 batch_size: int = 256,
                 opt_class: Callable = optax.adam,
                 learning_rate: float = 3e-4,
                 q_learning_scale: float = 2.5,
                 risk_type: str = 'cvar',
                 risk_eta: float = 0.5,
                 n_quantiles: int = 16,
                 drop_per_net: int = 3,
                 policy_delay: int = 2,
                 n_critics: int = 3,
                 bc_weight: float = 1.,
                 smooth: bool = False,
                 fourier_feature_critic: bool = False,
                 seed: int = 42,
                 tqc: bool = False,
                 ):
        self.risk_type: str = risk_type
        self.risk_measure: Callable = self.risk_measures[risk_type]
        self.risk_eta = risk_eta
        self.n_quantiles = n_quantiles
        self.drop_per_net = drop_per_net
        self.smooth = smooth
        self.bc_weight = bc_weight
        self.fourier_feature = fourier_feature_critic
        self.tqc = tqc

        super().__init__(env,
                         buffer=buffer,
                         normalizer=normalizer,
                         gamma=gamma,
                         batch_size=batch_size,
                         opt_class=opt_class,
                         learning_rate=learning_rate,
                         q_learning_scale=q_learning_scale,
                         n_critics=n_critics,
                         policy_delay=policy_delay,
                         seed=seed,
                         )
        self.density = self.densities[self.risk_type]

    def build(self, ):
        observation_ph, action_ph = self.make_placeholder()
        self.policy = DeterministicTQCPolicy(observation_ph, action_ph,
                                             opt_class=self.opt_class, learning_rate=self.learning_rate,
                                             seed=next(self.hk_rng),
                                             n_critics=self.n_critics,
                                             critic_update_fn=self.build_critic_update_fn() if self.tqc else
                                             self.build_classic_critic_update_fn(),
                                             actor_update_fn=self.build_actor_update_fn(),
                                             smooth=self.smooth,
                                             ff_feature=self.fourier_feature,
                                             )

    def build_classic_critic_update_fn(self, ) -> Callable:
        gamma = self.gamma
        target_update_rate = self.target_update_rate

        @jax.jit
        def update_fn(critic_train_state: RLTrainState,
                      actor_train_state: RLTrainState,
                      batch: QLearningBatch, key):
            @partial(jax.jit, static_argnums=(1,))
            def get_tau(key, shape):
                presum_tau = jax.random.uniform(key, shape) + 0.1
                presum_tau /= presum_tau.sum(axis=-1, keepdims=True)
                tau = jnp.cumsum(presum_tau, axis=-1)  # (N, T), note that they are tau1...tauN in the paper
                tau_hat = 0.5 * (tau[..., 1:] + tau[..., :-1])
                tau_hat = jnp.concatenate([tau[..., 0:1] / 2, tau_hat], axis=-1)
                return tau, tau_hat, presum_tau

            def loss_fn(param_critic):
                key1, key2 = jax.random.split(key, 2)
                batch_size = batch.observations.shape[0]
                _, taus, presum_taus = get_tau(key1, shape=(batch_size, self.n_quantiles))
                _, next_taus, next_presum_tau = get_tau(key2, shape=(batch_size, self.n_quantiles))
                q_value = critic_train_state.apply_fn({'params': param_critic}, batch.observations,
                                                      batch.actions, taus)

                next_action = actor_train_state.apply_fn({'params': actor_train_state.target_params},
                                                         batch.next_observations)
                noise = 0.2 * jax.random.normal(key, shape=next_action.shape)
                noise = noise.clip(-0.5, 0.5)
                next_action = (next_action + noise).clip(-1, 1)
                next_q_value = critic_train_state.apply_fn({'params': critic_train_state.target_params},
                                                           batch.next_observations, next_action, next_taus)

                neg_risks = critic_train_state.apply_fn({'params': critic_train_state.target_params},
                                                        batch.next_observations, next_action,
                                                        self.risk_measure(next_taus, self.risk_eta)
                                                        ).mean(axis=-2)
                index = jnp.argmin(neg_risks, axis=-1, keepdims=True)
                index = index[..., None]
                diff = jnp.diff(next_q_value, axis=-2)
                next_q_value = jnp.take_along_axis(next_q_value, indices=index, axis=-1).squeeze()
                td_target = jax.lax.stop_gradient(batch.rewards + gamma * (1 - batch.dones) * next_q_value)

                loss = jax.vmap(quantile_huber_loss, in_axes=(-1, None, None, None), out_axes=-1)(q_value, td_target,
                                                                                                  taus, next_presum_tau)
                qr_loss = loss.sum(axis=-1).mean()

                loss = qr_loss
                return loss, {"q_loss": qr_loss,
                              "ord": 100 * (diff < 0).astype(jnp.float32).sum() / diff.size,

                              }

            grads, loss_info = jax.grad(loss_fn, has_aux=True)(critic_train_state.params)
            state = critic_train_state.apply_gradients(grads=grads)
            critic_target_params = jax.jit(optax.incremental_update, static_argnums=(2,))(
                state.params, critic_train_state.target_params, target_update_rate)
            state = state.replace(target_params=critic_target_params)
            return state, loss_info

        return update_fn

    def build_critic_update_fn(self, ):
        gamma = self.gamma
        target_update_rate = self.target_update_rate
        drops = self.drop_per_net * self.n_critics

        @jax.jit
        def update_fn(critic_train_state: RLTrainState,
                      actor_train_state: RLTrainState,
                      batch: QLearningBatch, key):
            def loss_fn(param_critic):
                taus_sample = jax.random.uniform(key, shape=(batch.observations.shape[0], self.n_quantiles, 2))
                taus = taus_sample[..., 0].sort(axis=-1)
                next_taus = taus_sample[..., 1]
                next_taus = next_taus.sort(axis=-1)
                q_value = critic_train_state.apply_fn({'params': param_critic}, batch.observations,
                                                      batch.actions, taus)

                next_action = actor_train_state.apply_fn({'params': actor_train_state.target_params},
                                                         batch.next_observations)
                noise = 0.2 * jax.random.normal(key, shape=next_action.shape)
                noise = noise.clip(-0.5, 0.5)
                next_action = (next_action + noise).clip(-1, 1)
                next_q_value = critic_train_state.apply_fn({'params': critic_train_state.target_params},
                                                           batch.next_observations, next_action, next_taus)

                diff = jnp.diff(next_q_value, axis=-2)
                next_q_value = next_q_value.reshape(next_q_value.shape[0], -1).sort(axis=-1)
                if drops > 0:
                    next_q_value = jax.lax.stop_gradient(next_q_value[:, :-self.drop_per_net * self.n_critics])
                else:
                    next_q_value = jax.lax.stop_gradient(next_q_value)

                td_target = jax.lax.stop_gradient(batch.rewards + gamma * (1 - batch.dones) * next_q_value)

                loss = jax.vmap(quanitle_regression_loss, in_axes=(None, -1, None), out_axes=-1)(td_target, q_value,
                                                                                                 taus)
                qr_loss = loss.sum(axis=-1).mean()

                loss = qr_loss
                return loss, {"q_loss": qr_loss,
                              "ord": 100 * (diff < 0).astype(jnp.float32).sum() / diff.size,

                              }

            grads, loss_info = jax.grad(loss_fn, has_aux=True)(critic_train_state.params)
            state = critic_train_state.apply_gradients(grads=grads)
            critic_target_params = jax.jit(optax.incremental_update, static_argnums=(2,))(
                state.params, critic_train_state.target_params, target_update_rate)
            state = state.replace(target_params=critic_target_params)
            return state, loss_info

        return update_fn

    def build_actor_update_fn(self):
        target_update_rate = self.target_update_rate
        q_learning_scale = self.q_learning_scale
        risk_measure = self.risk_measure
        risk_eta = self.risk_eta

        @jax.jit
        def update_fn(critic_train_state: RLTrainState,
                      actor_train_state: RLTrainState,
                      batch: QLearningBatch,
                      key
                      ):
            key = jax.random.split(key, 32)[-1]

            def loss_fn(params):
                taus = jax.random.uniform(key, shape=(batch.actions.shape[0], self.actor_sampling_quantiles))
                # taus = risk_measure(taus, risk_eta)
                actions = actor_train_state.apply_fn({'params': params}, batch.observations)
                risk_q_values = critic_train_state.apply_fn({'params': critic_train_state.params, },
                                                            batch.observations, actions, risk_measure(taus, risk_eta),
                                                            )

                risk = -(risk_q_values.mean(axis=-2).min(axis=-1))
                scale = jax.lax.stop_gradient(jnp.abs(risk).mean())

                bc_loss = ((actions - batch.actions) ** 2).mean()
                loss = q_learning_scale * (risk / scale).mean() + self.bc_weight * bc_loss

                return loss, {"risk": risk.mean(), "bc_loss": bc_loss}

            grads, loss_info = jax.grad(loss_fn, has_aux=True)(actor_train_state.params)
            state = actor_train_state.apply_gradients(grads=grads)
            target_param = optax.incremental_update(state.params,
                                                    actor_train_state.target_params, target_update_rate)
            state = state.replace(target_params=target_param)

            return state, loss_info

        return update_fn

    def __str__(self):
        return "IQN_TD3PlusBC"

    def config(self):
        cfg = {"gamma": self.gamma,
               "risk": self.risk_type,
               "risk_eta": self.risk_eta,
               "lr": self.learning_rate,
               "opt_class": self.opt_class.__name__,
               "batch_size": self.batch_size,
               "q_learning_scale": self.q_learning_scale,
               "n_quantiles": self.n_quantiles,
               "drop_per_net": self.drop_per_net,
               "n_critics": self.drop_per_net,
               }
        return cfg

    def score_metric(self, array):
        array = np.asarray(array)
        array.sort()
        if self.risk_type == 'cvar':
            return array[:int(len(array) * self.risk_eta)].mean(), array
        empirical_quantile_fn = partial(np.interp, fp=array, xp=np.linspace(0, 1, len(array)))
        linspace = np.linspace(0, 1, 10000)
        if self.risk_type == 'wang':
            return empirical_quantile_fn(wang(linspace, self.risk_eta)).mean(), array
        elif self.risk_type == 'power' or self.risk_type == 'pow':
            return empirical_quantile_fn(pow(linspace, self.risk_eta)).mean(), array
        else:
            raise NotImplementedError(f"risk measure {self.risk_type} is not implemented yet.")

    def train(self,
              epoch: int,
              len_epoch: int = 1000,
              eval_interval: int = 5,
              n_eval: int = 10,
              final_eval: int = 100,
              normalizer: Optional[Callable] = None,
              eval_pretrain: bool = False
              ) -> dict:
        if eval_pretrain:
            scores = self.evaluate(n_eval, self.env)
            mean, std = np.mean(scores), np.std(scores)
            print("PRETRAIN SCORE")
            if normalizer is not None:
                print(f"SCORE {mean} +/- {std}:  NORMALIZED {normalizer(mean) * 100:.2f}%")
            else:
                print(f"SCORE {mean} +/- {std}")

        for e in range(epoch):
            self.epoch_learn(len_epoch)
            if e % eval_interval == 0:
                scores = self.evaluate(n_eval, self.env)
                mean, std = np.mean(scores), np.std(scores)
                print(f"EPOCH {e}:::")
                if n_eval >= 50:
                    risk, _ = self.score_metric(scores)
                    if normalizer is not None:
                        print(f"SCORE {mean} +/- {std}:{self.risk_type}@{self.risk_eta}: {risk:.4f}\n"
                              f" NORMALIZED {normalizer(mean) * 100:.2f}% {self.risk_type}@{self.risk_eta}: {normalizer(risk):.4f}")
                    else:
                        print(f"SCORE {mean} +/- {std}:{self.risk_type}@{self.risk_eta}: {risk:.4f}\n")
                else:
                    if normalizer is not None:
                        print(f"SCORE {mean} +/- {std}:  NORMALIZED {normalizer(mean) * 100:.2f}%")
                    else:
                        print(f"SCORE {mean} +/- {std}")
        scores = self.evaluate(final_eval, self.env)
        mean, std = np.mean(scores), np.std(scores)
        risk, _ = self.score_metric(scores)

        if normalizer is not None:
            print(f"SCORE {mean} +/- {std}:{self.risk_type}@{self.risk_eta}: {risk:.4f}\n"
                  f" NORMALIZED {normalizer(mean) * 100:.2f}% {self.risk_type}@{self.risk_eta}: {normalizer(risk):.4f}")
        else:
            # {test_risk:.4f}
            print(f"SCORE {mean} +/- {std}:{self.risk_type}@{self.risk_eta}: {risk:.4f}")
        log = {"risk_type": self.risk_type, "risk_eta": self.risk_eta,
               "mean": mean, "std": std, "risk": risk}
        return log

    def train_airsim(self,
                     epoch: int,
                     len_epoch: int = 1000,
                     final_eval: int = 100,
                     ) -> dict:

        for e in range(epoch):
            self.epoch_learn(len_epoch)
        scores, success_ratio, collision_ratio = self.evaluate_airsim(final_eval, self.env)
        mean, std = np.mean(scores), np.std(scores)
        risk, _ = self.score_metric(scores)

        print(f"SCORE {mean} +/- {std}:{self.risk_type}@{self.risk_eta}: {risk:.4f} \n "
              f"SUCCESS RATE: {success_ratio * 100:.2f}%, COLLISION RATE {collision_ratio * 100:.2f}%")
        log = {"risk_type": self.risk_type, "risk_eta": self.risk_eta,
               "mean": mean, "std": std, "risk": risk, "success_ratio": success_ratio,
               "collision_ratio": collision_ratio
               }
        return log

    def evaluate_airsim(self, n_eval: int, env):
        scores = []
        success_log = []
        collision_log = []
        for _ in trange(n_eval):
            seed = self.np_rng.integers(0, 2 ** 30, size=(1,)).item()
            obs, _ = env.reset(seed=seed)
            done = False
            score = 0
            success = False
            collision = False
            while not done:
                action = self.predict(obs, deterministic=False)
                obs, reward, done, timeout, info = env.step(action)
                if 'is_success' in info.keys():
                    if info['is_success']:
                        success = True
                if 'collision' in info.keys():
                    if info['collision']:
                        success = False
                        collision = True

                score += reward
                done = done or timeout
            scores.append(score)
            success_log.append(success)
            collision_log.append(collision)
        success = np.mean(success_log)
        collision = np.mean(collision_log)
        return np.asarray(scores), success, collision

