from rl.cql import CQL
from rl.policies.policy import CODACPolicy, RLTrainState, TrainState
from rl.utils.risk_utils import cvar, pow, wang, cvar_density, wang_density, get_tau, quantile_huber_loss
from rl.utils.replay_buffer import ReplayBuffer, QLearningBatch

import jax
import jax.numpy as jnp
import numpy as np
from tqdm import trange

import gymnasium as gym
from typing import Callable, Optional
from optax import adam
import optax
from functools import partial


class CODAC(CQL):
    policy: CODACPolicy
    risk_measures = {"cvar": cvar, "pow": pow, "power": pow, "wang": wang}
    densities = {"cvar": cvar_density, "wang": wang_density}
    actor_sampling_quantiles: int = 100

    def __init__(self,
                 env: gym.Env,
                 buffer: ReplayBuffer,
                 gamma: float = 0.99,
                 batch_size: int = 256,
                 opt_class: Callable = adam,
                 learning_rate: float = 3e-4,
                 n_critics: int = 2,
                 risk_type: str = 'cvar',
                 risk_eta: float = 0.5,
                 n_quantiles: int = 16,
                 target_update_rate: float = 0.005,
                 target_entropy: float | str = 'auto',
                 learn_ent_coef: bool = False,
                 ent_coef: float = 0.2,
                 target_diff: float = 10,
                 auto_adjust_kl: bool = False,
                 seed: int = 42,
                 ):
        self.risk_type: str = risk_type
        self.risk_measure: Callable = self.risk_measures[risk_type]
        self.risk_eta = risk_eta
        self.n_quantiles = n_quantiles

        super().__init__(
            env=env,
            buffer=buffer,
            gamma=gamma,
            batch_size=batch_size,
            opt_class=opt_class,
            learning_rate=learning_rate,
            n_critics=n_critics,
            target_update_rate=target_update_rate,
            target_entropy=target_entropy,
            learn_ent_coef=learn_ent_coef,
            ent_coef=ent_coef,
            target_diff=target_diff,
            auto_adjust_kl=auto_adjust_kl,
            seed=seed
        )

    def build(self, ):
        observation_ph, action_ph = self.make_placeholder()
        self.policy = CODACPolicy(observation_ph, action_ph,
                                  opt_class=self.opt_class, learning_rate=self.learning_rate,
                                  seed=next(self.hk_rng),
                                  n_critics=self.n_critics,
                                  auto_adjust_kl=self.auto_adjust_kl,

                                  ent_coef=self.ent_coef,
                                  learn_ent_coef=self.learn_ent_coef,

                                  critic_update_fn=self.build_critic_update_fn(),
                                  actor_update_fn=self.build_actor_update_fn(),
                                  ent_coef_update_fn=self.build_ent_coef_update_fn(self.target_diff),
                                  lag_update_fn=self.build_lag_coef_update_fn(self.target_diff),
                                  )

    def build_critic_update_fn(self, ):
        gamma = self.gamma
        target_update_rate = self.target_update_rate

        @jax.jit
        def update_fn(critic_train_state: RLTrainState,
                      actor_train_state: RLTrainState,
                      ent_coef_state: TrainState,
                      lag_coef_state: TrainState,
                      batch: QLearningBatch, key):
            def loss_fn(param_critic):
                batch_size = batch.observations.shape[0]

                key1, key2, key3, key4, key5 = jax.random.split(key, 5)
                tau, tau_hat, presum_tau = get_tau(key1, shape=(batch_size, self.n_quantiles))

                # current_q shape = (Batch, n_quantiles, n_critics)
                current_z = critic_train_state.apply_fn({"params": param_critic},
                                                        batch.observations, batch.actions, tau_hat)
                ent_coef = jax.lax.stop_gradient(jnp.exp(ent_coef_state.apply_fn(ent_coef_state.params)))
                cql_coef = jax.lax.stop_gradient(jnp.exp(lag_coef_state.apply_fn(lag_coef_state.params)).clip(0, 1e+6))

                ################ compute td_target ##################
                next_actions, next_log_probs = actor_train_state.apply_fn({"params": actor_train_state.target_params},
                                                                          batch.next_observations,
                                                                          rngs={"rng_stream": key})
                _, next_tau_hat, next_presum_tau = get_tau(key2, shape=(batch_size, self.n_quantiles))
                # (batch, n_quantiles)
                next_z = critic_train_state.apply_fn({"params": critic_train_state.target_params},
                                                     batch.next_observations, next_actions, next_tau_hat,
                                                     ).min(axis=-1)
                next_z = next_z - ent_coef * next_log_probs
                td_target = jax.lax.stop_gradient(batch.rewards + gamma * (1 - batch.dones) * next_z)
                ##########################################################################
                td_loss_fn = partial(quantile_huber_loss, target=td_target, taus=tau_hat, weight=next_presum_tau)
                td_error = jax.vmap(td_loss_fn, in_axes=-1, out_axes=-1)(current_z)
                td_error = td_error.sum(axis=-1).mean()

                ############ compute KL-divergence via Donsker-Varadhan Representation ###
                ############# First We Compute log E_(X~Q)[exp(g(X)] #####################
                tau_hat = jax.random.permutation(key5, tau_hat, axis=-1)[..., :1]  # random select index

                repeated_observation = jnp.repeat(batch.observations[:, None], axis=1, repeats=10)
                repeated_next_observation = jnp.repeat(batch.next_observations[:, None], axis=1, repeats=10)

                def act_fn(obs, k):
                    return jax.lax.stop_gradient(actor_train_state.apply_fn({"params": actor_train_state.params},
                                                                            obs, rngs={"rng_stream": k}))

                act_fn = jax.vmap(act_fn, in_axes=(1, 0), out_axes=(1, 1))
                current_action, current_log_pi = act_fn(repeated_observation, jax.random.split(key2, 10))
                next_action, next_log_pi = act_fn(repeated_next_observation, jax.random.split(key3, 10))
                rand_action = 2 * jax.random.uniform(key4, shape=current_action.shape) - 1

                def _critic_fn(obs, action):
                    zfs = critic_train_state.apply_fn({"params": param_critic}, obs, action, tau_hat)
                    return zfs.reshape(zfs.shape[0], -1)

                qf_fn = jax.vmap(_critic_fn, in_axes=(1, 1), out_axes=1)
                # (Batch, N_repeat, N_critics) because we choose tau shape (N, 1).
                q_cur = qf_fn(repeated_observation, current_action)
                q_cur = q_cur - current_log_pi

                q_next = qf_fn(repeated_observation, next_action) - next_log_pi
                q_rand = qf_fn(repeated_observation, rand_action) - jnp.log(0.5) * rand_action.shape[-1]
                cat_q = jnp.concatenate([q_cur, q_next, q_rand], axis=1)
                logsumexpQ = jax.nn.logsumexp(cat_q, axis=1)

                ############## D_{KL}[P, Q] = sup_g E_P[g(X)] - log E_(X~Q)[exp(g(X)] #######
                kl_potential = (logsumexpQ.mean(axis=0) - current_z.mean(axis=(0, 1)))
                ##################################################
                cql_loss = cql_coef * kl_potential.sum(axis=-1)
                loss = td_error + cql_loss
                return loss, {"td": td_error, "cql_loss": cql_loss, "kl_potential": kl_potential.mean()}

            grads, loss_info = jax.grad(loss_fn, has_aux=True)(critic_train_state.params)
            state = critic_train_state.apply_gradients(grads=grads)
            critic_target_params = jax.jit(optax.incremental_update, static_argnums=(2,))(
                state.params, critic_train_state.target_params, target_update_rate)
            state = state.replace(target_params=critic_target_params)
            return state, loss_info

        return update_fn

    def build_actor_update_fn(self):
        target_update_rate = self.target_update_rate
        density = self.densities[self.risk_type]
        n_quantiles = self.actor_sampling_quantiles

        @partial(jax.jit, static_argnums=(1,))
        def get_tau(key, shape):
            presum_tau = jax.random.uniform(key, shape) + 0.1
            presum_tau /= presum_tau.sum(axis=-1, keepdims=True)
            tau = jnp.cumsum(presum_tau, axis=-1)  # (N, T), note that they are tau1...tauN in the paper
            tau_hat = 0.5 * (tau[..., 1:] + tau[..., :-1])
            tau_hat = jnp.concatenate([tau[..., 0:1] / 2, tau_hat], axis=-1)
            return tau, tau_hat, presum_tau

        @jax.jit
        def update_fn(critic_train_state: RLTrainState,
                      actor_train_state: RLTrainState,
                      ent_coef_state: TrainState,
                      batch: QLearningBatch,
                      key: jax.Array,
                      ):
            def loss_fn(params):
                key1, key2 = jax.random.split(key, 2)
                actions, log_probs = actor_train_state.apply_fn({'params': params}, batch.observations,
                                                                rngs={"rng_stream": key1})
                batch_size = actions.shape[0]
                _, taus_hat, presum_tau = get_tau(key2, shape=(batch_size, n_quantiles))

                q_values = critic_train_state.apply_fn({'params': critic_train_state.params},
                                                       batch.observations, actions, taus_hat)
                weight = jax.lax.stop_gradient(density(taus_hat, self.risk_eta))

                def value(q_value):
                    return (q_value * weight * presum_tau).sum(axis=-1)

                q_values = jax.vmap(value, in_axes=-1, out_axes=-1)(q_values)
                ent_coef = jax.lax.stop_gradient(jnp.exp(ent_coef_state.apply_fn(ent_coef_state.params)))
                loss = ent_coef * log_probs - q_values.min(axis=-1, keepdims=True)
                loss = loss.mean()
                return loss, {"pi_loss": loss, "log_probs": log_probs}

            grads, loss_info = jax.grad(loss_fn, has_aux=True)(actor_train_state.params)
            state = actor_train_state.apply_gradients(grads=grads)
            actor_target_param = jax.jit(optax.incremental_update, static_argnums=(2,))(
                state.params, actor_train_state.target_params, target_update_rate)
            state = state.replace(target_params=actor_target_param)
            return state, loss_info

        return update_fn

    def score_metric(self, array):
        array = np.asarray(array)
        array.sort()
        if self.risk_type == 'cvar':
            return array[:int(len(array) * self.risk_eta)].mean(), array
        empirical_quantile_fn = partial(np.interp, fp=array, xp=np.linspace(0, 1, len(array)))
        linspace = np.linspace(0, 1, 10000)
        if self.risk_type == 'wang':
            return empirical_quantile_fn(wang(linspace, self.risk_eta)).mean(), array
        elif self.risk_type == 'power' or self.risk_type == 'pow':
            return empirical_quantile_fn(pow(linspace, self.risk_eta)).mean(), array
        else:
            raise NotImplementedError(f"risk measure {self.risk_type} is not implemented yet.")

    def train(self,
              epoch: int,
              len_epoch: int = 1000,
              eval_interval: int = 5,
              n_eval: int = 10,
              final_eval: int = 100,
              normalizer: Optional[Callable] = None,
              ) -> dict:
        scores = self.evaluate(n_eval, self.env)
        mean, std = np.mean(scores), np.std(scores)
        print("PRETRAIN SCORE")
        if normalizer is not None:
            print(f"SCORE {mean} +/- {std}:  NORMALIZED {normalizer(mean) * 100:.2f}%")
        else:
            print(f"SCORE {mean} +/- {std}")
        for e in range(epoch):
            self.epoch_learn(len_epoch)
            if e % eval_interval == 0:
                scores = self.evaluate(n_eval, self.env)
                mean, std = np.mean(scores), np.std(scores)
                print(f"EPOCH {e}:::")
                if n_eval >= 50:
                    risk, _ = self.score_metric(scores)
                    print(f"SCORE {mean} +/- {std}:{self.risk_type}@{self.risk_eta}: {risk:.4f}\n")
                else:
                    if normalizer is not None:
                        print(f"SCORE {mean} +/- {std}:  NORMALIZED {normalizer(mean) * 100:.2f}%")
                    else:
                        print(f"SCORE {mean} +/- {std}")
        scores = self.evaluate(final_eval, self.env)
        mean, std = np.mean(scores), np.std(scores)
        risk, _ = self.score_metric(scores)

        if normalizer is not None:
            print(f"SCORE {mean} +/- {std}:{self.risk_type}@{self.risk_eta}: {risk:.4f}\n"
                  f" NORMALIZED {normalizer(mean) * 100:.2f}% {self.risk_type}@{self.risk_eta}: {normalizer(risk):.4f}")
        else:
            # {test_risk:.4f}
            print(f"SCORE {mean} +/- {std}:{self.risk_type}@{self.risk_eta}: {risk:.4f}")
        log = {"risk_type": self.risk_type, "risk_eta": self.risk_eta,
               "mean": mean, "std": std, "risk": risk}
        return log

    def train_airsim(self,
                     epoch: int,
                     len_epoch: int = 1000,
                     final_eval: int = 100,
                     ) -> dict:

        for e in range(epoch):
            self.epoch_learn(len_epoch)
        scores, success_ratio, collision_ratio = self.evaluate_airsim(final_eval, self.env)
        mean, std = np.mean(scores), np.std(scores)
        risk, _ = self.score_metric(scores)

        print(f"SCORE {mean} +/- {std}:{self.risk_type}@{self.risk_eta}: {risk:.4f} \n "
              f"SUCCESS RATE: {success_ratio * 100:.2f}%, COLLISION RATE {collision_ratio * 100:.2f}%")
        log = {"risk_type": self.risk_type, "risk_eta": self.risk_eta,
               "mean": mean, "std": std, "risk": risk, "success_ratio": success_ratio,
               "collision_ratio": collision_ratio
               }
        return log

    def evaluate_airsim(self, n_eval: int, env):
        scores = []
        success_log = []
        collision_log = []
        for _ in trange(n_eval):
            seed = self.np_rng.integers(0, 2 ** 30, size=(1,)).item()
            obs, _ = env.reset(seed=seed)
            done = False
            score = 0
            success = False
            collision = False
            while not done:
                action = self.predict(obs, deterministic=False)
                obs, reward, done, timeout, info = env.step(action)
                if 'is_success' in info.keys():
                    if info['is_success']:
                        success = True
                if 'collision' in info.keys():
                    if info['collision']:
                        success = False
                        collision = True

                score += reward
                done = done or timeout
            scores.append(score)
            success_log.append(success)
            collision_log.append(collision)
        success = np.mean(success_log)
        collision = np.mean(collision_log)
        return np.asarray(scores), success, collision

