""" LCD (Language Conditioned Diffusion Implementaiton """
from functools import partial
from typing import Any, Tuple, NamedTuple, List, Dict, Union, Type, Optional, Callable

import gym
import jax
import jax.numpy as jnp
import haiku as hk
import optax

from sb3_jax.common.policies import BasePolicy
from sb3_jax.common.norm_layers import BaseNormLayer
from sb3_jax.common.jax_layers import BaseFeaturesExtractor, FlattenExtractor
from sb3_jax.common.preprocessing import get_flattened_obs_dim, get_act_dim
from sb3_jax.common.type_aliases import GymEnv, MaybeCallback, Schedule
from sb3_jax.common.utils import get_dummy_obs, get_dummy_act
from sb3_jax.du.policies import DiffusionBetaScheduler

from diffgro.utils.utils import print_b 
from diffgro.common.models.utils import apply_cond
from diffgro.common.models.helpers import MLP
from diffgro.common.models.diffusion import UNetDiffusion, Diffusion
from diffgro.diffgro.functions import calculate_grad


class HighActor(BasePolicy):
    """ High-Level actor for lcd policy """
    def __init__(
        self,
        observation_space: gym.spaces.Space,
        action_space: gym.spaces.Space,
        net_arch: Optional[List[int]] = None,
        activation_fn: str = 'mish',
        domain: str = 'short',
        # embedding
        horizon: int = 8,       # planning horizon
        skill_dim: int = 512,   # semantic skill embedding dimension
        emb_dim: int = 64,      # embedding dimension
        # diffusion
        n_denoise: int = 20,    # denoising timestep
        cf_weight: float = 1.0, # classifier-free guidance weight
        predict_epsilon: bool = False,  # predict noise / original
        beta_scheduler: str = 'linear', # denoising scheduler
        seed: int = 1,
    ):
        super(HighActor, self).__init__(
            observation_space,
            action_space,
            squash_output=False,
            seed=seed,
        )

        self.net_arch = net_arch
        self.activation_fn = activation_fn
        
        self.domain = domain
        print_b(f"Setting doamin as {self.domain}")
        # embedding
        self.horizon = horizon
        self.skill_dim = skill_dim
        self.emb_dim = emb_dim

        # diffusion
        self.n_denoise = n_denoise
        self.cf_weight = cf_weight
        self.predict_epsilon = predict_epsilon
        self.beta_scheduler = beta_scheduler
        self.ddpm_dict = DiffusionBetaScheduler(None, None, n_denoise, beta_scheduler).schedule()

        # misc
        self.obs_dim = get_flattened_obs_dim(self.observation_space)
        self.act_dim = get_act_dim(self.action_space)
        self.out_dim = self.obs_dim 

        self._build()

    def _get_constructor_parameters(self) -> Dict[str, Any]:
        data = super()._get_constructor_parameters()
        return data

    def _build_hact(self, batch_keys: Dict[str, jax.Array]) -> hk.Module:
        # diffusion high-level actor
        unet = UNetDiffusion(
            horizon=self.horizon,
            emb_dim=self.emb_dim,
            out_dim=self.out_dim,
            dim_mults=(1,4,8),
            attention=False,
            batch_keys=batch_keys,
            activation_fn=self.activation_fn
        )
        return Diffusion(
            diffusion=unet,
            n_denoise=self.n_denoise,
            ddpm_dict=self.ddpm_dict,
            guidance_weight=self.cf_weight,
            predict_epsilon=self.predict_epsilon,
            denoise_type='ddpm',
        )

    def _build(self) -> None:
        # dummy inputs
        dummy_obs, dummy_act = get_dummy_obs(self.observation_space), get_dummy_act(self.action_space)
        dummy_obs_stack = jnp.repeat(dummy_obs, self.horizon, axis=0).reshape(1, self.horizon, -1) # stacked observation
        dummy_lang = jax.random.normal(next(self.rng), shape=(1, self.skill_dim))  # skill embedding
        dummy_skill = jax.random.normal(next(self.rng), shape=(1, self.skill_dim))  # skill embedding
        dummy_t = jnp.array([[1.]])

        def fn_hact(x_t: jax.Array, batch_dict: Dict[str, jax.Array], t: jax.Array, denoise: bool, deterministic: bool):
            hact = self._build_hact(batch_keys=["lang"] if self.domain == 'short' else ["lang", "skill"])
            return hact(x_t, batch_dict, t, denoise, deterministic)
        params, self.hpi = hk.transform(fn_hact)
        batch_dict = {"lang": dummy_lang, "skill": dummy_skill}
        self.params = params(next(self.rng), dummy_obs_stack, batch_dict, dummy_t, denoise=False, deterministic=False)

    @partial(jax.jit, static_argnums=(0,4,5))
    def _hpi(
        self,
        x_t: jax.Array,
        batch_dict: Dict[str, jax.Array],
        t: jax.Array,
        denoise: bool, 
        deterministic: bool, 
        params: hk.Params, 
        rng=None
    ) -> Tuple[Tuple[jax.Array], Dict[str, jax.Array]]:
        # return: eps, info
        return self.hpi(params, rng, x_t, batch_dict, t, denoise, deterministic)

    def _predict(
        self,
        x_t: jax.Array,
        lang: jax.Array,
        t: int,
        skill: jax.Array,
        deterministic: bool = False,
    ) -> Tuple[jax.Array, Dict[str, jax.Array]]:
        # return: eps, info
        batch_dict = {"lang": lang, "skill": skill}
        ts = jnp.full((x_t.shape[0], 1), t)
        eps, info = self._hpi(x_t, batch_dict, ts, False, deterministic, self.params, next(self.rng))
        return eps, info
    
    @partial(jax.jit, static_argnums=(0,4))
    def _sample(
        self,
        x_t: jax.Array,
        eps: jax.Array,
        t: int,
        deterministic: bool,
        rng=None,
    ) -> jax.Array:
        batch_size = x_t.shape[0]
        noise = jax.random.normal(rng, shape=(batch_size, self.horizon, self.out_dim)) if not deterministic else 0.
       
        if self.predict_epsilon:
            x_t = self.ddpm_dict.oneover_sqrta[t] * (x_t - self.ddpm_dict.ma_over_sqrtmab_inv[t] * eps) \
                    + self.ddpm_dict.sqrt_beta_t[t] * noise
        else:
            x_t = self.ddpm_dict.posterior_mean_coef1[t] * eps + self.ddpm_dict.posterior_mean_coef2[t] * x_t \
                    + jnp.exp(0.5 * self.ddpm_dict.posterior_log_beta[t]) * noise
        return x_t

    def _denoise(
        self,
        cond: jax.Array,
        lang: jax.Array,
        skill: jax.Array = None,
        delta: float = 0.1,
        guide_fn: Callable = None,
        deterministic: bool = False,
    ) -> Tuple[jax.Array, Dict[str, jax.Array]]:
        batch_size = cond.shape[0]

        x_t = jax.random.normal(next(self.rng), shape=(batch_size, self.horizon, self.out_dim))
        x_t = apply_cond(x_t, cond)

        for t in range(self.n_denoise, 0, -1):
            eps, _ = self._predict(x_t, lang, t, skill, deterministic)
            original_eps = eps

            if (guide_fn is not None) and (t <= self.n_denoise - 2):
                # calculate gradient
                grad, grad_info = calculate_grad(guide_fn, eps, self.obs_dim)
                # gradient scaling
                loss = grad_info['loss']
                count = 0
                retry = 0 

                if loss < 0.0:
                    loss = -loss

                while True:
                    retry += 1
                    if (loss <= 1.0 and loss >= 0.1) or loss == 0.0:
                        break
                    if loss > 1.0:
                        loss /= 10
                        count -= 1
                    if loss < 0.1:
                        loss *= 10
                        count += 1

                try: 
                    grad = grad * (10 ** count)
                except:
                    print(count)
                    print(grad_info['loss'])
                    exit()
                eps = eps - delta * grad  # jnp.exp(self.ddpm_dict.posterior_log_beta[t])

            x_t = self._sample(x_t, eps, t, deterministic, next(self.rng))
            x_t = apply_cond(x_t, cond)
        return x_t, {}

    def _load_jax_params(self, params: Dict[str, hk.Params]) -> None:
        print_b("[lcd/hactor]: loading params")
        self.params = params["hpi_params"]


class LowActor(BasePolicy):
    """ Low-Level actor for lcd policy """
    def __init__(
        self,
        observation_space: gym.spaces.Space,
        action_space: gym.spaces.Space,
        net_arch: Optional[List[int]] = None,
        activation_fn: str = 'mish',
        domain: str = 'short',
        # embedding
        skill_dim: int = 512,   # semantic skill embedding dimension
        emb_dim: int = 64,      # embedding dimension
        seed: int = 1,
    ):
        super(LowActor, self).__init__(
            observation_space,
            action_space,
            squash_output=False,
            seed=seed,
        )

        self.net_arch = net_arch
        self.activation_fn = activation_fn

        self.domain = domain
        self.skill_dim = skill_dim
        self.emb_dim = emb_dim

        # misc
        self.obs_dim = get_flattened_obs_dim(self.observation_space)
        self.act_dim = get_act_dim(self.action_space)

        self._build()

    def _get_constructor_parameters(self) -> Dict[str, Any]:
        data = super()._get_constructor_parameters()
        return data
    
    def _build_lact(self, batch_keys: Dict[str, jax.Array]) -> hk.Module:
        # low-level actor
        return MLP(
            emb_dim=self.emb_dim,
            out_dim=self.act_dim,
            net_arch=self.net_arch,
            batch_keys=batch_keys,
            activation_fn=self.activation_fn,
            squash_output=False,
        )
    
    def _build(self) -> None:
        # dummy inputs
        dummy_obs = get_dummy_obs(self.observation_space)
        dummy_act = get_dummy_act(self.action_space)
        dummy_skill = jax.random.normal(next(self.rng), shape=(1, self.skill_dim))  # skill embedding

        def fn_lact(batch_dict: Dict[str, jax.Array]):
            lact = self._build_lact(batch_keys=["obs_0", "obs_1"] if self.domain == 'short' else ["obs_0", "obs_1", "skill"])
            return lact(batch_dict)
        params, self.lpi = hk.transform(fn_lact)
        batch_dict = {"obs_0": dummy_obs, "obs_1": dummy_obs, "skill": dummy_skill}
        self.params = params(next(self.rng), batch_dict)

    @partial(jax.jit, static_argnums=(0,))
    def _lpi(self, batch_dict: Dict[str, jax.Array], params: hk.Params, rng=None) -> jax.Array:
        return self.lpi(params, rng, batch_dict)

    def _predict(self, obs_0: jax.Array, obs_1: jax.Array, skill: jax.Array) -> jax.Array:
        batch_dict = {"obs_0": obs_0, "obs_1": obs_1, "skill": skill}
        return self._lpi(batch_dict, self.params, next(self.rng))

    def _load_jax_params(self, params: Dict[str, hk.Params]) -> None:
        print_b("[lcd/lactor]: loading params")
        self.params = params["lpi_params"]


class LCDPlannerPolicy(BasePolicy):
    """ policy class for lcd """
    def __init__(
        self,
        observation_space: gym.spaces.Space,
        action_space: gym.spaces.Space,
        lr_schedule: Schedule,
        net_arch: Optional[List[int]] = None,
        activation_fn: str = 'mish',
        domain: str = 'short',
        # embedding
        horizon: int = 8,       # planning horizon
        skill_dim: int = 512,   # semantic skill embedding dimension
        emb_dim: int = 64,      # embedding dimension
        # diffusion
        n_denoise: int = 20,    # denoising timestep
        cf_weight: float = 1.0, # diffusion classifier-free guidance weight
        predict_epsilon: bool = False,  # predict noise / original
        beta_scheduler: str = 'linear', # denoising scheduler
        # others
        squash_output: bool = False,
        features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
        features_extractor_kwargs: Optional[Dict[str, Any]] = None,
        normalize_images: bool = True,
        optimizer_class: Callable = optax.adamw,
        optimizer_kwargs: Optional[Dict[str, Any]] = None,
        normalization_class: Type[BaseNormLayer] = None,
        normalization_kwargs: Optional[Dict[str, Any]] = None,
        seed: int = 1,
    ):
        super(LCDPlannerPolicy, self).__init__(
            observation_space,
            action_space,
            features_extractor_class,
            features_extractor_kwargs,
            optimizer_class=optimizer_class,
            optimizer_kwargs=optimizer_kwargs,
            normalization_class=normalization_class,
            normalization_kwargs=normalization_kwargs,
            squash_output=squash_output,
            seed=seed,
        )

        if net_arch is None:
            net_arch = dict(hact=(1,4,8), lact=[128,128])
        self.hact_arch, self.lact_arch = net_arch['hact'], net_arch['lact']
        self.activation_fn = activation_fn

        self.domain = domain
        assert self.domain in ['short', 'long'], 'Domain should be either short or long'
        self.horizon = horizon
        self.skill_dim = skill_dim
        self.emb_dim = emb_dim

        self.n_denoise = n_denoise
        self.cf_weight = cf_weight
        self.predict_epsilon = predict_epsilon
        self.beta_scheduler = beta_scheduler

        # constructor args
        self.net_args = {
            "observation_space": self.observation_space,
            "action_space": self.action_space,
            "activation_fn": self.activation_fn,
            "domain": domain,
            "seed": seed,
        }

        # hactor kwargs
        self.hact_kwargs = self.net_args.copy()
        self.hact_kwargs.update({
            "net_arch": self.hact_arch,
            "horizon": horizon,
            "skill_dim": skill_dim,
            "emb_dim": emb_dim,
            "n_denoise": n_denoise,
            "cf_weight": cf_weight,
            "predict_epsilon": predict_epsilon,
            "beta_scheduler": beta_scheduler,
        })

        # lactor kwargs
        self.lact_kwargs = self.net_args.copy()
        self.lact_kwargs.update({
            "net_arch": self.lact_arch,
            "skill_dim": skill_dim,
            "emb_dim": emb_dim,
        })

        self._build(lr_schedule)

    def _get_constructor_parameters(self) -> Dict[str, Any]:
        data = super()._get_constructor_parameters()

        data.update(
            dict(
                observation_space=self.observation_space,
                action_space=self.action_space, 
                horizon=self.horizon,
                skill_dim=self.skill_dim,
                emb_dim=self.emb_dim,
                n_denoise=self.n_denoise,
                cf_weight=self.cf_weight,
                predict_epsilon=self.predict_epsilon,
                beta_scheduler=self.beta_scheduler,
                optimizer_class=self.optimizer_class,
                optimizer_kwargs=self.optimizer_kwargs,
                features_extractor_class=self.features_extractor_class,
                features_extractor_kwargs=self.features_extractor_kwargs,
                normalization_class=self.normalization_class,
                normalization_kwargs=self.normalization_kwargs,
            )
        )
        return data

    def _build(self, lr_schedule: Tuple[float]) -> None:
        if self.normalization_class is not None:
            self.normalization_layer = self.normalization_class(self.observation_space.shape, **self.normalization_kwargs)

        # make high-level actor
        self.hact = self.make_hact()
        self.hact.optim = self.optimizer_class(learning_rate=lr_schedule, **self.optimizer_kwargs)
        self.hact.optim_state = self.hact.optim.init(self.hact.params)

        # make low-level actor
        self.lact = self.make_lact()
        self.lact.optim = self.optimizer_class(learning_rate=lr_schedule, **self.optimizer_kwargs)
        self.lact.optim_state = self.lact.optim.init(self.lact.params)

    def make_hact(self) -> HighActor:
        return HighActor(**self.hact_kwargs)

    def make_lact(self) -> LowActor:
        return LowActor(**self.lact_kwargs)

    def _predict_hact(
        self,
        cond: jax.Array, # observation
        lang: jax.Array,
        skill: jax.Array = None,
        delta: float = 0.1,
        guide_fn: jax.Array = None,
        deterministic: bool = True,
    ) -> Tuple[jax.Array, Dict[str, jax.Array]]:
        cond = self.preprocess(cond, training=False)
        return self.hact._denoise(cond, lang, skill, delta, guide_fn, deterministic) 
    
    def _predict_lact(
        self,
        obs_0: jax.Array,
        obs_1: jax.Array,
        skill: jax.Array,
    ) -> jax.Array:
        obs_0 = self.preprocess(obs_0, training=False)
        # obs_1 = self.preprocess(obs_1, training=False)
        return self.lact._predict(obs_0, obs_1, skill)

    def _predict(self,):
        raise NotImplementedError
