from abc import ABCMeta, abstractmethod
from typing import Optional, Sequence, Union

import numpy as np
import torch
from torch.optim import Optimizer

from ...gpu import Device
from ...models.builders import (
    create_deterministic_policy,
    create_deterministic_regressor,
    create_discrete_imitator,
    create_probablistic_regressor,
    create_squashed_normal_policy,
    create_ensemble_discrete_imitator,
)
from ...models.encoders import EncoderFactory
from ...models.optimizers import OptimizerFactory
from ...models.torch import (
    DeterministicRegressor,
    DiscreteImitator,
    Imitator,
    Policy,
    ProbablisticRegressor,
)
from ...preprocessing import ActionScaler, Scaler
from ...torch_utility import hard_sync, torch_api, train_api
from .base import TorchImplBase

import ipdb
import torch.nn.functional as F
import wandb

class BCBaseImpl(TorchImplBase, metaclass=ABCMeta):

    _learning_rate: float
    _optim_factory: OptimizerFactory
    _encoder_factory: EncoderFactory
    _use_gpu: Optional[Device]
    _imitator: Optional[Imitator]
    _optim: Optional[Optimizer]

    def __init__(
        self,
        observation_shape: Sequence[int],
        action_size: int,
        learning_rate: float,
        optim_factory: OptimizerFactory,
        encoder_factory: EncoderFactory,
        use_gpu: Optional[Device],
        scaler: Optional[Scaler],
        action_scaler: Optional[ActionScaler],
    ):
        super().__init__(
            observation_shape=observation_shape,
            action_size=action_size,
            scaler=scaler,
            action_scaler=action_scaler,
            reward_scaler=None,
        )
        self._learning_rate = learning_rate
        self._optim_factory = optim_factory
        self._encoder_factory = encoder_factory
        self._use_gpu = use_gpu

        # initialized in build
        self._imitator = None
        self._optim = None

        self.wandb_iterator = 0
        # wandb.define_metric("rest_loss", step_metric='custom_loss_step')
        # wandb.define_metric("masked_loss", step_metric='custom_loss_step')

    def build(self) -> None:
        self._build_network()

        if self._use_gpu:
            self.to_gpu(self._use_gpu)
        else:
            self.to_cpu()

        self._build_optim()

    @abstractmethod
    def _build_network(self) -> None:
        pass

    def _build_optim(self) -> None:
        assert self._imitator is not None
        self._optim = self._optim_factory.create(
            self._imitator.parameters(), lr=self._learning_rate
        )

    @train_api
    @torch_api(scaler_targets=["obs_t"], action_scaler_targets=["act_t"])
    def update_imitator(
        self, batch, act_t: torch.Tensor # obs_t: torch.Tensor
    ) -> np.ndarray:
        assert self._optim is not None

        self._optim.zero_grad()
        obs_t = batch.observations
        # loss = self.compute_loss(obs_t, act_t)
        loss = self.compute_loss_2(batch, act_t)

        loss.backward()
        self._optim.step()

        return loss.cpu().detach().numpy()

    def compute_error(
        self, batch, action: torch.Tensor
    ) -> torch.Tensor:
        x = batch.observations
        log_probs, logits = self._imitator.compute_log_probs_with_logits(x)
        penalty = (logits ** 2)
        nll_loss = F.nll_loss(log_probs, action.view(-1).long(), reduction='none')
        self.wandb_iterator += 1
        return nll_loss.mean() + self._beta * penalty.mean()

    def compute_weighted_error(
        self, batch, action: torch.Tensor
    ) -> torch.Tensor:
        x = batch.observations
        log_probs, logits = self._imitator.compute_log_probs_with_logits(x)
        penalty = (logits ** 2)
        nll_loss = F.nll_loss(log_probs, action.view(-1).long(), reduction='none')

        vals_list = [92, 108, 125, 133, 192, 239, 447, 466, 285, 303, 317, 350, 364, 400, 410, 48, 61, 467, 478, 524, 544, 129, 599,]
        # vals_list  = [48, 61, 467, 478, 524, 544, 129, 599]
        loss_mask = (sum(batch.ep_ids==i for i in vals_list).bool()).view(-1)
        # loss_mask = ( ((batch.ep_ids==129) & (batch.tr_ids>5) & (batch.tr_ids<15)) | ((batch.ep_ids==600) & (batch.tr_ids<8)) ).view(-1)
        # loss_mask = (batch.ep_ids>=599).view(-1) #  & (batch.tr_ids!=4)

        # if loss_mask.sum()>0:
        #     wandb.log(
        #         {
        #             "masked_loss": nll_loss[loss_mask].mean().item(),
        #             "custom_loss_step": self.wandb_iterator,
        #         },
        #         step=self.epoch,
        #     )
        #     print('masked: ', nll_loss[loss_mask].mean())
        # print('rest: ', nll_loss[~loss_mask].mean())

        # wandb.log(
        #     {
        #         "rest_loss": nll_loss[~loss_mask].mean().item(),
        #         "custom_loss_step": self.wandb_iterator,                   
        #     },
        #     step=self.epoch,
        # )

        factor = 10.0
        nll_loss[loss_mask] = nll_loss[loss_mask]/factor
        penalty[loss_mask] = penalty[loss_mask]/factor

        self.wandb_iterator += 1

        return nll_loss.mean() + self._beta * penalty.mean()

    def compute_loss_2(
        self, batch, act_t: torch.Tensor
    ) -> torch.Tensor:
        assert self._imitator is not None
        return self.compute_error(batch, act_t)

    def compute_loss(
        self, obs_t: torch.Tensor, act_t: torch.Tensor
    ) -> torch.Tensor:
        assert self._imitator is not None
        return self._imitator.compute_error(obs_t, act_t)

    def _predict_best_action(self, x: torch.Tensor) -> torch.Tensor:
        assert self._imitator is not None
        return self._imitator(x)

    def predict_value(
        self, x: np.ndarray, action: np.ndarray, with_std: bool
    ) -> np.ndarray:
        raise NotImplementedError("BC does not support value estimation")


class BCImpl(BCBaseImpl):

    _policy_type: str
    _imitator: Optional[Union[DeterministicRegressor, ProbablisticRegressor]]

    def __init__(
        self,
        observation_shape: Sequence[int],
        action_size: int,
        learning_rate: float,
        optim_factory: OptimizerFactory,
        encoder_factory: EncoderFactory,
        policy_type: str,
        use_gpu: Optional[Device],
        scaler: Optional[Scaler],
        action_scaler: Optional[ActionScaler],
    ):
        super().__init__(
            observation_shape=observation_shape,
            action_size=action_size,
            learning_rate=learning_rate,
            optim_factory=optim_factory,
            encoder_factory=encoder_factory,
            use_gpu=use_gpu,
            scaler=scaler,
            action_scaler=action_scaler,
        )
        self._policy_type = policy_type

    def _build_network(self) -> None:
        if self._policy_type == "deterministic":
            self._imitator = create_deterministic_regressor(
                self._observation_shape,
                self._action_size,
                self._encoder_factory,
            )
        elif self._policy_type == "stochastic":
            self._imitator = create_probablistic_regressor(
                self._observation_shape,
                self._action_size,
                self._encoder_factory,
                min_logstd=-4.0,
                max_logstd=15.0,
            )
        else:
            raise ValueError("invalid policy_type: {self._policy_type}")

    @property
    def policy(self) -> Policy:
        assert self._imitator

        policy: Policy
        if self._policy_type == "deterministic":
            policy = create_deterministic_policy(
                self._observation_shape,
                self._action_size,
                self._encoder_factory,
            )
        elif self._policy_type == "stochastic":
            policy = create_squashed_normal_policy(
                self._observation_shape,
                self._action_size,
                self._encoder_factory,
                min_logstd=-20.0,
                max_logstd=2.0,
            )
        else:
            raise ValueError(f"invalid policy_type: {self._policy_type}")

        # copy parameters
        hard_sync(policy, self._imitator)

        return policy

    @property
    def policy_optim(self) -> Optimizer:
        assert self._optim
        return self._optim


class DiscreteBCImpl(BCBaseImpl):

    _beta: float
    _imitator: Optional[DiscreteImitator]
    _ensemble_size: int

    def __init__(
        self,
        observation_shape: Sequence[int],
        action_size: int,
        learning_rate: float,
        optim_factory: OptimizerFactory,
        encoder_factory: EncoderFactory,
        beta: float,
        use_gpu: Optional[Device],
        scaler: Optional[Scaler],
        ensemble_size=1,
    ):
        super().__init__(
            observation_shape=observation_shape,
            action_size=action_size,
            learning_rate=learning_rate,
            optim_factory=optim_factory,
            encoder_factory=encoder_factory,
            use_gpu=use_gpu,
            scaler=scaler,
            action_scaler=None,
        )
        self._beta = beta
        self._ensemble_size = ensemble_size

    def _build_network(self) -> None:
        if self._ensemble_size == 1:
            self._imitator = create_discrete_imitator(
                self._observation_shape,
                self._action_size,
                self._beta,
                self._encoder_factory,
            )
        else:
            self._imitator = create_ensemble_discrete_imitator(
                self._observation_shape,
                self._action_size,
                self._beta,
                self._encoder_factory,
                self._ensemble_size,
            )

    def _predict_best_action(self, x: torch.Tensor) -> torch.Tensor:
        assert self._imitator is not None
        # print(self._imitator(x))
        return self._imitator(x).argmax(dim=1)

    def compute_loss(
        self, obs_t: torch.Tensor, act_t: torch.Tensor
    ) -> torch.Tensor:
        assert self._imitator is not None
        return self._imitator.compute_error(obs_t, act_t.long())
