from typing import Tuple

import torch
import torch.nn as nn
from torch.nn.utils import spectral_norm


class MLP(torch.nn.Module):

    def __init__(
        self,
        in_dim: int,
        hid_dim: int,
        out_dim: int,
        n_hid_layers: int,
        activation: str = 'relu',
        out_activation: str = 'relu',
        use_spectral_norm: bool = False,
    ):
        super().__init__()
        layers = [
            self._get_linear(in_dim, hid_dim, use_spectral_norm),
            self._get_activation(name=activation, dim=hid_dim)
        ]
        for _ in range(n_hid_layers - 1):
            layers += [
                self._get_linear(hid_dim, hid_dim, use_spectral_norm),
                self._get_activation(name=activation, dim=hid_dim)
            ]
        layers += [
            self._get_linear(hid_dim, out_dim, use_spectral_norm),
            self._get_activation(name=out_activation, dim=out_dim)
        ]

        self.net = nn.Sequential(*layers)

    @staticmethod
    def _get_linear(in_dim: int, out_dim: int, sn: bool = False):
        ret = nn.Linear(in_dim, out_dim)
        return spectral_norm(ret) if sn else ret

    @staticmethod
    def _get_activation(name: str, dim: int):
        if name == 'relu':
            return nn.ReLU()
        elif name == 'none':
            return nn.Identity()
        elif name == 'tanh':
            return nn.Tanh()
        elif name == 'gelu':
            return nn.GELU()
        elif name == 'mish':
            return nn.Mish()
        elif name == 'leaky_relu':
            return nn.LeakyReLU(0.2)
        elif name == 'ln':
            return nn.LayerNorm(normalized_shape=dim)
        else:
            raise ValueError(f'Invalid activation={name} is detected.')

    def forward(self, x):
        return self.net(x)


class Policy(torch.nn.Module):

    def __init__(
        self,
        state_dim: int,
        cond_dim: int,
        out_dim: int,
        domain_dim: int,
        latent_dim: int,
        hid_dim: int,
        num_hidden_layers: Tuple[int, int, int] = [3, 5, 3],
        no_head_domain: bool = False,
        activation: str = 'relu',
        repr_activation: str = 'relu',
        enc_sn: bool = False,
        decode_with_state: bool = False,
    ):
        super().__init__()
        self.state_dim = state_dim
        self.cond_dim = cond_dim
        self.hid_dim = hid_dim
        self.latent_dim = latent_dim
        self.domain_dim = domain_dim
        self.no_head_domain = no_head_domain
        self.decode_with_state = decode_with_state

        self.encoder = MLP(
            in_dim=state_dim + domain_dim,
            hid_dim=hid_dim,
            out_dim=latent_dim,
            n_hid_layers=num_hidden_layers[0],
            activation=activation,
            out_activation=repr_activation,
            use_spectral_norm=enc_sn,
        )

        self.core = MLP(
            in_dim=latent_dim + cond_dim,
            hid_dim=hid_dim,
            out_dim=latent_dim,
            n_hid_layers=num_hidden_layers[1],
            activation=activation,
            out_activation=repr_activation,
        )

        self.head = MLP(
            in_dim=latent_dim + domain_dim +
            state_dim if decode_with_state else latent_dim + domain_dim,
            hid_dim=hid_dim,
            out_dim=out_dim,
            n_hid_layers=num_hidden_layers[2],
            activation=activation,
            out_activation='tanh',
        )

    def forward(self, s, c, d):
        # s: state, c: condition vector (e.g. task ID), d: domain ID
        # encoding
        sd = torch.cat((s, d), dim=-1)
        z = self.encoder(sd)

        # core
        zc = torch.cat((z, c), dim=-1)
        alpha = self.core(zc)

        # decoding
        d_in = torch.zeros_like(d) if self.no_head_domain else d
        alpha_d = torch.cat((alpha, d_in), dim=-1)

        if self.decode_with_state:
            alpha_d = torch.cat((alpha_d, s), dim=-1)

        out = self.head(alpha_d)

        return out, z, alpha


class Discriminator(torch.nn.Module):

    def __init__(
        self,
        latent_dim: int,
        hid_dim: int,
        num_classes: int = 2,
        cond_dim: int = 2,
        sa_disc: bool = False,
        num_hidden_layer: int = 4,
        task_cond: bool = False,
        activation: str = 'relu',
        sn: bool = False,
    ):
        super().__init__()
        self.task_cond = task_cond
        self.sa_disc = sa_disc

        input_dim = latent_dim
        if self.sa_disc:
            input_dim *= 2

        if self.task_cond:
            input_dim += cond_dim
        self.net = MLP(
            in_dim=input_dim,
            hid_dim=hid_dim,
            out_dim=num_classes,
            n_hid_layers=num_hidden_layer,
            activation=activation,
            out_activation='none',
            use_spectral_norm=sn,
        )

    def forward(self, z, z_alpha=None, c=None):
        if self.sa_disc:
            z = torch.cat((z, z_alpha), dim=-1)
        if self.task_cond:
            z = torch.cat((z, c), dim=-1)

        return self.net(z)


class ReconstructionDecoder(nn.Module):

    def __init__(self,
                 latent_dim: int,
                 state_dim: int,
                 hid_dim: int,
                 activation: str = 'relu'):
        super().__init__()
        self.net = MLP(
            in_dim=latent_dim,
            hid_dim=hid_dim,
            out_dim=state_dim,
            n_hid_layers=3,
            activation=activation,
            out_activation='none',
        )

    def forward(self, x):
        return self.net(x)
