import torch
import torch.nn as nn


class FCN(nn.Module):
    def __init__(self, d_model, n_commands, n_args, args_dim=256, abs_targets=False):
        super().__init__()

        self.n_args = n_args
        self.args_dim = args_dim
        self.abs_targets = abs_targets

        self.command_fcn = nn.Linear(d_model, n_commands)

        if abs_targets:
            self.args_fcn = nn.Linear(d_model, n_args)
        else:
            self.args_fcn = nn.Linear(d_model, n_args * args_dim)

    def forward(self, out):
        S, N, _ = out.shape

        command_logits = self.command_fcn(out)  # Shape [S, N, n_commands]
        args_logits = self.args_fcn(out)       # Shape [S, N, n_args * args_dim]

        if not self.abs_targets:
            args_logits = args_logits.reshape(S, N, self.n_args, self.args_dim)  # Shape [S, N, n_args, args_dim]

        return command_logits, args_logits


class ArgumentFCN(nn.Module):
    def __init__(self, d_model, n_args, args_dim=256, abs_targets=False):
        super().__init__()

        self.n_args = n_args
        self.args_dim = args_dim
        self.abs_targets = abs_targets

        # classification -> regression
        if abs_targets:
            self.args_fcn = nn.Sequential(
                nn.Linear(d_model, n_args * args_dim),
                nn.Linear(n_args * args_dim, n_args)
            )
        else:
            self.args_fcn = nn.Linear(d_model, n_args * args_dim)

    def forward(self, out):
        S, N, _ = out.shape

        args_logits = self.args_fcn(out)  # Shape [S, N, n_args * args_dim]

        if not self.abs_targets:
            args_logits = args_logits.reshape(S, N, self.n_args, self.args_dim)  # Shape [S, N, n_args, args_dim]

        return args_logits


class HierarchFCN(nn.Module):
    def __init__(self, d_model, dim_z):
        super().__init__()

        # self.visibility_fcn = nn.Linear(d_model, 2)
        # self.z_fcn = nn.Linear(d_model, dim_z)
        self.visibility_fcn = nn.Linear(dim_z, 2)
        self.z_fcn = nn.Linear(dim_z, dim_z)

    def forward(self, out):
        G, N, _ = out.shape

        visibility_logits = self.visibility_fcn(out)  # Shape [G, N, 2]
        z = self.z_fcn(out)  # Shape [G, N, dim_z]

        return visibility_logits.unsqueeze(0), z.unsqueeze(0)


class ResNet(nn.Module):
    def __init__(self, d_model):
        super().__init__()

        self.linear1 = nn.Sequential(
            nn.Linear(d_model, d_model), nn.ReLU()
        )
        self.linear2 = nn.Sequential(
            nn.Linear(d_model, d_model), nn.ReLU()
        )
        self.linear3 = nn.Sequential(
            nn.Linear(d_model, d_model), nn.ReLU()
        )
        self.linear4 = nn.Sequential(
            nn.Linear(d_model, d_model), nn.ReLU()
        )

    def forward(self, z):
        z = z + self.linear1(z)
        z = z + self.linear2(z)
        z = z + self.linear3(z)
        z = z + self.linear4(z)

        return z
