from typing import Tuple

import torch
import torch.nn as nn

from common.ours.models import MLP


class StateConverter(nn.Module):

    def __init__(
        self,
        in_state_dim: int,
        out_state_dim: int,
        hid_dim: int = 256,
        latent_dim: int = 256,
        num_hidden_layers: Tuple[int, int] = (4, 4),
        activation: str = "relu",
    ):
        super().__init__()

        self.encoder = MLP(
            in_dim=in_state_dim,
            hid_dim=hid_dim,
            out_dim=latent_dim,
            n_hid_layers=num_hidden_layers[0],
            activation=activation,
            out_activation="none",
        )

        self.decoder = MLP(
            in_dim=latent_dim,
            hid_dim=hid_dim,
            out_dim=out_state_dim,
            n_hid_layers=num_hidden_layers[0],
            activation=activation,
            out_activation="none",
        )

    def forward(self, state):
        z = self.encoder(state)
        out_state = self.decoder(z)
        return out_state, z


class PositionalEncoder(nn.Module):

    def __init__(
        self,
        state_dim: int,
        cond_dim: int,
        hid_dim: int = 256,
        num_hidden_layers: int = 4,
        activation: str = "relu",
    ):
        super().__init__()

        self.cond_dim = cond_dim
        self.net = MLP(
            in_dim=state_dim + cond_dim,
            hid_dim=hid_dim,
            out_dim=1,
            n_hid_layers=num_hidden_layers,
            activation=activation,
            out_activation="sigmoid",
        )

    def forward(self, state, cond=None):
        if self.cond_dim > 0 and cond is not None:
            s = torch.cat((state, cond), dim=-1)
        else:
            s = state
        return self.net(s)


class Discriminator(torch.nn.Module):
    """Discriminator with gradient reversal
    """

    def __init__(
        self,
        latent_dim: int,
        hid_dim: int,
        num_classes: int = 2,
        cond_dim: int = 2,
        sa_disc: bool = False,
        num_hidden_layers: int = 4,
        task_cond: bool = False,
        activation: str = "relu",
        sn: bool = True,
        adv_coef: float = 1.0,
    ):
        super().__init__()
        self.task_cond = task_cond
        self.sa_disc = sa_disc

        input_dim = latent_dim
        if self.sa_disc:
            input_dim *= 2

        if self.task_cond:
            input_dim += cond_dim
        self.net = MLP(
            in_dim=input_dim,
            hid_dim=hid_dim,
            out_dim=num_classes,
            n_hid_layers=num_hidden_layers,
            activation=activation,
            out_activation="none",
        )

        self.gradient_reversal_layer = GradientReversalLayer(scale=adv_coef)

    def forward(self, z, z_alpha=None, c=None):
        if self.sa_disc:
            z = torch.cat((z, z_alpha), dim=-1)
        if self.task_cond:
            z = torch.cat((z, c), dim=-1)

        return self.net(self.gradient_reversal_layer(z))


class GradientReversalFunction(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input_forward: torch.Tensor,
                scale: torch.Tensor) -> torch.Tensor:
        ctx.save_for_backward(scale)
        return input_forward

    @staticmethod
    def backward(
            ctx,
            grad_backward: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        scale, = ctx.saved_tensors
        return scale * -grad_backward, None


class GradientReversalLayer(nn.Module):

    def __init__(self, scale: float):
        super(GradientReversalLayer, self).__init__()
        self.scale = torch.tensor(scale)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return GradientReversalFunction.apply(x, self.scale)


class InverseDynamicsModel(nn.Module):

    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        hid_dim: int,
        num_hidden_layers: int = 2,
        activation: str = "relu",
    ):
        super().__init__()

        self.net = MLP(
            in_dim=state_dim * 2,
            hid_dim=hid_dim,
            out_dim=action_dim,
            n_hid_layers=num_hidden_layers,
            activation=activation,
            out_activation="tanh",
        )

    def forward(self, state, next_state):
        s = torch.cat((state, next_state), dim=-1)
        return self.net(s)


class Policy(nn.Module):

    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        hid_dim: int,
        num_hidden_layers: int = 2,
        activation: str = "relu",
        image_observation: bool = False,
        image_state_dim: int = 0,
        use_image_decoder: bool = False,
        coord_conv: bool = False,
        pretrained: bool = False,
    ):
        super().__init__()

        self.net = MLP(
            in_dim=state_dim,
            hid_dim=hid_dim,
            out_dim=action_dim,
            n_hid_layers=num_hidden_layers,
            activation=activation,
            out_activation="tanh",
        )

        if image_observation:
            from common.ours.models import ConvDecoder, ConvEncoder
            self.source_image_encoder = ConvEncoder(
                in_channels=3,
                out_dim=image_state_dim,
                pretrained=pretrained,
                coord_conv=coord_conv,
            )
            self.target_image_encoder = ConvEncoder(
                in_channels=3,
                out_dim=image_state_dim,
                pretrained=pretrained,
                coord_conv=coord_conv,
            )
            self.image_encoder = self.target_image_encoder

            if use_image_decoder:
                self.source_image_decoder = ConvDecoder(
                    image_latent_dim=image_state_dim)
                self.target_image_decoder = ConvDecoder(
                    image_latent_dim=image_state_dim)

    def forward(self, state):
        return self.net(state)
