import torch
from torch import nn
from .layers import ConvEncoderBlock, ConvDecoderBlock
from .ArchitectureSampler import SampleNetworkArchitecture


class AdaptiveEncoder(nn.Module):
    def __init__(self, in_channels, num_channels=400, a_prior=2.0, b_prior=2.0, num_samples=5, truncation=40, device=torch.device("cuda:0")):
        super(AdaptiveEncoder, self).__init__()
        self.mode = "NN"
        self.in_channels = in_channels
        self.num_channels = num_channels
        self.truncation = truncation
        self.num_samples = num_samples
        self.device = device

        # instance of a stick breaking process
        self.architecture_sampler = SampleNetworkArchitecture(num_neurons=num_channels,
                                                              a_prior=a_prior,
                                                              b_prior=b_prior,
                                                              num_samples=num_samples,
                                                              truncation=truncation,
                                                              device=self.device)


        # stack of conv layers upto the given truncation level
        # convEncoderBlock does the masking and residual connection in it

        #first layer does pooling to downsample
        self.layers = nn.ModuleList([ConvEncoderBlock(self.in_channels, self.num_channels, kernel_size=5, pool=True).to(self.device)])
        #other layers preserve the dimensionality of the input
        for i in range(1, self.truncation):
            self.layers.append(ConvEncoderBlock(self.num_channels, self.num_channels, kernel_size=3, padding=1, residual=True).to(self.device))


    def _forward(self, x, mask_matrix, threshold):
        if not self.training and threshold > len(self.layers):
            threshold = len(self.layers)

        for layer_idx in range(threshold):
            mask = mask_matrix[:, layer_idx]
            x = self.layers[layer_idx](x, mask)

        return x

    def forward(self, x, num_samples=5):
        """
        Fits the data with different samples of architectures

        Parameters
        ----------
        x : data
        num_samples : Number of architectures to sample for KL divergence

        Returns
        -------
        act_vec : Tensor
            output from different architectures
        kl_loss: Tensor
            Kl divergence for each sampled architecture
        thresholds: numpy array
            threshold sampled for different architectures
        """

        # sample architeture from beta-bernoulli process
        mask_matrix, pi, n_layers, _ = self.architecture_sampler(num_samples)
        act_vec = []
        for i in range(num_samples):
            out = self._forward(x, mask_matrix[i], n_layers)
            act_vec.append(out.unsqueeze(0))

        act_vec = torch.cat(act_vec, dim=0)
        return act_vec

    def get_kl(self):
        return self.architecture_sampler.get_kl()

    def get_mask(self, num_samples=5):
        mask_matrix, pi, n_layers, _ = self.architecture_sampler(num_samples)
        return mask_matrix, n_layers


class AdaptiveDecoder(nn.Module):
    def __init__(self, out_channels, num_channels=400, a_prior=2.0, b_prior=2.0, num_samples=5, truncation=40, device=torch.device("cuda:0")):
        super(AdaptiveDecoder, self).__init__()
        self.out_channels = out_channels
        self.num_channels = num_channels
        self.device = device

        # instance of a stick breaking process
        self.architecture_sampler = SampleNetworkArchitecture(num_neurons=num_channels,
                                                              a_prior=a_prior,
                                                              b_prior=b_prior,
                                                              num_samples=num_samples,
                                                              truncation=truncation,
                                                              device=self.device)

        #stacking convTranspose layers
        self.layers = nn.ModuleList([])
        for i in range(truncation):
            self.layers.append(ConvDecoderBlock(self.num_channels, self.num_channels, residual=True).to(self.device))

        #last layer does upsampling to match the dimension of the input image
        self.output_layer = nn.Sequential(nn.UpsamplingNearest2d(scale_factor=2),
                                          nn.ConvTranspose2d(num_channels, out_channels, kernel_size=3, stride=1))


    def _forward(self, x, mask_matrix, threshold):
        if not self.training and threshold > len(self.layers):
            threshold = len(self.layers)

        for layer_idx in range(threshold):
            mask = mask_matrix[:, layer_idx]
            x = self.layers[layer_idx](x, mask)

        return x

    def forward(self, x, num_samples=5):
        """
        Fits the data with different samples of architectures

        Parameters
        ----------
        x : data
        num_samples : Number of architectures to sample for KL divergence

        Returns
        -------
        act_vec : Tensor
            output from different architectures
        kl_loss: Tensor
            Kl divergence for each sampled architecture
        thresholds: numpy array
            threshold sampled for different architectures
        """

        # sample architecture from beta-bernoulli process
        mask_matrix, pi, n_layers, _ = self.architecture_sampler(num_samples)

        act_vec = []
        for i in range(num_samples):
            out = self._forward(x[i], mask_matrix[i], n_layers)
            out = self.output_layer(out)
            act_vec.append(out.unsqueeze(0))

        act_vec = torch.cat(act_vec, dim=0)
        return act_vec

    def get_kl(self):
        return self.architecture_sampler.get_kl()