from collections import OrderedDict
from collections.abc import Callable
from typing import Optional, Union

from beartype import beartype
from einops import pack
import torch
from torch import nn
from torch.nn.utils.parametrizations import spectral_norm
from torch.distributions import Normal

from helpers import logger


SAC_LOG_STD_BOUNDS = [-5., 2.]


@beartype
def log_module_info(model: nn.Module):

    def _fmt(n) -> str:
        if n // 10 ** 6 > 0:
            out = str(round(n / 10 ** 6, 2)) + " M"
        elif n // 10 ** 3:
            out = str(round(n / 10 ** 3, 2)) + " k"
        else:
            out = str(n)
        return out

    logger.info("logging model specs")
    logger.info(model)
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    logger.info(f"total trainable params: {_fmt(num_params)}.")


@beartype
def init(constant_bias: float = 0.) -> Callable[[nn.Module], None]:
    """Perform orthogonal initialization"""

    def _init(m: nn.Module) -> None:

        if (isinstance(m, (nn.Conv2d, nn.Linear, nn.Bilinear))):
            nn.init.orthogonal_(m.weight)
            if m.bias is not None:
                nn.init.constant_(m.bias, constant_bias)
        elif (isinstance(m, (nn.BatchNorm2d, nn.LayerNorm))):
            nn.init.ones_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)

    return _init


@beartype
def snwrap(*, enabled: bool = False) -> Callable[[nn.Module], nn.Module]:
    """Spectral normalization wrapper"""

    def _snwrap(m: nn.Module) -> nn.Module:
        if enabled and isinstance(m, (nn.Linear, nn.Bilinear, nn.Conv2d)):
            return spectral_norm(m)
        return m

    return _snwrap


@beartype
def nlfact(activation: str) -> Callable[[], nn.Module]:
    """Non-linearity factory"""

    def _nlfact() -> nn.Module:
        if activation == "mish":
            return nn.Mish(inplace=True)
        if activation == "relu":
            return nn.ReLU(inplace=True)
        if activation.startswith("leaky"):
            _, leak = activation.split("_")
            return nn.LeakyReLU(float(leak), inplace=True)
        raise ValueError("invalid activation")

    return _nlfact


class Discriminator(nn.Module):

    @beartype
    def __init__(self,
                 ob_shape: tuple[int, ...],
                 ac_shape: tuple[int, ...],
                 hid_dims: tuple[int, int],
                 input_mode: str,
                 activation: str,
                 *,
                 spectral_norm: bool,
                 dropout: bool,
                 device: Union[str, torch.device]):
        super().__init__()
        ob_dim = ob_shape[-1]
        ac_dim = ac_shape[-1]

        apply_sn = snwrap(enabled=spectral_norm)
        craft_nl = nlfact(activation=activation)

        # define the input dimension
        in_dim = ob_dim
        match input_mode:
            case "ss":
                in_dim += ob_dim
            case "sa":
                in_dim += ac_dim
            case "s":
                pass
            case _:
                raise ValueError("invalid input mode")

        # assemble the last layers and output heads
        self.fc_stack = nn.Sequential(OrderedDict([
            ("fc_block_1", nn.Sequential(OrderedDict([
                ("fc", apply_sn(nn.Linear(in_dim, hid_dims[0], device=device))),
                ("do", (nn.Dropout(p=0.01) if dropout else nn.Identity(hid_dims[0]))),
                ("nl", craft_nl()),
            ]))),
            ("fc_block_2", nn.Sequential(OrderedDict([
                ("fc", apply_sn(nn.Linear(hid_dims[0], hid_dims[1], device=device))),
                ("do", (nn.Dropout(p=0.01) if dropout else nn.Identity(hid_dims[1]))),
                ("nl", craft_nl()),
            ]))),
        ]))
        self.head = nn.Linear(hid_dims[1], 1, device=device)

        # perform initialization
        self.fc_stack.apply(init())
        self.head.apply(init())

    @beartype
    def forward(self, input_a: torch.Tensor, input_b: Optional[torch.Tensor]) -> torch.Tensor:
        if input_b is not None:
            x, _ = pack([input_a, input_b], "b *")  # concatenate along last dim
        else:
            x = input_a
        x = self.fc_stack(x)
        return self.head(x)  # no sigmoid here


class RandomPredictor(nn.Module):

    def __init__(self,
                 ob_shape: tuple[int, ...],
                 ac_shape: tuple[int, ...],
                 hid_dims: tuple[int, int],
                 input_mode: str,
                 activation: str,
                 out_size: int,
                 out_scale: Optional[float],
                 *,
                 spectral_norm: bool,
                 dropout: bool,
                 v2: bool,
                 device: Union[str, torch.device],
                 make_untrainable: bool = False):
        super().__init__()
        ob_dim = ob_shape[-1]
        ac_dim = ac_shape[-1]
        self.input_mode = input_mode
        self.activation = activation
        self.spectral_norm = spectral_norm
        self.dropout = dropout
        self.v2 = v2
        self.make_untrainable = make_untrainable
        self.out_scale = out_scale

        apply_sn = snwrap(enabled=self.spectral_norm)
        craft_nl = nlfact(activation=self.activation)

        if self.v2:
            # define the input dimensions
            assert self.input_mode == "sa", "s and ss modes not allowed here"
            if self.make_untrainable:
                # assemble the layers and output heads
                self.mini_fc_stack_a = nn.Sequential(OrderedDict([
                    ("fc_block_1", nn.Sequential(OrderedDict([
                        ("fc", apply_sn(nn.Linear(ob_dim, 2 * hid_dims[0], device=device))),
                    ]))),
                ]))
                self.mini_fc_stack_b = nn.Sequential(OrderedDict([
                    ("fc_block_1", nn.Sequential(OrderedDict([
                        ("fc", apply_sn(nn.Linear(ac_dim, hid_dims[0], device=device))),
                        ("nl", craft_nl()),
                    ]))),
                    ("fc_block_2", nn.Sequential(OrderedDict([
                        ("fc", apply_sn(nn.Linear(hid_dims[0], hid_dims[1], device=device))),
                    ]))),
                ]))
                self.mini_fc_stack_o = nn.Sequential(OrderedDict([
                    ("fc_block_1", nn.Sequential(OrderedDict([
                        ("nl", craft_nl()),
                        ("fc", apply_sn(nn.Linear(hid_dims[1], out_size, device=device))),
                    ]))),
                ]))
                # perform initialization
                self.mini_fc_stack_a.apply(init())
                self.mini_fc_stack_b.apply(init())
                self.mini_fc_stack_o.apply(init())
                # prevent the weights from ever being updated
                for stack in [self.mini_fc_stack_a, self.mini_fc_stack_b, self.mini_fc_stack_o]:
                    for param in stack.parameters():
                        param.requires_grad = False
            else:
                # assemble the layers and output heads
                self.bilinear = apply_sn(nn.Bilinear(ob_dim, ac_dim, hid_dims[0], device=device))
                # a bit weird, but `Sequential.forward` only takes 1 argument
                self.fc_stack = nn.Sequential(OrderedDict([
                    ("fc_block_1", nn.Sequential(OrderedDict([
                        ("nl", craft_nl()),
                    ]))),
                    ("fc_block_2", nn.Sequential(OrderedDict([
                        ("fc", apply_sn(nn.Linear(hid_dims[0], hid_dims[1], device=device))),
                        ("nl", craft_nl()),
                    ]))),
                    ("fc_block_3", nn.Sequential(OrderedDict([
                        ("fc", apply_sn(nn.Linear(hid_dims[1], out_size, device=device))),
                    ]))),
                ]))
                # perform initialization
                self.bilinear.apply(init())
                self.fc_stack.apply(init())
        else:
            # define the input dimension
            in_dim = ob_dim
            match self.input_mode:
                case "ss":
                    in_dim += ob_dim
                case "sa":
                    in_dim += ac_dim
                case "s":
                    pass
                case _:
                    raise ValueError("invalid input mode")

            # assemble the layers and output heads
            self.fc_stack = nn.Sequential(OrderedDict([
                ("fc_block_1", nn.Sequential(OrderedDict([
                    ("fc", apply_sn(nn.Linear(in_dim, hid_dims[0], device=device))),
                    ("do", (nn.Dropout(p=0.01) if self.dropout else nn.Identity(hid_dims[0]))),
                    ("nl", craft_nl()),
                ]))),
                ("fc_block_2", nn.Sequential(OrderedDict([
                    ("fc", apply_sn(nn.Linear(hid_dims[0], hid_dims[1], device=device))),
                    ("do", (nn.Dropout(p=0.01) if self.dropout else nn.Identity(hid_dims[1]))),
                    ("nl", craft_nl()),
                ]))),
                ("fc_block_3", nn.Sequential(OrderedDict([
                    ("fc", apply_sn(nn.Linear(hid_dims[1], out_size, device=device))),
                ]))),
            ]))

            # perform initialization (orthogonal)
            self.fc_stack.apply(init())

            if self.make_untrainable:
                # prevent the weights from ever being updated
                for param in self.fc_stack.parameters():
                    param.requires_grad = False

    @beartype
    def forward(self, input_a: torch.Tensor, input_b: Optional[torch.Tensor]) -> torch.Tensor:
        if self.v2:
            if self.make_untrainable:
                # assemble prior network
                gamma_beta = self.mini_fc_stack_a(input_a)  # `input_a` is always s_t
                gamma, beta = torch.chunk(gamma_beta, 2, dim=-1)
                h = self.mini_fc_stack_b(input_b)  # `input_b` is either a_t or s_{t+1}
                film = gamma * h + beta
                x = self.mini_fc_stack_o(film)
            else:
                # assemble predictor network
                x = self.bilinear(input_a, input_b)
                x = self.fc_stack(x)
        else:
            if input_b is not None:
                x, _ = pack([input_a, input_b], "b *")  # concatenate along last dim
            else:
                x = input_a
            x = self.fc_stack(x)

        if self.out_scale is not None:
            return self.out_scale * torch.tanh(x)

        return x


class Critic(nn.Module):

    @beartype
    def __init__(self,
                 ob_shape: tuple[int, ...],
                 ac_shape: tuple[int, ...],
                 hid_dims: tuple[int, int],
                 *,
                 layer_norm: bool,
                 device: Union[str, torch.device]):
        super().__init__()
        ob_dim = ob_shape[-1]
        ac_dim = ac_shape[-1]
        self.layer_norm = layer_norm

        # assemble the last layers and output heads
        self.fc_stack = nn.Sequential(OrderedDict([
            ("fc_block_1", nn.Sequential(OrderedDict([
                ("fc", nn.Linear(ob_dim + ac_dim, hid_dims[0], device=device)),
                ("ln", (nn.LayerNorm if self.layer_norm else nn.Identity)(hid_dims[0],
                                                                          device=device)),
                ("nl", nn.ReLU()),
            ]))),
            ("fc_block_2", nn.Sequential(OrderedDict([
                ("fc", nn.Linear(hid_dims[0], hid_dims[1], device=device)),
                ("ln", (nn.LayerNorm if self.layer_norm else nn.Identity)(hid_dims[1],
                                                                          device=device)),
                ("nl", nn.ReLU()),
            ]))),
        ]))
        self.head = nn.Linear(hid_dims[1], 1, device=device)

        # perform initialization
        self.fc_stack.apply(init())
        self.head.apply(init())

    @beartype
    def forward(self, ob: torch.Tensor, ac: torch.Tensor) -> torch.Tensor:
        x, _ = pack([ob, ac], "b *")
        x = self.fc_stack(x)
        return self.head(x)


class Actor(nn.Module):

    @beartype
    def __init__(self,
                 ob_shape: tuple[int, ...],
                 ac_shape: tuple[int, ...],
                 hid_dims: tuple[int, int],
                 min_ac: torch.Tensor,
                 max_ac: torch.Tensor,
                 *,
                 exploration_noise: float,
                 layer_norm: bool,
                 device: Union[str, torch.device]):
        super().__init__()
        ob_dim = ob_shape[-1]
        ac_dim = ac_shape[-1]
        self.layer_norm = layer_norm

        # assemble the last layers and output heads
        self.fc_stack = nn.Sequential(OrderedDict([
            ("fc_block_1", nn.Sequential(OrderedDict([
                ("fc", nn.Linear(ob_dim, hid_dims[0], device=device)),
                ("ln", (nn.LayerNorm if self.layer_norm else nn.Identity)(hid_dims[0],
                                                                          device=device)),
                ("nl", nn.ReLU()),
            ]))),
            ("fc_block_2", nn.Sequential(OrderedDict([
                ("fc", nn.Linear(hid_dims[0], hid_dims[1], device=device)),
                ("ln", (nn.LayerNorm if self.layer_norm else nn.Identity)(hid_dims[1],
                                                                          device=device)),
                ("nl", nn.ReLU()),
            ]))),
        ]))
        self.head = nn.Linear(hid_dims[1], ac_dim, device=device)

        # perform initialization
        self.fc_stack.apply(init())
        self.head.apply(init())

        # register buffers: action rescaling
        self.register_buffer("action_scale",
            (max_ac - min_ac) / 2.0)
        self.register_buffer("action_bias",
            (max_ac + min_ac) / 2.0)
        # register buffers: exploration
        self.register_buffer("exploration_noise",
            torch.as_tensor(exploration_noise, device=device))

    @beartype
    def forward(self, ob: torch.Tensor) -> torch.Tensor:
        x = self.fc_stack(ob)
        x = self.head(x)
        return torch.tanh(x) * self.action_scale + self.action_bias

    @beartype
    def exploit(self, ob: torch.Tensor) -> dict[str, torch.Tensor]:
        ac = self(ob)
        return {"action": ac}

    @beartype
    def explore(self, ob: torch.Tensor) -> dict[str, torch.Tensor]:
        ac = self(ob)
        return {
            "action": ac + torch.randn_like(ac).mul(self.action_scale * self.exploration_noise),
        }


class TanhGaussActor(nn.Module):

    @beartype
    def __init__(self,
                 ob_shape: tuple[int, ...],
                 ac_shape: tuple[int, ...],
                 hid_dims: tuple[int, int],
                 min_ac: torch.Tensor,
                 max_ac: torch.Tensor,
                 *,
                 generator: torch.Generator,
                 layer_norm: bool,
                 device: Union[str, torch.device]):
        super().__init__()
        ob_dim = ob_shape[-1]
        ac_dim = ac_shape[-1]
        self.rng = generator
        self.layer_norm = layer_norm

        # assemble the last layers and output heads
        self.fc_stack = nn.Sequential(OrderedDict([
            ("fc_block_1", nn.Sequential(OrderedDict([
                ("fc", nn.Linear(ob_dim, hid_dims[0], device=device)),
                ("ln", (nn.LayerNorm if self.layer_norm else nn.Identity)(hid_dims[0],
                                                                          device=device)),
                ("nl", nn.ReLU()),
            ]))),
            ("fc_block_2", nn.Sequential(OrderedDict([
                ("fc", nn.Linear(hid_dims[0], hid_dims[1], device=device)),
                ("ln", (nn.LayerNorm if self.layer_norm else nn.Identity)(hid_dims[1],
                                                                          device=device)),
                ("nl", nn.ReLU()),
            ]))),
        ]))
        self.head = nn.Linear(hid_dims[1], 2 * ac_dim, device=device)

        # perform initialization
        self.fc_stack.apply(init())
        self.head.apply(init())

        # register buffers: action rescaling
        self.register_buffer("action_scale",
            (max_ac - min_ac) / 2.0)
        self.register_buffer("action_bias",
            (max_ac + min_ac) / 2.0)

    @staticmethod
    @beartype
    def bound_log_std(log_std: torch.Tensor) -> torch.Tensor:
        """Stability trick from OpenAI SpinUp / Denis Yarats"""
        log_std = torch.tanh(log_std)
        log_std_min, log_std_max = SAC_LOG_STD_BOUNDS
        return log_std_min + 0.5 * (log_std_max - log_std_min) * (log_std + 1)

    @beartype
    def forward(self, ob: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        x = self.fc_stack(ob)
        mean, log_std = self.head(x).chunk(2, dim=-1)
        log_std = self.bound_log_std(log_std)
        std = log_std.exp()
        return mean, std

    @beartype
    def get_action(self, ob: torch.Tensor) -> dict[str, torch.Tensor]:
        mean, std = self(ob)
        normal = Normal(mean, std)
        x_t = normal.rsample()  # for reparameterization trick (mean + std * N(0,1))
        y_t = torch.tanh(x_t)
        action = y_t * self.action_scale + self.action_bias
        log_prob = normal.log_prob(x_t)
        # enforcing action bound
        log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + 1e-6)
        log_prob = log_prob.sum(1, keepdim=True)
        mean = torch.tanh(mean) * self.action_scale + self.action_bias
        return {"sample": action, "log_prob": log_prob, "mode": mean}
