import numpy as np
import torch


class CnnNet(torch.nn.Module):
    def __init__(self):
        super(CnnNet, self).__init__()

        self.conv1 = torch.nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1)
        self.conv2 = torch.nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)
        self.conv3 = torch.nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=0)
        self.conv4 = torch.nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=0)
        self.phi = torch.relu

    def forward(self, x):
        h = self.phi(self.conv1(x))
        h = self.phi(self.conv2(h))
        h = self.phi(self.conv3(h))
        h = self.phi(self.conv4(h))
        return h


class AttentionNet(torch.nn.Module):
    def __init__(self):
        super(AttentionNet, self).__init__()

        self.conv1 = torch.nn.Conv2d(3, 16, kernel_size=4, stride=2, padding=1)
        self.conv2 = torch.nn.Conv2d(16, 32, kernel_size=4, stride=2, padding=1)
        self.conv3 = torch.nn.Conv2d(32, 32, kernel_size=4, stride=2, padding=1)
        self.conv4 = torch.nn.Conv2d(32, 8, kernel_size=4, stride=2, padding=0)
        self.linear1 = torch.nn.Linear(1024, 512)
        self.linear2 = torch.nn.Linear(512, 32)
        self.phi = torch.relu

    def forward(self, x):

        # reshape time to batch dimension
        b, s, c, h, w = x.shape
        x = torch.reshape(x, shape=(b * s, c, h, w))

        # apply general encoder
        x = self.phi(self.conv1(x))
        x = self.phi(self.conv2(x))
        x = self.phi(self.conv3(x))
        x = self.phi(self.conv4(x))

        x = torch.reshape(x, shape=(b, -1))
        x = self.phi(self.linear1(x))
        x = self.linear2(x)
        x = torch.softmax(x, dim=1)
        x = x.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)

        return x


class TimeDistCnnNet(torch.nn.Module):
    def __init__(self, cnn, attention, hidden_dim):
        super(TimeDistCnnNet, self).__init__()

        self.in_dim = 576
        self.cnn = cnn
        self.attention = attention
        self.hidden_dim = hidden_dim
        self.linear = torch.nn.Linear(self.in_dim, self.hidden_dim)

    def forward(self, x):
        b, s, c, h, w = x.shape

        # compute attention
        a = self.attention(x)

        # apply soft attention directly to input
        x = a * x

        # reshape time to batch dimension
        x = torch.reshape(x, shape=(b*s, c, h, w))

        # apply cnn encoder
        x = self.cnn(x)

        # split time and batch dimension
        x = torch.reshape(x, shape=(b, s, x.shape[1], x.shape[2], x.shape[3]))

        # apply soft attention directly to input
        # x = torch.sum(a * x, dim=1)
        x = torch.mean(x, dim=1)

        # apply linear layer
        x = self.linear(x.view(b, self.in_dim))
        x = torch.relu(x)

        return x


class ActionNet(torch.nn.Module):
    def __init__(self, hidden_dim, action_out_dim):
        super(ActionNet, self).__init__()

        self.out_dim = action_out_dim
        self.hidden_dim = hidden_dim
        self.final = torch.nn.Linear(self.hidden_dim, self.out_dim)

    def forward(self, x):
        return self.final(x)


class CameraNet(torch.nn.Module):
    def __init__(self, hidden_dim, camera_out_dim):
        super(CameraNet, self).__init__()

        self.out_dim = camera_out_dim
        self.hidden_dim = hidden_dim
        self.final = torch.nn.Linear(self.hidden_dim, self.out_dim)

    def forward(self, x):
        return self.final(x)


class Net(torch.nn.Module):
    def __init__(self, camera_head=True, hidden_dim=256, action_out_dim=8, camera_out_dim=2):
        super(Net, self).__init__()
        self.camera_head = camera_head
        self.cnn = CnnNet()
        self.attention = AttentionNet()
        self.td_cnn = TimeDistCnnNet(self.cnn, self.attention, hidden_dim=hidden_dim)
        self.action_net = ActionNet(hidden_dim=hidden_dim, action_out_dim=action_out_dim)
        self.camera_net = CameraNet(hidden_dim=hidden_dim, camera_out_dim=camera_out_dim)

    def forward(self, x, action, camera):
        h = self.td_cnn(x)
        action_out = self.action_net(h)
        camera_out = self.camera_net(h)
        return action_out, camera_out


class BaseNetwork(Net):

    def __init__(self, ):
        super().__init__(camera_head=True, hidden_dim=256, action_out_dim=8, camera_out_dim=2)


class Network(BaseNetwork):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        input_dict = x
        pov = input_dict['pov']
        binary_actions = input_dict['binary_actions']
        camera_actions = input_dict['camera_actions']

        logits, camera = super().forward(pov, binary_actions, camera_actions)

        return {"logits": logits, "camera": camera}


if __name__ == "__main__":
    """ main """
    in_frames = torch.from_numpy(np.random.randn(1, 32, 3, 48, 48).astype(np.float32))
    binary_actions = torch.from_numpy(np.random.randn(1, 32, 9).astype(np.float32))
    camera_actions = torch.from_numpy(np.random.randn(1, 32, 2).astype(np.float32))

    net = Network()
    out = net({"pov": in_frames, "binary_actions": binary_actions, "camera_actions": camera_actions})
    print("out_shape", out["logits"].shape, out["camera"].shape)
