import numpy as np
import torch


class CnnNet(torch.nn.Module):
    def __init__(self, in_frames=3):
        super(CnnNet, self).__init__()

        self.conv1 = torch.nn.Conv2d(in_frames, 16, kernel_size=4, stride=2, padding=1)
        self.conv2 = torch.nn.Conv2d(16, 32, kernel_size=4, stride=2, padding=1)
        self.conv3 = torch.nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=0)
        self.conv4 = torch.nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=0)
        self.phi = torch.relu

    def forward(self, x):
        h = self.phi(self.conv1(x))
        h = self.phi(self.conv2(h))
        h = self.phi(self.conv3(h))
        h = self.phi(self.conv4(h))
        return h


class CnnLstmNet(torch.nn.Module):
    def __init__(self, cnn, lstm_hidden_dim, lstm_layer_dim):
        super(CnnLstmNet, self).__init__()

        self.in_dim = 288
        self.cnn = cnn
        self.hidden_dim = lstm_hidden_dim
        self.layer_dim = lstm_layer_dim
        self.lstm = torch.nn.LSTM(self.in_dim, self.hidden_dim, self.layer_dim,
                                  batch_first=True, dropout=0.0)

    def forward(self, x):
        b, s, c, h, w = x.shape

        # Initialize hidden state with zeros
        h0 = torch.zeros(self.layer_dim, x.shape[0], self.hidden_dim,
                         dtype=x.dtype, layout=x.layout, device=x.device).requires_grad_()
        # Initialize cell state
        c0 = torch.zeros(self.layer_dim, x.shape[0], self.hidden_dim,
                         dtype=x.dtype, layout=x.layout, device=x.device).requires_grad_()

        # reshape time to batch dimension
        x = torch.reshape(x, shape=(b * s, c, h, w))
        # apply cnn encoder
        x = self.cnn(x)
        # split time and batch dimension
        x = torch.reshape(x, shape=(b, s, -1))

        return self.lstm(x, (h0.detach(), c0.detach()))


class ActionNet(torch.nn.Module):
    def __init__(self, lstm_hidden_dim, action_out_dim):
        super(ActionNet, self).__init__()

        self.out_dim = action_out_dim
        self.hidden_dim = lstm_hidden_dim
        self.final = torch.nn.Linear(self.hidden_dim, self.out_dim)

    def forward(self, x):
        return self.final(x)


class CameraNet(torch.nn.Module):
    def __init__(self, lstm_hidden_dim, camera_out_dim):
        super(CameraNet, self).__init__()

        self.out_dim = camera_out_dim
        self.hidden_dim = lstm_hidden_dim
        self.final = torch.nn.Linear(self.hidden_dim, self.out_dim)

    def forward(self, x):
        return self.final(x)


class ValueNet(torch.nn.Module):
    def __init__(self, lstm_hidden_dim):
        super(ValueNet, self).__init__()

        self.hidden_dim = lstm_hidden_dim
        self.final = torch.nn.Linear(self.hidden_dim, 1)

    def forward(self, x):
        return self.final(x)


class Net(torch.nn.Module):
    def __init__(self, in_frames=3, lstm_hidden_dim=256, lstm_layer_dim=1, action_out_dim=8, camera_out_dim=2):
        super(Net, self).__init__()

        # policy part
        self.cnn = CnnNet(in_frames=in_frames)
        self.cnn_lstm = CnnLstmNet(self.cnn, lstm_hidden_dim=lstm_hidden_dim, lstm_layer_dim=lstm_layer_dim)
        self.action_net = ActionNet(lstm_hidden_dim=lstm_hidden_dim, action_out_dim=action_out_dim)
        self.camera_net = CameraNet(lstm_hidden_dim=lstm_hidden_dim, camera_out_dim=camera_out_dim)

        # value part
        self.cnn_val = CnnNet(in_frames=in_frames)
        self.cnn_lstm_val = CnnLstmNet(self.cnn_val, lstm_hidden_dim=lstm_hidden_dim, lstm_layer_dim=lstm_layer_dim)
        self.value_net = ValueNet(lstm_hidden_dim=lstm_hidden_dim)

    def policy_parameters(self):
        return self._shared_parameters() + list(self.action_net.parameters()) + list(self.camera_net.parameters())

    def value_parameters(self):
        return self._shared_parameters() + list(self.value_net.parameters())

    def _shared_parameters(self):
        return list(self.cnn_lstm.parameters())

    def forward(self, x, action, camera):

        # policy forward
        h = self.cnn_lstm(x)[0][:, -1, ...]
        action_out = self.action_net(h)
        camera_out = self.camera_net(h)

        # value forward
        h = self.cnn_lstm_val(x)[0][:, -1, ...]
        value_out = self.value_net(h)
        return action_out, camera_out, value_out


class BaseNetwork(Net):

    def __init__(self, ):
        super().__init__(lstm_hidden_dim=256, lstm_layer_dim=1, action_out_dim=11, camera_out_dim=0)


class Network(BaseNetwork):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        input_dict = x
        pov = input_dict['pov']
        binary_actions = input_dict['binary_actions']
        camera_actions = input_dict['camera_actions']

        logits, camera, value = super().forward(pov, binary_actions, camera_actions)

        return {"logits": logits, "camera": camera, "value": value}


if __name__ == "__main__":
    """ main """
    in_frames = torch.from_numpy(np.random.randn(1, 10, 3, 48, 48).astype(np.float32))
    binary_actions = torch.from_numpy(np.random.randn(1, 10, 9).astype(np.float32))
    camera_actions = torch.from_numpy(np.random.randn(1, 10, 2).astype(np.float32))

    net = Network()
    out = net({"pov": in_frames, "binary_actions": binary_actions, "camera_actions": camera_actions})
    print("out_shape", out["logits"].shape, out["camera"].shape, out["value"].shape)
