from token import OP
from typing import Any, Dict, List, Optional, Sequence, Union

import numpy as np

from ..argument_utility import (
    ActionScalerArg,
    EncoderArg,
    DropQFuncArg,
    RewardScalerArg,
    ScalerArg,
    UseGPUArg,
    check_encoder,
    check_drop_q_func,
    check_use_gpu,
)
from ..constants import IMPL_NOT_INITIALIZED_ERROR, ActionSpace
from ..dataset import TransitionMiniBatch
from ..gpu import Device
from ..models.encoders import EncoderFactory
from ..models.optimizers import AdamFactory, OptimizerFactory
from ..models.q_functions import DropQFunctionFactory
from .base import DropAlgoBase
from .torch.drop_impl import DropBaseImpl, DropImpl, DropDiscreteImpl


class _DropBase(DropAlgoBase):
    _drop_learning_rate: float
    _drop_optim_factory: OptimizerFactory
    _drop_encoder_factory: EncoderFactory
    _actor_learning_rate: float
    _actor_optim_factory: OptimizerFactory
    _actor_encoder_factory: EncoderFactory
    _critic_learning_rate: float
    _critic_optim_factory: OptimizerFactory
    _critic_encoder_factory: EncoderFactory
    _q_func_factory: DropQFunctionFactory
    _embedding_learning_rate: float
    _embedding_optim_factory: OptimizerFactory
    _energy_learning_rate: float
    _energy_optim_factory: OptimizerFactory
    _energy_encoder_factory: EncoderFactory
    _energy_update_steps: int
    _tau: float
    _n_critics: int
    _use_gpu: Optional[Device]
    _impl: Optional[DropBaseImpl]

    def __init__(
        self,
        *,
        drop_num: int = 1,
        drop_dim: int = 2,
        drop_size: float = 0,
        drop_seed: int = 1,
        drop_type: Optional[str] = None,
        drop_learning_rate: float = 1e-3,
        drop_optim_factory: OptimizerFactory = AdamFactory(),
        drop_encoder_factory: EncoderArg = "default",
        actor_learning_rate: float = 1e-3,
        actor_optim_factory: OptimizerFactory = AdamFactory(),
        actor_encoder_factory: EncoderArg = "default",
        critic_learning_rate: float = 3e-4,
        critic_optim_factory: OptimizerFactory = AdamFactory(),
        critic_encoder_factory: EncoderArg = "default",
        q_func_factory: DropQFuncArg = "mean",
        embedding_learning_rate: float = 3e-3,
        embedding_optim_factory: OptimizerFactory = AdamFactory(),
        energy_learning_rate: float = 3e-3,
        energy_optim_factory: OptimizerFactory = AdamFactory(),
        energy_encoder_factory: EncoderArg = "default",
        energy_update_steps: int = 1,
        batch_size: int = 100,
        n_frames: int = 1,
        n_steps: int = 1,
        gamma: float = 0.99,
        tau: float = 0.005,
        n_critics: int = 1,
        use_gpu: UseGPUArg = False,
        scaler: ScalerArg = None,
        action_scaler: ActionScalerArg = None,
        reward_scaler: RewardScalerArg = None,
        impl: Optional[DropBaseImpl] = None,
        **kwargs: Any
    ):
        super().__init__(
            drop_num=drop_num,
            drop_dim=drop_dim,
            drop_size=drop_size,
            drop_seed=drop_seed,
            drop_type=drop_type,
            batch_size=batch_size,
            n_frames=n_frames,
            n_steps=n_steps,
            gamma=gamma,
            scaler=scaler,
            action_scaler=action_scaler,
            reward_scaler=reward_scaler,
            kwargs=kwargs,
        )
        self._drop_learning_rate = drop_learning_rate
        self._drop_optim_factory = drop_optim_factory
        assert drop_encoder_factory is not None
        self._drop_encoder_factory = check_encoder(drop_encoder_factory)
        self._actor_learning_rate = actor_learning_rate
        self._actor_optim_factory = actor_optim_factory
        self._actor_encoder_factory = check_encoder(actor_encoder_factory)
        self._critic_learning_rate = critic_learning_rate
        self._critic_optim_factory = critic_optim_factory
        self._critic_encoder_factory = check_encoder(critic_encoder_factory)
        self._q_func_factory = check_drop_q_func(q_func_factory)
        self._embedding_learning_rate = embedding_learning_rate
        self._embedding_optim_factory = embedding_optim_factory
        self._energy_learning_rate = energy_learning_rate
        self._energy_optim_factory = energy_optim_factory
        self._energy_encoder_factory = check_encoder(energy_encoder_factory)
        self._energy_update_steps = energy_update_steps
        self._tau = tau
        self._n_critics = n_critics
        self._use_gpu = check_use_gpu(use_gpu)
        self._impl = impl

        self.best_drop = -1
        self._embeddings = None
        self._best_embeddings = None
        self._best_embeddings_loss = None
    
    def update_policy(self, drop_batch: list) -> Dict[str, float]:
        [ns, ns_drop, batch, batch_drop, batch_init] = drop_batch
        emb = self._embeddings[ns]
        [actor_loss] = self._impl.update_imitator(batch.observations, batch.actions, ns, emb)
        return {"Actor": actor_loss}


    def _update(self, drop_batch: list) -> Dict[str, float]:
        assert self._impl is not None, IMPL_NOT_INITIALIZED_ERROR
        [ns, ns_drop, batch, batch_drop, batch_init] = drop_batch
        
        # [emb_loss_1, emb_loss_2] = self._impl.update_energy(batch, batch_drop, ns, ns_drop) 
        
        [actor_loss] = self._impl.update_imitator(batch.observations, batch.actions, ns)
        [critic_loss, critic_loss_pre, q_tpn_pre] = self._impl.update_critic(batch, ns)
        self._impl.update_critic_target()
        self._impl.update_drop_emb_target()

        [con_loss] = self._impl.update_critic_conservative(batch, ns, np.random.rand(self._batch_size, self._drop_dim)*self._drop_dim)
        
        self._embeddings = self._impl.get_emb(np.arange(self._drop_num)).cpu().numpy().copy()# * 0.1 + 0.9 * self._embeddings

        # Emax = np.max(self._embeddings)
        # L = np.abs(q_tpn_pre[batch.Inits.reshape(-1).tolist()] - batch.Rs[batch.Inits.reshape(-1).tolist()]).mean()
        return {"Q": critic_loss, "Qp": critic_loss_pre,# "L": L, 
                "Pi": actor_loss, "Con": con_loss, }#"Emax": Emax, "Ew": self._impl._emb_norm_weight}#"E1": emb_loss_1, "E2": emb_loss_2,}
    


    def _update_best(self, drop_batch: list, iter: int = 100) -> None:
        [n, batch_init] = drop_batch
        assert len(n) == self._drop_num
        if iter > 0:
            embeddings = self._embeddings
            for _ in range(iter):
                [losses, embeddings] = self._impl.update_best(batch_init.observations, embeddings.copy())
            self._best_embeddings = embeddings.copy()
            self._best_embeddings_loss = losses[:, 0]


    def _get_N_embeddings(self, N: int) -> List[np.ndarray]:
        index = np.argpartition(self._best_embeddings_loss, N)
        return [index, self._best_embeddings[index][:N]]

    def predict_value(
        self,
        x: Union[np.ndarray, List[Any]],
        action: Union[np.ndarray, List[Any]],
        e: Optional[np.ndarray] = None,
        with_std: bool = False,
    ) -> np.ndarray:
        if e is None:
            e = self.best_emb3[2]
        return self._impl.predict_value(x, action, e)
    
    def predict(self, x: Union[np.ndarray, List[Any]], e: Optional[np.ndarray] = None) -> np.ndarray:
        assert self._impl is not None, IMPL_NOT_INITIALIZED_ERROR
        if e is None:
            e = self.best_emb3[2]
        return self._impl.predict_best_action(x, e)
        

    # def sample_action(self, x: Union[np.ndarray, List[Any]]) -> None:
    #     """sampling action is not supported by BC algorithm."""
    #     raise NotImplementedError("BC does not support sampling action.")

    @property
    def best_emb3(self) -> List[np.ndarray]:
        assert self.best_drop is not None
        return [self.best_drop, self._embeddings[self.best_drop], self._best_embeddings[self.best_drop]]
    
    
    def best_emb3_adaptive(self, obs: np.ndarray, count: Optional[list]=None, iter: int = 10) -> List[np.ndarray]:
        if iter == 0:
            if count is None:
                count = range(self._drop_num)
            embeddings = self._embeddings[count]
            obss = np.array([obs]).repeat(embeddings.shape[0], axis=0)
            [losses, _] = self._impl.update_best(obss, embeddings.copy())
            losses = losses.reshape(-1).tolist()
            index = losses.index(min(losses))
            index_ = count[index]
            return [index, self._embeddings[index_], embeddings[index]]
            

        if count is None:
            count = range(self._drop_num)
        embeddings = self._embeddings[count]
        obss = np.array([obs]).repeat(embeddings.shape[0], axis=0)
        for _ in range(iter):
            [losses, embeddings] = self._impl.update_best(obss, embeddings.copy())
        losses = losses.reshape(-1).tolist()
        index = losses.index(min(losses))
        index_ = count[index]
        return [index, self._embeddings[index_], embeddings[index]]


class DROP(_DropBase):

    _policy_type: str
    _impl: Optional[DropImpl]

    def __init__(
        self,
        *,
        drop_num: int = 1,
        drop_dim: int = 2,
        drop_size: float = 0,
        drop_seed: int = 1,
        drop_type: Optional[str] = None,
        drop_learning_rate: float = 3e-4,
        actor_learning_rate: float = 3e-4,
        critic_learning_rate: float = 3e-4,
        embedding_learning_rate: float = 3e-4,
        energy_learning_rate: float = 3e-3,
        drop_optim_factory: OptimizerFactory = AdamFactory(),
        actor_optim_factory: OptimizerFactory = AdamFactory(),
        critic_optim_factory: OptimizerFactory = AdamFactory(),
        embedding_optim_factory: OptimizerFactory = AdamFactory(),
        energy_optim_factory: OptimizerFactory = AdamFactory(),
        drop_encoder_factory: EncoderArg = "default",
        actor_encoder_factory: EncoderArg = "default",
        critic_encoder_factory: EncoderArg = "default",
        q_func_factory: DropQFuncArg = "mean",
        energy_encoder_factory: EncoderArg = "default",
        energy_update_steps: int = 1,
        batch_size: int = 100,
        n_frames: int = 1,
        n_steps: int = 1,
        gamma: float = 0.99,
        tau: float = 0.005,
        n_critics: int = 1,
        policy_type: str = "stochastic",
        use_gpu: UseGPUArg = False,
        scaler: ScalerArg = None,
        action_scaler: ActionScalerArg = None,
        reward_scaler: RewardScalerArg = None,
        impl: Optional[DropBaseImpl] = None,
        **kwargs: Any
    ):
        super().__init__(
            drop_num=drop_num,
            drop_dim=drop_dim,
            drop_size=drop_size,
            drop_seed=drop_seed,
            drop_type=drop_type,
            drop_learning_rate=drop_learning_rate,
            drop_optim_factory=drop_optim_factory,
            drop_encoder_factory=drop_encoder_factory,
            actor_learning_rate=actor_learning_rate,
            actor_optim_factory=actor_optim_factory,
            actor_encoder_factory=actor_encoder_factory,
            critic_learning_rate=critic_learning_rate,
            critic_optim_factory=critic_optim_factory,
            critic_encoder_factory=critic_encoder_factory,
            q_func_factory=q_func_factory,
            embedding_learning_rate=embedding_learning_rate,
            embedding_optim_factory=embedding_optim_factory,
            energy_learning_rate=energy_learning_rate,
            energy_optim_factory=energy_optim_factory,
            energy_encoder_factory=energy_encoder_factory,
            energy_update_steps=energy_update_steps,
            batch_size=batch_size,
            n_frames=n_frames,
            n_steps=n_steps,
            gamma=gamma,
            tau=tau,
            n_critics=n_critics,
            use_gpu=use_gpu,
            scaler=scaler,
            action_scaler=action_scaler,
            reward_scaler=reward_scaler,
            impl=impl,
            **kwargs,
        )
        self._policy_type = policy_type

    def _create_impl(
        self, observation_shape: Sequence[int], action_size: int
    ) -> None:
        self._impl = DropImpl(
            observation_shape=observation_shape,
            action_size=action_size,
            embedding_size=self._drop_dim,
            drop_num=self._drop_num,
            drop_learning_rate=self._drop_learning_rate,
            drop_optim_factory=self._drop_optim_factory,
            drop_encoder_factory=self._drop_encoder_factory,
            learning_rate=self._actor_learning_rate,
            optim_factory=self._actor_optim_factory,
            encoder_factory=self._actor_encoder_factory,
            critic_learning_rate=self._critic_learning_rate,
            critic_optim_factory=self._critic_optim_factory,
            critic_encoder_factory=self._critic_encoder_factory,
            q_func_factory=self._q_func_factory,
            embedding_learning_rate=self._embedding_learning_rate,
            embedding_optim_factory=self._embedding_optim_factory,
            energy_learning_rate=self._energy_learning_rate,
            energy_optim_factory=self._energy_optim_factory,
            energy_encoder_factory=self._energy_encoder_factory,
            gamma=self._gamma,
            tau=self._tau,
            n_critics=self._n_critics,
            policy_type=self._policy_type,
            use_gpu=self._use_gpu,
            scaler=self._scaler,
            action_scaler=self._action_scaler,
            reward_scaler=self._reward_scaler,
        )
        self._impl.build()
    
    def _init_impl_embeddings(self) -> None:
        self._embeddings = np.random.rand(self._drop_num, self._drop_dim)*0.
        self._best_embeddings = self._embeddings.copy()
        self._best_embeddings_loss = np.array([0.,]*self._drop_num)
        

    def get_action_type(self) -> ActionSpace:
        return ActionSpace.CONTINUOUS


class DiscreteDROP(_DropBase):

    _beta: float
    _impl: Optional[DropDiscreteImpl]

    def __init__(
        self,
        *,
        learning_rate: float = 1e-3,
        optim_factory: OptimizerFactory = AdamFactory(),
        encoder_factory: EncoderArg = "default",
        batch_size: int = 100,
        n_frames: int = 1,
        beta: float = 0.5,
        use_gpu: UseGPUArg = False,
        scaler: ScalerArg = None,
        impl: Optional[DropDiscreteImpl] = None,
        **kwargs: Any
    ):
        super().__init__(
            learning_rate=learning_rate,
            optim_factory=optim_factory,
            encoder_factory=encoder_factory,
            batch_size=batch_size,
            n_frames=n_frames,
            use_gpu=use_gpu,
            scaler=scaler,
            impl=impl,
            **kwargs,
        )
        self._beta = beta

    def _create_impl(
        self, observation_shape: Sequence[int], action_size: int
    ) -> None:
        self._impl = DropDiscreteImpl(
            observation_shape=observation_shape,
            action_size=action_size,
            learning_rate=self._learning_rate,
            optim_factory=self._optim_factory,
            encoder_factory=self._encoder_factory,
            beta=self._beta,
            use_gpu=self._use_gpu,
            scaler=self._scaler,
        )
        self._impl.build()

    def get_action_type(self) -> ActionSpace:
        return ActionSpace.DISCRETE
