import torch
from torch import nn
from torch.nn.utils import spectral_norm
import numpy as np
from collections import OrderedDict

from .conv import conv2d_size_out


class ConvNetFiLM(nn.Module):
    def __init__(
        self,
        input_n_channel=1,  # not counting z_conv
        append_dim=0,  # not counting z_mlp
        cnn_kernel_size=[5, 3],
        cnn_stride=[2, 1],
        cnn_padding=None,
        output_n_channel=[16, 32],
        img_size=128,
        verbose=True,
        use_sm=True,
        use_bn=True,
        use_spec=False,
        use_residual=False,
        #
        lang_dim=768,
    ):

        super(ConvNetFiLM, self).__init__()
        self.output_n_channel = output_n_channel

        self.append_dim = append_dim
        assert len(cnn_kernel_size) == len(output_n_channel), (
            "The length of the kernel_size list does not match with the " +
            "#channel list!")
        self.n_conv_layers = len(cnn_kernel_size)

        if np.isscalar(img_size):
            height = img_size
            width = img_size
        else:
            height, width = img_size

        # Use ModuleList to store [] conv layers, 1 spatial softmax and [] MLP
        # layers.
        self.moduleList = nn.ModuleList()

        #= CNN: W' = (W - kernel_size + 2*padding) / stride + 1
        # Nx1xHxW -> Nx16xHxW -> Nx32xHxW
        for i, (kernel_size, stride, out_channels) in enumerate(
                zip(cnn_kernel_size, cnn_stride, output_n_channel)):

            # Add conv
            padding = 0
            if cnn_padding is not None:
                padding = cnn_padding[i]
            if i == 0:
                in_channels = input_n_channel
            else:
                in_channels = output_n_channel[i - 1]
            module = nn.Sequential()
            conv_layer = nn.Conv2d(in_channels=in_channels,
                                   out_channels=out_channels,
                                   kernel_size=kernel_size,
                                   stride=stride,
                                   padding=padding)
            if use_spec:
                conv_layer = spectral_norm(conv_layer)
            module.add_module("conv_1", conv_layer)

            # # Always ReLU
            # module.add_module('act_1', nn.ReLU())

            # Add module
            self.moduleList.append(module)

            # Update height and width of images after modules
            height, width = conv2d_size_out([height, width], kernel_size,
                                            stride, padding)

        #= Flatten
        self.flatten_module = nn.Sequential(OrderedDict([('flatten', 
                                                          nn.Flatten())]))
        cnn_output_dim = int(output_n_channel[-1] * height * width)
        self.cnn_output_dim = cnn_output_dim

        if verbose:
            print(self.moduleList)

        # FiLM
        self.film_generator = nn.Linear(lang_dim, 2 * sum(output_n_channel))
        self.relu = nn.ReLU(inplace=True)


    def get_output_dim(self):
        return self.cnn_output_dim


    def forward(self, x, lang):

        if x.dim() == 3:
            x = x.unsqueeze(1)  # Nx1xHxW
        N, _, H, W = x.shape

        # Coordinate map - 2 channels - assume square image for now
        coordinate = torch.arange(-1, 1 + 0.00001, 2 / (H-1)).cuda()
        coordinate_x = coordinate.expand(N, 1, H, W)
        coordinate_y = coordinate.view(H, 1).expand(N, 1, H, W)
        x = torch.cat([x, coordinate_x, coordinate_y], 1)

        film_vector = self.film_generator(lang).view(
            N, 2, -1)

        channel = 0
        for ind, module in enumerate(self.moduleList):
            x = module(x)
            
            # FiLM
            beta = film_vector[:, 0, 
                               channel:(channel+self.output_n_channel[ind])]
            gamma = film_vector[:, 1, 
                               channel:(channel+self.output_n_channel[ind])]
            beta = beta.view(N, x.size(1), 1, 1)
            gamma = gamma.view(N, x.size(1), 1, 1)
            channel += self.output_n_channel[ind]
            x = gamma * x + beta

            x = self.relu(x)
        x = self.flatten_module(x)
        return x
