import copy
from abc import ABCMeta, abstractmethod
from typing import Optional, Sequence, Union, List

import math
import numpy as np
import torch
from torch.optim import Optimizer

from ...gpu import Device
from ...models.builders import (
    create_drop_deterministic_policy,
    create_drop_deterministic_regressor,
    create_drop_discrete_imitator,
    create_drop_probablistic_regressor,
    create_drop_squashed_normal_policy,
    create_drop_continuous_q_function,
    create_drop_continuous_q_function_ens,
    create_drop_energy_function,
    create_parameter
)
from ...models.encoders import EncoderFactory
from ...models.optimizers import OptimizerFactory
from ...models.q_functions import DropQFunctionFactory, DropContinuousMeanQFunction
from ...models.torch import (
    DropDeterministicRegressor,
    DropDiscreteImitator,
    DropImitator,
    DropPolicy,
    DropProbablisticRegressor,
    DropEnsembleContinuousQFunction,
    DropEnsembleQFunction
)
from ...preprocessing import ActionScaler, RewardScaler, Scaler
from ...torch_utility import TorchMiniBatch, DropTorchMiniBatch, soft_sync, hard_sync, torch_api, train_api, drop_torch_api, eval_api
from .base import DropTorchImplBase
from .utility import DropContinuousQFunctionMixin


class DropBaseImpl(DropContinuousQFunctionMixin, DropTorchImplBase, metaclass=ABCMeta):

    _drop_learning_rate: float
    _drop_optim_factory: OptimizerFactory
    _drop_encoder_factory: EncoderFactory
    _learning_rate: float
    _optim_factory: OptimizerFactory
    _encoder_factory: EncoderFactory

    _critic_learning_rate: float
    _critic_optim_factory: OptimizerFactory
    _critic_encoder_factory: EncoderFactory
    _q_func_factory: DropQFunctionFactory
    _embedding_learning_rate: float
    _embedding_optim_factory: OptimizerFactory
    _energy_learning_rate: float
    _energy_optim_factory: OptimizerFactory
    _energy_encoder_factory: EncoderFactory
    _emb_norm_weight: int
    _con_ada_weight: int
    _gamma: float
    _tau: float
    _n_critics: int

    _use_gpu: Optional[Device]
    _imitator: Optional[DropImitator]
    _optim: Optional[Optimizer]

    _q_func: Optional[DropEnsembleContinuousQFunction]
    _targ_q_func: Optional[DropEnsembleContinuousQFunction] 
    # _q_func: Optional[DropContinuousMeanQFunction]
    # _targ_q_func: Optional[DropContinuousMeanQFunction] 
    _critic_optim: Optional[Optimizer]

    def __init__(
        self,
        observation_shape: Sequence[int],
        action_size: int,
        embedding_size: int,
        drop_num: int,
        drop_learning_rate: float,
        drop_optim_factory: OptimizerFactory,
        drop_encoder_factory: EncoderFactory,
        learning_rate: float,
        optim_factory: OptimizerFactory,
        encoder_factory: EncoderFactory,
        critic_learning_rate: float,
        critic_optim_factory: OptimizerFactory,
        critic_encoder_factory: EncoderFactory,
        q_func_factory: DropQFunctionFactory,
        embedding_learning_rate: float,
        embedding_optim_factory: OptimizerFactory,
        energy_learning_rate: float,
        energy_optim_factory: OptimizerFactory,
        energy_encoder_factory: EncoderFactory,
        gamma: float,
        tau: float,
        n_critics: int,
        use_gpu: Optional[Device],
        scaler: Optional[Scaler],
        action_scaler: Optional[ActionScaler],
        reward_scaler: Optional[RewardScaler],
    ):
        super().__init__(
            observation_shape=observation_shape,
            action_size=action_size,
            embedding_size=embedding_size,
            drop_num=drop_num,
            scaler=scaler,
            action_scaler=action_scaler,
            reward_scaler=reward_scaler,
        )
        self._drop_learning_rate = drop_learning_rate
        self._drop_optim_factory = drop_optim_factory
        self._drop_encoder_factory = drop_encoder_factory
        self._learning_rate = learning_rate
        self._optim_factory = optim_factory
        self._encoder_factory = encoder_factory
        self._critic_learning_rate = critic_learning_rate
        self._critic_optim_factory = critic_optim_factory
        self._critic_encoder_factory = critic_encoder_factory
        self._q_func_factory = q_func_factory
        self._embedding_learning_rate = embedding_learning_rate
        self._embedding_optim_factory = embedding_optim_factory
        self._energy_learning_rate = energy_learning_rate
        self._energy_optim_factory = energy_optim_factory
        self._energy_encoder_factory = energy_encoder_factory
        self._gamma = gamma
        self._tau = tau
        self._n_critics = n_critics
        self._use_gpu = use_gpu

        self._emb_norm_weight = 0.1
        self._emb_norm_weight_lr = 0.0005
        self._con_ada_weight = 1
        self._con_ada_weight_lr = 0.0005

        # initialized in build
        self._drop_emb = None
        self._imitator = None
        self._optim = None
        self._q_func = None
        self._targ_q_func = None
        self._critic_optim = None

        self._energy_func = None
        self._energy_optim = None

    def build(self) -> None:
        # setup torch models
        self._build_drop()
        self._build_network()
        self._build_critic()
        self._build_energy()

        # setup target networks
        self._targ_q_func = copy.deepcopy(self._q_func)
        self._targ_drop_emb = copy.deepcopy(self._drop_emb)

        if self._use_gpu:
            self.to_gpu(self._use_gpu)
        else:
            self.to_cpu()

        self._build_drop_optim()
        self._build_optim()
        self._build_critic_optim()
        self._build_energy_optim()
    
    def _build_drop(self) -> None:
        self._drop_emb = self._drop_encoder_factory.create((self._drop_num,))

    def _build_drop_optim(self) -> None:
        self._drop_optim = self._drop_optim_factory.create(
            self._drop_emb.parameters(), lr=self._drop_learning_rate
        )

    @abstractmethod
    def _build_network(self) -> None:
        pass

    def _build_optim(self) -> None:
        assert self._imitator is not None
        self._optim = self._optim_factory.create(
            self._imitator.parameters(), lr=self._learning_rate
        )
    
    def _build_critic(self) -> None:
        self._q_func = create_drop_continuous_q_function_ens( # create_drop_continuous_q_function
            self._observation_shape,
            self._action_size,
            self._embedding_size,
            self._critic_encoder_factory,
            self._q_func_factory,
            n_ensembles=self._n_critics,
        )
    
    def _build_energy(self) -> None:
        self._energy_func = create_drop_energy_function(
            self._observation_shape,
            self._action_size,
            self._embedding_size,
            self._energy_encoder_factory,
        )

    def _build_critic_optim(self) -> None:
        assert self._q_func is not None
        self._critic_optim = self._critic_optim_factory.create(
            self._q_func.parameters(), lr=self._critic_learning_rate
        )
    
    def _build_energy_optim(self) -> None:
        assert self._energy_func is not None
        self._energy_optim = self._energy_optim_factory.create(
            self._energy_func.parameters(), lr=self._energy_learning_rate
        )

    

    @train_api
    @torch_api(scaler_targets=["obs_t"], action_scaler_targets=["act_t"])
    def update_imitator(
        self, obs_t: torch.Tensor, act_t: torch.Tensor, ns: torch.Tensor, emb: Optional[torch.Tensor] = None
    ) -> list:
        assert self._optim is not None

        self._optim.zero_grad()
        self._drop_optim.zero_grad()
        if emb is None:
            ns = torch.nn.functional.one_hot(
                    ns.view(-1).long(), num_classes=self._drop_num
                ).float()
            self._drop_emb.eval()
            emb = self._drop_emb(ns)#.detach()
        # emb = torch.sigmoid(emb)
        loss = self.compute_loss(obs_t, act_t, emb)

        loss.backward()
        self._optim.step()
        self._drop_optim.step()

        return [np.sqrt(loss.cpu().detach().numpy()), ]

    def compute_loss(
        self, obs_t: torch.Tensor, act_t: torch.Tensor, e: torch.Tensor
    ) -> torch.Tensor:
        assert self._imitator is not None
        return self._imitator.compute_error(obs_t, act_t, e)

    @train_api
    @drop_torch_api()
    def update_critic(self, batch: DropTorchMiniBatch, ns: torch.Tensor) -> list:
        assert self._critic_optim is not None

        self._critic_optim.zero_grad()
        self._drop_optim.zero_grad()

        ns = torch.nn.functional.one_hot(
                ns.view(-1).long(), num_classes=self._drop_num
            ).float()
        emb = self._drop_emb(ns)
        # emb = torch.sigmoid(emb)
        loss = self.compute_loss(batch.observations, batch.actions, emb)

        q_tpn = self.compute_target(batch, emb)
        loss = self.compute_critic_loss(batch, emb, q_tpn)

        q_tpn_pre = self.compute_target_pre(batch, emb)
        loss_pre = self.compute_critic_loss_pre(batch, emb, q_tpn_pre)

        loss_all = loss + loss_pre  + torch.norm(emb, dim=1).mean() * self._emb_norm_weight #

        loss_all.backward()
        self._critic_optim.step()
        self._drop_optim.step()

        self.update_emb_norm_weight(emb, None, self._embedding_size)

        return [np.sqrt(loss.cpu().detach().numpy()), 
                np.sqrt(loss_pre.cpu().detach().numpy()),
                q_tpn_pre.cpu().detach().numpy()]
    
    def compute_target(self, batch: DropTorchMiniBatch, e: torch.Tensor) -> torch.Tensor:
        assert self._targ_q_func is not None
        self._targ_q_func.eval()
        with torch.no_grad():
            # action = self._sample_action(batch.next_observations, e)
            action = batch.next_actions
            return self._targ_q_func.compute_target(
                batch.next_observations,
                action,#.clamp(-1.0, 1.0),
                e,
                reduction="mean",
            )
    
    def compute_target_pre(self, batch: DropTorchMiniBatch, e: torch.Tensor) -> torch.Tensor:
        self._targ_q_func.eval()
        with torch.no_grad():
            return self._targ_q_func.compute_target(
                batch.observations,
                batch.actions,
                e,
                reduction="mean",
            )

    def compute_critic_loss(
        self, batch: DropTorchMiniBatch, e: torch.Tensor, q_tpn: torch.Tensor
    ) -> torch.Tensor:
        assert self._q_func is not None
        return self._q_func.compute_error(
            observations=batch.observations,
            actions=batch.actions,
            embeddings=e,
            rewards=batch.rewards,
            target=q_tpn,
            terminals=batch.terminals,
            gamma=self._gamma**batch.n_steps,
        )
    
    def compute_critic_loss_pre(
        self, batch: DropTorchMiniBatch, e: torch.Tensor, q_tpn_pre: torch.Tensor
    ) -> torch.Tensor:
        assert self._q_func is not None
        return self._q_func.compute_error_pre(
            next_observations=batch.next_observations,
            next_actions=batch.next_actions,
            embeddings=e,
            rewards=batch.rewards,
            terminals=batch.terminals,
            target_pre=q_tpn_pre,
            Inits=batch.Inits,
            Rs=batch.Rs,
            gamma=self._gamma**batch.n_steps,
        )
    
    
    @train_api
    @drop_torch_api()
    def update_energy(self, batch: DropTorchMiniBatch, batch_drop: DropTorchMiniBatch, \
                        ns: torch.Tensor, ns_drop: torch.Tensor) -> list:
        assert self._energy_optim is not None

        self._energy_optim.zero_grad()
        self._drop_optim.zero_grad()
        ns = torch.nn.functional.one_hot(
                ns.view(-1).long(), num_classes=self._drop_num
            ).float()
        self._drop_emb.eval()
        emb = self._drop_emb(ns).detach()
        # emb = torch.sigmoid(emb)

        target = torch.ones((emb.shape[0], 1)).to(self.device)
        target_drop = torch.zeros((emb.shape[0], 1)).to(self.device)
        loss_1 = self.compute_energy_loss(batch, emb, target)
        loss_2 = self.compute_energy_loss(batch_drop, emb, target_drop)
        loss = (loss_1 + loss_2) * 0.1 


        loss.backward()
        self._energy_optim.step()
        self._drop_optim.step()
        
        return [loss_1.cpu().detach().numpy(), loss_2.cpu().detach().numpy()]
    
    def compute_energy_loss(
        self, batch: DropTorchMiniBatch, emb: torch.Tensor, target: torch.Tensor
    ) -> torch.Tensor:
        assert self._energy_func is not None
        return self._energy_func.compute_error(batch.observations, batch.actions, batch.next_observations, emb, target)
    
    def update_emb_norm_weight(
        self, emb: torch.Tensor, emb_drop: Optional[torch.Tensor], drop_dim: int
    ) -> None:
        with torch.no_grad():
            new_weight = self._emb_norm_weight + \
                self._emb_norm_weight_lr * (torch.norm(emb, dim=1).max().cpu().detach().item() - math.sqrt(drop_dim)) + \
                0. if emb_drop is None else \
                self._emb_norm_weight_lr * (torch.norm(emb_drop, dim=1).mean().cpu().detach().item() - math.sqrt(drop_dim))
            new_weight = max(new_weight, 0.0)
            self._emb_norm_weight = new_weight

    @train_api
    @drop_torch_api()
    def update_critic_conservative(
        self, batch: DropTorchMiniBatch, ns: torch.Tensor, emb_con: torch.Tensor
    ) -> list:
        self._critic_optim.zero_grad()
        # self._drop_optim.zero_grad()

        ns = torch.nn.functional.one_hot(
                ns.view(-1).long(), num_classes=self._drop_num
            ).float()
        self._drop_emb.eval()
        emb = self._drop_emb(ns).detach() # TODO 
        # emb = torch.sigmoid(emb)

        q = self.predict_value(batch.observations, batch.actions, emb)
        # q_con = self.predict_value(batch.observations, batch.actions, emb_con) # TODO 
        q_con = self.predict_value(batch.observations, 
                                    self._predict_best_action(batch.observations, emb_con).detach(), 
                                    emb_con)
        loss = - (q - q_con).mean() * self._con_ada_weight
        
        loss.backward()
        self._critic_optim.step()
        # self._drop_optim.step()

        self.update_con_ada_weight(batch, emb, emb_con)
        return [loss.cpu().detach().numpy()]
    
    def update_con_ada_weight(
        self, batch: DropTorchMiniBatch, emb: torch.Tensor, emb_con: torch.Tensor
    ) -> None:
        with torch.no_grad():
            q = self.predict_value(batch.observations, batch.actions, emb)
            q_con = self.predict_value(batch.observations, batch.actions, emb_con)
            new_weight = self._con_ada_weight + self._con_ada_weight_lr * ((q_con - q).mean().cpu().detach().item() - 2.0)
            new_weight = max(new_weight, 0.0)
            # print('setting beta from %.2f to %.2f' % (self.reg_param, new_beta))
            self._con_ada_weight = new_weight

    def _predict_best_action(self, x: torch.Tensor, e: torch.Tensor) -> torch.Tensor:
        assert self._imitator is not None
        return self._imitator(x, e)
        # return self.policy.sample(x, e)

    @torch_api()
    def predict_value(
        self, x: torch.Tensor, action: torch.Tensor, e: torch.Tensor
    ) -> torch.Tensor:
        return self._q_func.compute_target(
                x,
                action,
                e,
                reduction="min", #min
            )
    
    # def predict_value(
    #     self, x: torch.Tensor, action: torch.Tensor, e: torch.Tensor, with_std: bool
    # ) -> torch.Tensor:
    #     return self._q_func.compute_target(
    #             x,
    #             action,
    #             e,
    #             reduction="mean", #min
    #         )

    @eval_api
    @torch_api()
    def get_emb(self, ns: torch.Tensor) -> list:
        with torch.no_grad():
            ns = torch.nn.functional.one_hot(
                ns.view(-1).long(), num_classes=self._drop_num
            ).float()
            emb = self._drop_emb(ns)
            # emb = torch.sigmoid(emb)
        return emb

    
    @eval_api
    @torch_api()
    def update_best(
        self, observations: torch.Tensor, emb: torch.Tensor
    ) -> List[np.ndarray]:
        emb.requires_grad_().to(self.device)
        emb_optim = self._embedding_optim_factory.create(
            [emb], lr=self._embedding_learning_rate
        )
        emb_optim.zero_grad()
        q = self.predict_value(observations, self._predict_best_action(observations, emb), emb)
        losses = - q
        loss = losses.mean()
        loss.backward()
        emb_optim.step()
        return [losses.cpu().detach().numpy(), emb.cpu().detach().numpy()]



    def _sample_action(self, x: torch.Tensor, e: torch.Tensor) -> torch.Tensor:
        assert self._imitator is not None
        # return self.policy.sample(x, e)
        return self._imitator(x, e)

    def update_critic_target(self) -> None:
        assert self._q_func is not None
        assert self._targ_q_func is not None
        soft_sync(self._targ_q_func, self._q_func, self._tau)
    
    def update_drop_emb_target(self) -> None:
        soft_sync(self._targ_drop_emb, self._drop_emb, self._tau)
    
    @property
    def q_function(self) -> DropEnsembleQFunction:
        assert self._q_func
        return self._q_func

    @property
    def q_function_optim(self) -> Optimizer:
        assert self._critic_optim
        return self._critic_optim


class DropImpl(DropBaseImpl):

    _policy_type: str
    _imitator: Optional[Union[DropDeterministicRegressor, DropProbablisticRegressor]]

    def __init__(
        self,
        observation_shape: Sequence[int],
        action_size: int,
        embedding_size: int,
        drop_num: int,
        drop_learning_rate: float,
        drop_optim_factory: OptimizerFactory,
        drop_encoder_factory: EncoderFactory,
        learning_rate: float,
        optim_factory: OptimizerFactory,
        encoder_factory: EncoderFactory,
        critic_learning_rate: float,
        critic_optim_factory: OptimizerFactory,
        critic_encoder_factory: EncoderFactory,
        q_func_factory: DropQFunctionFactory,
        embedding_learning_rate: float,
        embedding_optim_factory: OptimizerFactory,
        energy_learning_rate: float,
        energy_optim_factory: OptimizerFactory,
        energy_encoder_factory: EncoderFactory,
        gamma: float,
        tau: float,
        n_critics: int,
        policy_type: str,
        use_gpu: Optional[Device],
        scaler: Optional[Scaler],
        action_scaler: Optional[ActionScaler],
        reward_scaler: Optional[RewardScaler],
    ):
        super().__init__(
            observation_shape=observation_shape,
            action_size=action_size,
            embedding_size=embedding_size,
            drop_num=drop_num,
            drop_learning_rate=drop_learning_rate,
            drop_optim_factory=drop_optim_factory,
            drop_encoder_factory=drop_encoder_factory,
            learning_rate=learning_rate,
            optim_factory=optim_factory,
            encoder_factory=encoder_factory,
            critic_learning_rate=critic_learning_rate,
            critic_optim_factory=critic_optim_factory,
            critic_encoder_factory=critic_encoder_factory,
            q_func_factory=q_func_factory,
            embedding_learning_rate=embedding_learning_rate,
            embedding_optim_factory=embedding_optim_factory,
            energy_learning_rate=energy_learning_rate,
            energy_optim_factory=energy_optim_factory,
            energy_encoder_factory=energy_encoder_factory,
            gamma=gamma,
            tau=tau,
            n_critics=n_critics,
            use_gpu=use_gpu,
            scaler=scaler,
            action_scaler=action_scaler,
            reward_scaler=reward_scaler,
        )
        self._policy_type = policy_type

    def _build_network(self) -> None:
        if self._policy_type == "deterministic":
            self._imitator = create_drop_deterministic_regressor(
                self._observation_shape,
                self._action_size,
                self._embedding_size,
                self._encoder_factory,
            )
        elif self._policy_type == "stochastic":
            self._imitator = create_drop_probablistic_regressor(
                self._observation_shape,
                self._action_size,
                self._embedding_size,
                self._encoder_factory,
                min_logstd=-4.0,
                max_logstd=15.0,
            )
        else:
            raise ValueError("invalid policy_type: {self._policy_type}")

    @property
    def policy(self) -> DropPolicy:
        assert self._imitator
        assert False

        policy: DropPolicy
        if self._policy_type == "deterministic":
            policy = create_drop_deterministic_policy(
                self._observation_shape,
                self._action_size,
                self._embedding_size,
                self._encoder_factory,
            )
        elif self._policy_type == "stochastic":
            policy = create_drop_squashed_normal_policy(
                self._observation_shape,
                self._action_size,
                self._embedding_size,
                self._encoder_factory,
                min_logstd=-20.0,
                max_logstd=2.0,
            )
        else:
            raise ValueError(f"invalid policy_type: {self._policy_type}")

        # copy parameters
        hard_sync(policy, self._imitator)

        return policy.cuda(f"cuda:{self._use_gpu.get_id()}") if self._use_gpu else policy.cpu()

    @property
    def policy_optim(self) -> Optimizer:
        assert self._optim
        return self._optim


class DropDiscreteImpl(DropBaseImpl):

    _beta: float
    _imitator: Optional[DropDiscreteImitator]

    def __init__(
        self,
        observation_shape: Sequence[int],
        action_size: int,
        embedding_size: int,
        learning_rate: float,
        optim_factory: OptimizerFactory,
        encoder_factory: EncoderFactory,
        beta: float,
        use_gpu: Optional[Device],
        scaler: Optional[Scaler],
    ):
        super().__init__(
            observation_shape=observation_shape,
            action_size=action_size,
            embedding_size=embedding_size,
            learning_rate=learning_rate,
            optim_factory=optim_factory,
            encoder_factory=encoder_factory,
            use_gpu=use_gpu,
            scaler=scaler,
            action_scaler=None,
        )
        self._beta = beta

    def _build_network(self) -> None:
        self._imitator = create_drop_discrete_imitator(
            self._observation_shape,
            self._action_size,
            self._embedding_size,
            self._beta,
            self._encoder_factory,
        )

    def _predict_best_action(self, x: torch.Tensor, e: torch.Tensor) -> torch.Tensor:
        assert self._imitator is not None
        return self._imitator(x, e).argmax(dim=1)

    def compute_loss(
        self, obs_t: torch.Tensor, act_t: torch.Tensor, emb_t: torch.Tensor
    ) -> torch.Tensor:
        assert self._imitator is not None
        return self._imitator.compute_error(obs_t, act_t.long(), emb_t)
