import numpy as np
import torch
from torch import nn
from torch.nn import functional as F

class BabyAIResBlock(nn.Module):

    def __init__(self, dim):
        super().__init__()
        self.conv1 = nn.Conv2d(
            in_channels=dim, out_channels=dim,
            kernel_size=(3, 3), padding=1)
        self.bn1 = nn.BatchNorm2d(dim)
        self.conv2 = nn.Conv2d(
            in_channels=dim, out_channels=dim,
            kernel_size=(3, 3), padding=1)
        self.bn2 = nn.BatchNorm2d(dim)

    def forward(self, x):
        r = F.relu(self.bn1(self.conv1(x)))
        r = F.relu(self.bn2(self.conv2(r)))
        return x + r # add to the original signal

class CNNPartialObsInverseModel(nn.Module):

    def __init__(self, num_actions, dim=128, end_pool=False):
        super().__init__()
        self.cnn = nn.Sequential(
                                nn.Conv2d(6, 64, kernel_size=3, stride=1, padding=1),
                                nn.BatchNorm2d(64),
                                nn.ReLU(),
                                nn.Conv2d(64, dim, kernel_size=3, stride=1, padding=1),
                                nn.BatchNorm2d(dim),
                                nn.ReLU(),
                                *([] if end_pool else [nn.MaxPool2d(kernel_size=(2,2), stride=2, padding=1)]),
                                BabyAIResBlock(dim),
                                BabyAIResBlock(dim)
                                )
        self.head = nn.Linear(dim, num_actions)

    def forward(self, obs, next_obs):
        # Concatenate along the channels axis
        x = torch.cat((obs, next_obs), dim=1).float() # Concat along channels dim
        x = self.cnn(x)
        # Max pool spatially
        b, c, h, w = x.shape
        x = x.view(b, c, h*w).max(dim=-1)[0]
        return self.head(x)

class ImpalaResBlock(nn.Module):

    def __init__(self, filters):
        super().__init__()
        self.layers = nn.Sequential(
            nn.ReLU(),
            nn.Conv2d(filters, filters, 3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(filters, filters, 3, stride=1, padding=1),
        )

    def forward(self, x):
        return x + self.layers(x)

class ImpalaConvSequence(nn.Module):

    def __init__(self, filters, input_size=None):
        super().__init__()
        if input_size is None:
            input_size = filters
        self.conv = nn.Conv2d(input_size, filters, 3, stride=1, padding=1)
        self.resblock1 = ImpalaResBlock(filters)
        self.resblock2 = ImpalaResBlock(filters)
        self.pooling = nn.MaxPool2d(3, stride=2, padding=1)

    def forward(self, x):
        x = self.conv(x)
        x = self.pooling(x)
        x = self.resblock1(x)
        x = self.resblock2(x)
        return x


class CNNFullyObsInverseModel(nn.Module):
    
    def __init__(self, num_actions, arch=[32, 64, 64]):
        super().__init__()
        prev_size = 6
        impala_blocks= []
        for dim in arch:
            impala_blocks.append(ImpalaConvSequence(dim, input_size=prev_size))
            prev_size = dim
        self.cnn = nn.Sequential(*impala_blocks)
        self.inv_emb = nn.Linear(6, 32)
        self.head = nn.Sequential(
                        nn.Linear(arch[-1]+32, arch[-1]),
                        nn.ReLU(),
                        nn.Linear(arch[-1], num_actions)
                    )

    def forward(self, obs, next_obs):
        x = torch.cat((obs['image'], next_obs['image']), dim=1).float()
        x = self.cnn(x)
        b, c, h, w = x.shape
        x = x.view(b, c, h*w).max(dim=-1)[0] # max pool the visual features
        # The inventory features
        inv = torch.cat((obs['inventory'], next_obs['inventory']), dim=1).float()
        inv = self.inv_emb(inv)
        return self.head(torch.cat((x, inv), dim=1))

