from re import L
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.autograd import Variable

from utils import convtrans2D_output_size


class HVAE(nn.Module):
    def __init__(self,  encoder_split_inds = [26], # contains the indices of the feature extractor module 
                                                   # where we insert latent vars
                                                   # The latent var at the end of the feature extractor is ommited
                        encoder_filters_at_split = [256], # number of filters at split, MUST BE CORRECT 
                        out_size=(112,112), 
                        fc_hidden1=256, 
                        fc_hidden2=128, 
                        CNN_embed_dim=64,
                        dropout_rate=0.1,
                        train_vgg=False):
        super(HVAE, self).__init__()
        # We allow upto 3 split, decoder splits are hardcoded
        assert len(encoder_split_inds) <= 3, "Upto 3 splits allowed"
        self.out_size = out_size if isinstance (out_size, tuple) else (out_size, out_size)
        self.fc_hidden1, self.fc_hidden2, self.CNN_embed_dim = fc_hidden1, fc_hidden2, CNN_embed_dim
        #  encoding components
        vgg = models.vgg19_bn(pretrained=True).cuda()

        if not train_vgg:
            for param in vgg.parameters():
                param.requires_grad = False         
        feature_modules = list(vgg.children())[0] # get the modules of the feature extractor    

        # Get rid of inplace ops
        for i,mod in enumerate(feature_modules):
            if isinstance(mod, nn.ReLU):
                feature_modules[i] = nn.ReLU()

        # Decoder
        self.filters = [128, 64, 32, 16, 3] # 4 decoder layers, 3*4 modules
        self.kernels = [5, 5, 5, 5]
        self.strides = [2, 2, 2, 2]    

        decoder_modules = []
        decoder_input = [(self.filters[0], 4, 4)]
        for i in range(len(self.filters)-1):
            mods = [
                nn.ConvTranspose2d(in_channels=self.filters[i], 
                                out_channels=self.filters[i+1], 
                                kernel_size=self.kernels[i], 
                                stride=self.strides[i]),
                nn.BatchNorm2d(self.filters[i+1], momentum=0.01)
            ]
            nl = nn.ReLU() if i < len(self.filters) - 2 else nn.Sigmoid()
            mods.append(nl)
            decoder_modules.extend(mods)
            img_size = decoder_input[-1][1:]
            out_size = convtrans2D_output_size(img_size, 
                                    kernel_size=self.kernels[i], 
                                    stride=self.strides[i])
            out_size = [self.filters[i+1]] + list(out_size)
            decoder_input.append(tuple(out_size))

        self.decoder_shape_at_split = []
        if len(encoder_split_inds) == 0:
            self.decoder_shape_at_split.append(decoder_input[0])
            decoder_split = [0,12]
        if len(encoder_split_inds) == 1:
            decoder_split = [0,6,12]
            self.decoder_shape_at_split.append(decoder_input[0])
            self.decoder_shape_at_split.append(decoder_input[2])
        elif len(encoder_split_inds) == 2:
            decoder_split = [0,4,8,12]
            self.decoder_shape_at_split.append(decoder_input[0])
            self.decoder_shape_at_split.append(decoder_input[2])
            self.decoder_shape_at_split.append(decoder_input[-2])
        elif len(encoder_split_inds) == 3:
            decoder_split = [0,3,6,9,12]
            self.decoder_shape_at_split.extend(decoder_input)
        encoder_split_inds = [0] + encoder_split_inds + [len(feature_modules)+1]
        encoder_filters_at_split = encoder_filters_at_split + [512]

        # Init modules lists
        self.feature_blocks = nn.ModuleList()
        self.enc_to_latent = nn.ModuleList()
        self.latent_to_dec = nn.ModuleList()
        self.mu_layers = nn.ModuleList()
        self.logvar_layers = nn.ModuleList()
        self.decoder_blocks = nn.ModuleList()

        for i in range(len(encoder_split_inds)-1):
            # Split Encoder blocks
            start = encoder_split_inds[i]
            end = encoder_split_inds[i+1]
            mods = feature_modules[start:end]
            block = nn.Sequential(*mods)
            self.feature_blocks.append(block)

            # Split decoder blocks
            start = decoder_split[i]
            end = decoder_split[i+1]
            mods = decoder_modules[start:end]
            block = nn.Sequential(*mods)
            self.decoder_blocks.append(block)

            # Encoder -> Latent Vector
            enc_to_latent = nn.Sequential(
                nn.AdaptiveAvgPool2d((2,2)),
                nn.Flatten(),
                nn.Linear(encoder_filters_at_split[i]*2*2, self.fc_hidden1),
                nn.BatchNorm1d(self.fc_hidden1, momentum=0.01),
                nn.ReLU(),
                nn.Dropout(dropout_rate),
                nn.Linear(self.fc_hidden1, self.fc_hidden2),
                nn.BatchNorm1d(self.fc_hidden2, momentum=0.01),
                nn.ReLU(),
                nn.Dropout(dropout_rate)
            )

            # Latent vectors mu and sigma
            mu_layer = nn.Linear(self.fc_hidden2, self.CNN_embed_dim)      # output = CNN embedding latent variables
            logvar_layer = nn.Linear(self.fc_hidden2, self.CNN_embed_dim)  # output = CNN embedding latent variables

            # Latent vector -> Decoder
            latent_to_dec = nn.Sequential(
                nn.Linear(self.CNN_embed_dim, self.fc_hidden2),
                nn.BatchNorm1d(self.fc_hidden2),
                nn.ReLU(),
                nn.Dropout(dropout_rate),
                nn.Linear(self.fc_hidden2, self.decoder_shape_at_split[i][0] * 4 * 4),
                nn.BatchNorm1d(self.decoder_shape_at_split[i][0] * 4 * 4),
                nn.Dropout(dropout_rate),
                nn.ReLU()
            )

            self.enc_to_latent.append(enc_to_latent)
            self.mu_layers.append(mu_layer)
            self.logvar_layers.append(logvar_layer)
            self.latent_to_dec.append(latent_to_dec)

    def encode(self, x):
        mus = []
        logvars = []
        for i in range(len(self.feature_blocks)):
            # Extract features
            x = self.feature_blocks[i](x)  

            # FC layers
            z = self.enc_to_latent[i](x)

            # Get mu, sigma
            mu, logvar = self.mu_layers[i](z), self.logvar_layers[i](z)
            
            # Append
            mus.append(mu)
            logvars.append(logvar)

        return mus, logvars

    def reparameterize(self, mu, logvar):
        if self.training:
            std = logvar.mul(0.5).exp_()
            eps = Variable(std.data.new(std.size()).normal_())
            return eps.mul(std).add_(mu)
        else:
            return mu

    """ 
        z must be a list of length equal to the length of self.decoder_blocks 
    """
    def decode(self, z):
        x = torch.zeros(1).cuda()
        for i in range(len(self.decoder_blocks)):
            h = self.latent_to_dec[i](z[i]).view(-1, self.decoder_shape_at_split[i][0], 4, 4)
            h = F.interpolate(h, size=self.decoder_shape_at_split[i][1:], mode="bilinear")
            x = self.decoder_blocks[i](x + h)

        # Increase resolution
        x = F.interpolate(x, size=self.out_size, mode='bilinear')
        
        return x

    def forward(self, x):
        # Encode    
        mus, logvars = self.encode(x)

        # Sample latent vars
        z = []
        out = 0
        for mu, logvar in zip(mus, logvars):
            z.append(self.reparameterize(mu, logvar))

        # Decode
        out = self.decode(z)

        return out, z, mus, logvars


class DecodeModel(nn.Module):
    def __init__(self, hvae_params, 
                       hvae_weights, 
                       num_voxels_per_roi, #dict keyed by roi name
                       ndec_params) -> None:
        super().__init__()

        self.hvae = HVAE(**hvae_params)
        
        # Load pretrained model
        self.hvae.load_state_dict(torch.load(hvae_weights))

        # Set hvae to non-trainable 
        for param in self.hvae.parameters():
            param.requires_grad = False

        # And model to eval
        self.hvae.eval()

        # The neural decoding model, maps the voxel space to the latent space of the decoder
        self.neural_decoder = nn.ModuleDict()
        for roi in num_voxels_per_roi:
            self.neural_decoder[''.join(roi)] =  nn.Sequential(
                nn.Linear(in_features=num_voxels_per_roi[roi], out_features=ndec_params["fc_hidden1"]),
                nn.BatchNorm1d(ndec_params["fc_hidden1"], momentum=0.01),
                nn.ReLU(),
                nn.Dropout(ndec_params["dropout_rate"]),
                nn.Linear(in_features=ndec_params["fc_hidden1"], out_features=ndec_params["fc_hidden2"]),
                nn.BatchNorm1d(ndec_params["fc_hidden2"], momentum=0.01),
                nn.ReLU(),
                nn.Dropout(ndec_params["dropout_rate"]),
                nn.Linear(in_features=ndec_params["fc_hidden2"], out_features=self.hvae.CNN_embed_dim)
            ) 

    # x is the (batched) fMRI data, given as an orderedict indexed by ROI name (or ROI tuple)
    # each dict value is (-1, self.num_voxels, )
    def forward(self, x):
        z  = []
        for roi in x:
            z.append(self.neural_decoder[''.join(roi)](x[tuple(roi)]))
        img = self.hvae.decode(z)
        return img, z


