from utils.nets import BasicNeuralNetwork
from torch import nn
from agent.args import DEVICE3
from agent.params import PARAMS
import torch

DEVICE = DEVICE3


class MLPModel(BasicNeuralNetwork):
    def __init__(
        self,
        state_dim: int,
        action_hidden_dim: int,
        act_dim: int,
        seq_len: int,
    ):
        super(MLPModel, self).__init__()

        self.state_dim = state_dim
        self.hidden_size = PARAMS["transformer"]["hidden_size"]
        self.seq_len = PARAMS["transformer"]["max_step_len"]
        self.num_centers = PARAMS["num_centers"]
        self.num_layers = PARAMS["transformer"]["num_layers"]
        assert self.num_layers >= 2

        self.act_dim = act_dim
        self.action_hidden_dim = self.act_dim * 2

        print(
            f"in mlp model: {self.num_centers}, {self.num_layers}, {self.hidden_size}, {self.action_hidden_dim}"
        )

        self.state_embed = nn.Linear(state_dim, self.state_dim * 2).to(DEVICE)
        nets = [
            nn.Linear(self.seq_len * self.state_dim * 2, self.hidden_size),
            nn.ReLU(),
        ]
        for _ in range(self.num_layers - 2):
            nets.extend(
                [
                    nn.Linear(self.hidden_size, self.hidden_size),
                    nn.ReLU(),
                ]
            )
        nets.extend(
            [
                nn.Linear(self.hidden_size, action_hidden_dim * (seq_len - 1)),
                nn.ReLU(),
            ]
        )
        self.net = nn.Sequential(*nets).to(DEVICE)
        self.centroid_prob = nn.Sequential(
            nn.Linear(action_hidden_dim, self.num_centers), nn.LogSoftmax(dim=-1)
        ).to(DEVICE)
        self.offset = nn.Linear(action_hidden_dim, act_dim * self.num_centers).to(
            DEVICE
        )
        self.mask = nn.Embedding(1, self.state_dim * 2).to(DEVICE)

    def forward(self, states, masks=None):
        bs, seq_length, state_dim = states.shape
        assert state_dim == self.state_dim
        if masks is not None:
            assert masks.shape == states.shape[:-1]
            assert masks.dtype == torch.bool
            assert len(masks.sum(dim=-1).unique()) == 1

        assert seq_length == self.seq_len

        state_embeddings = self.state_embed(
            states  # .reshape((bs, self.seq_len * self.state_dim))
        )
        if masks is not None:
            _mask = self.mask(torch.tensor(0, dtype=torch.long, device=DEVICE))
            state_embeddings[masks] = _mask

        action_hiddens = self.net(
            state_embeddings.reshape((bs, self.seq_len * self.state_dim * 2))
        ).reshape((bs, self.seq_len - 1, self.action_hidden_dim))

        return self.centroid_prob(action_hiddens), self.offset(action_hiddens).reshape(
            (bs, self.seq_len - 1, self.num_centers, self.act_dim)
        )
