import torch
import torch.nn as nn

from large_rl.policy.arch.mlp import MLP


class DDPG_OUNoise(object):
    def __init__(self, dim_action, mu=0.0, theta=0.15, sigma=0.2, env_max_action=1., device="cpu"):
        self.dim_action = dim_action
        self.mu = mu
        self.theta = theta
        self.sigma = sigma * env_max_action
        self.state = torch.ones(self.dim_action, device=device) * self.mu
        self._device = device

    def noise(self, scale):
        x = self.state
        dx = (self.theta * (self.mu - x) + self.sigma * torch.randn(*self.state.shape, device=self._device)) * scale
        self.state = x + dx
        return self.state

    def reset(self, _id=None):
        if _id is not None:
            self.state[_id] = (torch.ones(self.dim_action, device=self._device)[_id]) * self.mu
        else:
            self.state = torch.ones(self.dim_action, device=self._device) * self.mu


class GaussianNoise(object):
    def __init__(self, dim_action, mu=0.0, sigma=0.2, device="cpu", env_max_action=1., **kwargs):
        self.dim_action = dim_action
        self.mu = mu
        self.sigma = sigma * env_max_action
        self._device = device

    def noise(self, scale=1.):
        return torch.normal(mean=self.mu, std=self.sigma, size=self.dim_action, device=self._device) * scale

    def reset(self, id=None):
        pass


class Actor(nn.Module):
    def __init__(self, dim_in=32, dim_hiddens="64_64", dim_out=20, if_init_layer=False, if_norm_each=True,
                 if_norm_final=True, args=None):
        super(Actor, self).__init__()
        self._args = args
        self.model = MLP(dim_in=dim_in, dim_hiddens=dim_hiddens, dim_out=dim_out, type_hidden_act_fn=10,
                         if_init_layer=if_init_layer, if_norm_each=if_norm_each, if_norm_final=if_norm_final)

    def forward(self, inputs, if_joint_update=False, knn_function=None):
        out = self.model(inputs)[:, None, :]
        if self._args["DEBUG_type_activation"] == "tanh":
            out = torch.tanh(out)
            out = out * self._args["env_max_action"]
        elif self._args["DEBUG_type_activation"] == "sigmoid":
            out = torch.sigmoid(out)

        if self._args["WOLP_if_joint_actor"]:  # if update for Wolp-joint then we just output
            if if_joint_update:
                out = out.squeeze(1)  # batch x list-len * dim-action
            else:
                batch, _, mix_dim_action = out.shape
                dim_action = mix_dim_action // self._args["WOLP_cascade_list_len"]
                out = out.reshape(out.shape[0], self._args["WOLP_cascade_list_len"], dim_action)
        return out
