import torch
import torch.nn as nn

class InverseModel(nn.Module):
    def __init__(self, obs_shape, action_shape):
        super().__init__()
        repr_dim = 128
        hid_dim = 64
        self.encoder = DownSample(2 * obs_shape[0], repr_dim, use_bn=False)
        repr_dim = repr_dim * 2 * 2
        self.inverse_dynamics = nn.Sequential(
            nn.Linear(repr_dim, hid_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hid_dim, hid_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hid_dim, action_shape[0]),
            nn.Tanh()
        )

    def forward(self, obs, next_obs):
        obs = obs/255.0 - 0.5
        next_obs = next_obs/255.0 - 0.5
        x = torch.cat([obs, next_obs], dim=1)
        assert x.dtype == torch.float
        state = self.encoder(x)
        batch_size = len(state)
        state = state.reshape(batch_size, -1)
        action = self.inverse_dynamics(state)
        return action

class DownSample(nn.Module):
    def __init__(self, in_channels, out_channels, use_bn=True):
        super().__init__()
        self.conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=3,
            stride=2,
            padding=1,
            bias=False,
        )
        self.bn = nn.BatchNorm2d(out_channels)

        self.resblocks1 = nn.ModuleList(
            [ResidualBlock(out_channels, out_channels, use_bn=use_bn) for _ in range(2)]
        )
        self.pooling1 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)

        self.resblocks2 = nn.ModuleList(
            [ResidualBlock(out_channels, out_channels, use_bn=use_bn) for _ in range(2)]
        )
        self.pooling2 = nn.AvgPool2d(kernel_size=5, stride=3, padding=1)

        self.resblocks3 = nn.ModuleList(
            [ResidualBlock(out_channels, out_channels, use_bn=use_bn) for _ in range(2)]
        )
        self.pooling3 = nn.AvgPool2d(kernel_size=4, stride=3, padding=0)

        self.use_bn = use_bn

    def forward(self, x):
        x = self.conv(x)
        if self.use_bn:
            x = self.bn(x)
        x = nn.functional.relu(x)

        for block in self.resblocks1:
            x = block(x)
        x = self.pooling1(x)

        for block in self.resblocks2:
            x = block(x)
        x = self.pooling2(x)

        for block in self.resblocks3:
            x = block(x)
        x = self.pooling3(x)
        return x
    
    
def conv3x3(in_channels, out_channels, stride=1):
    return nn.Conv2d(
        in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False
    )


# Residual block
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, use_bn=False, stride=1):
        super().__init__()
        self.conv1 = conv3x3(in_channels, out_channels, stride)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = conv3x3(out_channels, out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.use_bn = use_bn

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        if self.use_bn:
            out = self.bn1(out)
        out = nn.functional.relu(out)

        out = self.conv2(out)
        if self.use_bn:
            out = self.bn2(out)

        out += identity
        out = nn.functional.relu(out)
        return out

class Transition(nn.Module):
    def __init__(self, obs_shape, action_shape):
        super().__init__()

        assert len(obs_shape) == 3
        self.repr_dim = 32 * 35 * 35

        self.convnet = nn.Sequential(nn.Conv2d(obs_shape[0], 32, 3, stride=2),
                                     nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
                                     nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
                                     nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
                                     nn.ReLU())
        self.head = nn.Sequential(nn.Linear(self.repr_dim, 20),
                                    nn.ReLU(),
                                    nn.Linear(20, 40),
                                    nn.ReLU(),
                                    nn.Linear(40, 10),
                                    nn.ReLU(),
                                    nn.Linear(10, action_shape[0]),)
        self.apply(utils.weight_init)


    def forward(self, obs):
        obs = obs / 255.0 - 0.5
        h = self.convnet(obs)
        h = h.view(h.shape[0], -1)
        h = self.head(h)
        return h