import numpy as np
import torch
import torch.nn as nn
from collections import OrderedDict

class RNA(nn.Module):
    def __init__(self, model, dim_coarse_label=1000):
        super().__init__()
        self.model, self.film_blocks, self.film_generator = configure_model(model, dim_coarse_label)

    def forward(self, x, coarse_labels):
        batch_size = x.shape[0]
        film_vectors = self.film_generator(coarse_labels)
        for film_block, film_vector in zip(self.film_blocks, film_vectors):
            film_vector = film_vector.view(batch_size, 2, -1)
            gamma = film_vector[:, 0, :]
            beta = film_vector[:, 1, :]
            film_block.set_params(gamma, beta)
        output = self.model(x)
        return output
    
    def get_trainable_params(self):
        return self.film_generator.parameters()


class FiLMBlock(nn.Module):
    def __init__(self, param_size):
        super().__init__()
        self.param_size = param_size
        self.set_params()

    def set_params(self, gamma=None, beta=None):
        self.gamma = gamma
        self.beta = beta

    def forward(self, x):
        self.gamma = self.gamma.view(x.shape[0], x.shape[1], 1, 1)
        self.beta = self.beta.view(x.shape[0], x.shape[1], 1, 1)

        x = self.gamma * x + self.beta
        return x


class FiLMGenerator(nn.Module):
    def __init__(self, channels, dim_coarse_label=1000):
        super().__init__()
        self.encoder = nn.Sequential(
                            nn.Linear(dim_coarse_label, 128), 
                            nn.ReLU(),
                            nn.Linear(128, 64),
                            nn.ReLU())
        self.decoders = nn.ModuleList([nn.Sequential(
                            nn.Linear(64, 64), 
                            nn.ReLU(),
                            nn.Linear(64, 2 * channel)) for channel in channels])

    def forward(self, x):
        encoder_output = self.encoder(x)
        decoder_outputs = [decoder(encoder_output) for decoder in self.decoders]
        return decoder_outputs


def configure_model(model, dim_coarse_label=1000):
    model.train()
    model.requires_grad_(False)
    channels = []
    film_blocks = []

    for name, module in tuple(model.named_modules()):
        if isinstance(module, nn.ReLU):
            if isinstance(module_prev, nn.BatchNorm2d):
                recursive_setattr(model, name_prev, replace_BatchNorm2d(module_prev))                 
        module_prev, name_prev = module, name
    
    for module in model.modules():
        if isinstance(module, FiLMBlock):
            channels.append(module.param_size)
            film_blocks.append(module)

    film_generator = FiLMGenerator(channels, dim_coarse_label)
    return model, film_blocks, film_generator


def replace_BatchNorm2d(module):
    if isinstance(module, nn.BatchNorm2d):
        return nn.Sequential(OrderedDict([('inner_bn', module), ('FiLM', FiLMBlock(module.num_features))]))
    else:
        return module


def recursive_setattr(obj, attr, value):
    attr = attr.split('.', 1)
    if len(attr) == 1:
        setattr(obj, attr[0], value)
    else:
        recursive_setattr(getattr(obj, attr[0]), attr[1], value)
