import os
import copy

import torch
import gym
import numpy as np
import torch.nn.functional as F

from torch import nn

from expground.types import Dict, DataArray, Any, Tensor, Tuple
from expground.logger import Log
from expground.algorithms import misc
from expground.algorithms.base_policy import Policy, Action, ActionDist, Logits
from expground.common.distributions import Distribution, make_proba_distribution
from expground.algorithms.ppo.config import DEFAULT_CONFIG
from expground.common.models import get_model
from expground.common.models.torch.vision import MODEL_CONFIG


class VisionPPO(Policy):
    def __init__(
        self,
        observation_space,
        action_space,
        model_config,
        custom_config,
        is_fixed: bool = False,
    ):
        assert isinstance(observation_space, gym.spaces.Box), observation_space
        # update model_config here
        _model_config = copy.deepcopy(MODEL_CONFIG)
        _model_config.update(model_config)
        assert len(observation_space.shape) == 3, observation_space.shape
        super(VisionPPO, self).__init__(
            observation_space, action_space, _model_config, custom_config, is_fixed
        )

        self.vision_net = get_model(_model_config)(
            observation_space=observation_space,
            action_space=action_space,
            use_cuda=custom_config.get("use_cuda"),
        )

        self.target_vision_net = get_model(_model_config)(
            observation_space=observation_space,
            action_space=action_space,
            use_cuda=custom_config.get("use_cuda"),
        )

        self._is_discrete_action = isinstance(action_space, gym.spaces.Discrete)
        self._action_dist_handler = make_proba_distribution(
            action_space,
            use_sde=self.custom_config.get("use_sed", False),
            dist_kwargs=self.custom_config.get("dist_kwargs", None),
        )

        self.use_cuda = self.custom_config.get("use_cuda", False)

        if self.use_cuda:
            self.vision_net = self.vision_net.to("cuda")
            self.target_vision_net = self.target_vision_net.to("cuda")

        self.register_state(self.vision_net, "vision_net")
        self.register_state(self.target_vision_net, "target_vision_net")

        # hard update
        self.update_target()

    def update_parameters(self, param):
        target_param = self.vision_net.parameters()
        for tparam, sparam in zip(target_param, param):
            tparam.data.copy_(sparam.data)

    @property
    def is_discrete_action(self) -> bool:
        return self._is_discrete_action

    def value_function(self, critic_state: Tensor = None) -> Tensor:
        if critic_state is not None:
            self.vision_net(
                torch.as_tensor(critic_state, device="cuda" if self.use_cuda else "cpu")
            )
        return self.vision_net.value_function()

    def target_value_function(self, critic_state: Tensor = None) -> Tensor:
        if critic_state is not None:
            self.target_vision_net(
                torch.as_tensor(critic_state, device="cuda" if self.use_cuda else "cpu")
            )
        return self.target_vision_net.value_function()

    def update_target(self, tau: float = 1.0):
        misc.soft_update(self.target_vision_net, self.vision_net, tau)

    def parameters(self):
        return self.vision_net.parameters()

    def compute_action(
        self, observation: DataArray, action_mask: DataArray, evaluate: bool
    ) -> Tuple[Action, ActionDist, Logits]:
        with torch.no_grad():
            observation = torch.as_tensor(
                observation, device="cuda" if self.use_cuda else "cpu"
            )
            action_logits = self.vision_net(observation)
            if action_mask is not None:
                action_mask = torch.from_numpy(action_mask)
                action_logits = misc.masked_logits(
                    action_logits, action_mask, not evaluate
                )
            prob = self._action_dist_handler.proba_distribution(action_logits)
            prob = prob.distribution.probs
            action = self._action_dist_handler.sample()

        return (
            action.detach().cpu().numpy(),
            prob.detach().cpu().numpy(),
            action_logits.detach().cpu().numpy(),
        )

    def compute_actions(
        self,
        observation: Tensor,
        use_target: bool = False,
        action_mask=None,
        explore: bool = False,
    ) -> Tuple[Action, ActionDist, Logits]:
        # only support discrete for now
        if use_target:
            logits = self.target_vision_net(
                torch.as_tensor(observation, device="cuda" if self.use_cuda else "cpu")
            )
        else:
            logits = self.vision_net(
                torch.as_tensor(observation, device="cuda" if self.use_cuda else "cpu")
            )

        if action_mask is not None:
            logits = misc.masked_logits(logits, action_mask, explore=explore)

        action, log_pi = self._action_dist_handler.log_prob_from_params(
            logits, deterministic=not explore
        )

        if self._is_discrete_action:
            action = action.reshape((-1, 1))
            action1hot = torch.zeros(
                logits.shape[0], logits.shape[1], device=action.device
            )
            action1hot.scatter_(1, action, 1)
            action = action1hot

        log_pi = torch.nn.functional.log_softmax(logits, dim=1)
        entropy = self._action_dist_handler.entropy()
        return action, log_pi, entropy

    def to(self, device: str):
        if device is None:
            return self
        cond1 = "cpu" in device and self.use_cuda
        cond2 = "cuda" in device and not self.use_cuda

        if "cpu" in device:
            self.use_cuda = False
        else:
            self.use_cuda = self._custom_config.get("use_cuda", False)

        if cond1 or cond2:
            self.vision_net = self.vision_net.to(device)
            self.target_vision_net = self.target_vision_net.to(device)

        return self

    def save(self, path, global_step=0, hard: bool = False):
        file_exist = os.path.exists(path)
        if file_exist:
            Log.warning("\t! detected existing model with path: {}".format(path))
        if (not file_exist) or hard:
            torch.save(
                {
                    "vision": self.vision_net.state_dict(),
                    "target_vision": self.target_vision_net.state_dict(),
                },
                path,
            )

    def load(self, path: str):
        state_dict = torch.load(path, map_location="cuda" if self.use_cuda else "cpu")
        self.vision_net.load_state_dict(state_dict=state_dict["vision"])
        self.target_vision_net.load_state_dict(state_dict=state_dict["target_vision"])
        misc.soft_update(self.target_vision_net, self.vision_net, tau=1.0)
