import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


class PolicyNet(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(PolicyNet, self).__init__()

        self.mlp_1 = nn.Sequential(
            nn.Linear(state_dim[0], 128),
            nn.LeakyReLU(),
            nn.Linear(128, 128),
            nn.LeakyReLU()
        )

        self.mlp_2 = nn.Sequential(
            nn.Linear(state_dim[1], 32),
            nn.LeakyReLU(),
            nn.Linear(32, 32),
            nn.LeakyReLU()
        )

        self.mlp = nn.Sequential(
            nn.Linear(160, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 128),
            nn.LeakyReLU(),
        )

        self.value_head = nn.Sequential(
            nn.Linear(128, 64),
            nn.LeakyReLU(),
            nn.Linear(64, 32),
            nn.LeakyReLU(),
            nn.Linear(32, 1)
        )

        self.policy_head = nn.Sequential(
            nn.Linear(128, 64),
            nn.LeakyReLU(),
            nn.Linear(64, 32),
            nn.LeakyReLU(),
            nn.Linear(32, action_dim),
        )

        self.softmax = nn.Softmax(dim=1)

    def init(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.Linear)):
                nn.init.orthogonal(m.weight)
                # nn.init.constant_(m.bias, 0.01)

    def forward(self, input_tensors, style):
        output_1 = self.mlp_1(input_tensors)
        # style
        style_output = self.mlp_2(style)

        # concat and process
        dense_output = torch.cat([output_1, style_output], dim=-1)
        dense_output = self.mlp(dense_output)

        # get the value tensor
        value = self.value_head(dense_output)

        # get the tensors of policy net
        policy_output = self.policy_head(dense_output)

        prob = self.softmax(policy_output - policy_output.max(1)[0].unsqueeze(1))
        log_prob = F.log_softmax(policy_output, dim=1)

        return prob, log_prob, value


if __name__ == '__main__':
    import torch as th
    batch_size = 2
    net = PolicyNet([120, 6], 5)
    image = th.rand((batch_size, 120))
    state = th.rand((batch_size, 6))
    result = net(image, state)
    # print(result[0][0].shape)
    print(result)


