import torch
import torch.nn as nn
import torch.nn.functional as F
from algo.common.impala import ResNetBase

class DiscrimModel(nn.Module):
    def __init__(self, config):
        super(DiscrimModel, self).__init__()
        self.embedding_p = nn.Sequential(nn.Linear(2, config.e_dim), nn.ReLU())
        self.embedding_x = nn.Sequential(nn.Linear(256, config.e_dim), nn.ReLU())
        self.mlp = nn.Sequential(nn.Linear(config.e_dim + config.e_dim, 512),
                                 nn.LeakyReLU(),
                                 nn.Linear(512, 1)
                                 )

    def load(self, path):
        self.load_state_dict(torch.load(path))

    def save(self, path):
        torch.save(self.state_dict(), path)

    def forward(self, x, p):
        n = p.size(1)
        x = self.embedding_x(x)
        p = self.embedding_p(p)
        x = x.unsqueeze(1).repeat(1, n, 1)
        # print(x.shape, p.shape)

        x = self.mlp(torch.cat((x, p), dim=-1))
        return x # [b, n]


class Model(nn.Module):
    def __init__(self, config):
        super(Model, self).__init__()
        self.encoder = ResNetBase(num_inputs=3)
        self.mlp = nn.Sequential(nn.Linear(256, 128),
                                 nn.ReLU(),
                                 nn.Linear(128, config.action_shape))

    def load(self, path):
        self.load_state_dict(torch.load(path))

    def save(self, path):
        torch.save(self.state_dict(), path)

    def forward(self, x, ret_z=False):
        z = self.encoder(x)
        z = torch.flatten(z, 1)  # flatten all dimensions except batch
        x = self.mlp(z)
        x = F.softmax(x, dim=-1)

        if ret_z:
            return x, z
        return x
