import torch

from .conv import ConvNet
from .conv_film import ConvNetFiLM


def tie_weights(src, trg):
    assert type(src) == type(trg)
    trg.weight = src.weight
    trg.bias = src.bias


class Encoder(torch.nn.Module):
    """Conv layers"""
    def __init__(
        self,
        input_n_channel,
        img_sz,
        kernel_sz,
        stride,
        padding,
        n_channel,
        dual_conv=False,
        use_sm=True,
        use_spec=False,
        use_bn=False,
        use_residual=False,
        use_film=False,
        device='cpu',
        verbose=True,
    ):
        super().__init__()
        self.use_film = use_film
        self.dual_conv = dual_conv
        self.input_n_channel = input_n_channel
        if use_film:
            input_n_channel += 2    # coordinate map
        if verbose:
            print(
                "The neural network for encoder has the architecture as below:"
            )
        if use_film:
            conv_type = ConvNetFiLM
        else:
            conv_type = ConvNet
        self.conv = conv_type(input_n_channel=input_n_channel,
                                cnn_kernel_size=kernel_sz,
                                cnn_stride=stride,
                                cnn_padding=padding,
                                output_n_channel=n_channel,
                                img_size=img_sz,
                                use_sm=use_sm,
                                use_spec=use_spec,
                                use_bn=use_bn,
                                use_residual=use_residual,
                                verbose=verbose).to(device)
        if dual_conv:
            self.conv_1 = conv_type(input_n_channel=input_n_channel,
                                    cnn_kernel_size=kernel_sz,
                                    cnn_stride=stride,
                                    cnn_padding=padding,
                                    output_n_channel=n_channel,
                                    img_size=img_sz,
                                    use_sm=use_sm,
                                    use_spec=use_spec,
                                    use_bn=use_bn,
                                    use_residual=use_residual,
                                    verbose=verbose).to(device)

    def forward(self, image, detach=False, lang=None):
        if len(image.shape) == 3:
            image = image.unsqueeze(0)

        if self.dual_conv:
            image_1 = image[:, self.input_n_channel:, :, :]
            image = image[:, :self.input_n_channel, :, :]

        if self.use_film:
            out = self.conv(image, lang=lang)
        else:
            out = self.conv(image)
        
        if self.dual_conv:
            if self.use_film:
                out_1 = self.conv(image_1, lang=lang)
            else:
                out_1 = self.conv(image_1)
            out = torch.cat((out, out_1), dim=-1)

        if detach:
            out = out.detach()
        return out

    def copy_conv_weights_from(self, source):
        """
        Tie convolutional layers between two same encoders
        """
        for source_module, module in zip(source.conv.moduleList,
                                         self.conv.moduleList):
            for source_layer, layer in zip(
                    source_module.children(), module.children(
                    )):  # children() works for both Sequential and nn.Module
                if isinstance(layer, torch.nn.Conv2d):
                    tie_weights(src=source_layer, trg=layer)

        if self.use_film:
            tie_weights(src=source.conv.film_generator, trg=self.conv.film_generator)

        if self.dual_conv:
            for source_module, module in zip(source.conv_1.moduleList,
                                            self.conv_1.moduleList):
                for source_layer, layer in zip(
                        source_module.children(), module.children(
                        )):  # children() works for both Sequential and nn.Module
                    if isinstance(layer, torch.nn.Conv2d):
                        tie_weights(src=source_layer, trg=layer)

            if self.use_film:
                tie_weights(src=source.conv_1.film_generator, trg=self.conv_1.film_generator)

    def get_output_dim(self):
        if self.dual_conv:
            return self.conv.get_output_dim()*2
        else:
            return self.conv.get_output_dim()
