'''Sample transforms for designing normalizing flows'''
# TODO: move this file and `neural_networks.py` into a separate networks folder

import pdb
import torch
import torch.nn.functional as F
import torch.nn as nn

from nflows.nn import nets as nets
from nflows.transforms.base import CompositeTransform
from nflows.transforms.coupling import (
    AdditiveCouplingTransform,
    AffineCouplingTransform,
    PiecewiseRationalQuadraticCouplingTransform
)
from nflows.transforms.lu import LULinear
from nflows.transforms.permutations import RandomPermutation, Permutation
from nflows.transforms.conv import OneByOneConvolution
from nflows.transforms.normalization import BatchNorm

class SimpleFlowTransform(CompositeTransform):
    '''Simple flow transform designed to act on flat data'''

    def __init__(
        self,
        features,
        hidden_features,
        num_layers,
        num_blocks_per_layer,
        include_linear=True,
        num_bins=8,
        tail_bound=1.0,
        activation=F.relu,
        dropout_probability=0.0,
        batch_norm_within_layers=False,
        coupling_constructor=PiecewiseRationalQuadraticCouplingTransform,
        net="mlp",
        data_shape=None,
        do_batchnorm=False,
        conditioning=None,
        conditioning_dimension=None
    ):
        mask = torch.ones(features)
        mask[::2] = -1

        self.model_type = net

        if conditioning_dimension is None:
            conditioning_dimension = 0

        assert not (do_batchnorm and net == "cnn"), "Batchnorm only implemented for 1D inputs"

        if net == "cnn":
            mask = mask[:data_shape[0]]
       
        def create_resnet(in_features, out_features):
            if net == "cnn":
                assert data_shape is not None

                return nets.ConvResidualNet(
                    in_features,
                    out_features,
                    hidden_channels=hidden_features,
                    num_blocks=num_blocks_per_layer,
                    activation=activation,
                    dropout_probability=dropout_probability,
                    use_batch_norm=batch_norm_within_layers,
                    context_features=conditioning_dimension
                )
            else:
                return nets.ResidualNet(
                    in_features,
                    out_features,
                    hidden_features=hidden_features,
                    num_blocks=num_blocks_per_layer,
                    activation=activation,
                    dropout_probability=dropout_probability,
                    use_batch_norm=batch_norm_within_layers,
                    context_features=conditioning_dimension
                )

        layers = []
        for _ in range(num_layers):
            coupling_transform = coupling_constructor(
                mask=mask,
                transform_net_create_fn=create_resnet,
                tails="linear",
                num_bins=num_bins,
                tail_bound=tail_bound,
            )
            layers.append(coupling_transform)
            mask *= -1

            if include_linear:

                if self.model_type == "cnn":
                    linear_transform = CompositeTransform([
                        OneByOneConvolution(data_shape[0], identity_init=True)
                    ])
                else:
                    if do_batchnorm:
                        layers.append(BatchNorm(features))
                    linear_transform = CompositeTransform([
                        RandomPermutation(features=features),
                        LULinear(features, identity_init=True)])
                layers.append(linear_transform)
         
        super().__init__(layers)


class SimpleNSFTransform(SimpleFlowTransform):

    def __init__(self, **kwargs):
        super().__init__(
            coupling_constructor=PiecewiseRationalQuadraticCouplingTransform,
            **kwargs,
        )

class SharedSimpleFlowTransform(CompositeTransform):
    '''Simple flow transform designed to act on flat data'''

    def __init__(
        self,
        features,
        hidden_features,
        share_start,
        share_middle,
        share_end,
        num_layers,
        num_blocks_per_layer,
        include_linear=True,
        num_bins=8,
        tail_bound=1.0,
        activation=F.relu,
        dropout_probability=0.0,
        batch_norm_within_layers=False,
        coupling_constructor=PiecewiseRationalQuadraticCouplingTransform,
        net="mlp",
        data_shape=None,
        do_batchnorm=False
    ):
        mask = torch.ones(features)
        mask[::2] = -1

        self.model_type = net

        assert not (do_batchnorm and net == "cnn"), "Batchnorm only implemented for 1D inputs"

        if net == "cnn":
            mask = mask[:data_shape[0]]
       
        def create_resnet(in_features, out_features):
            if net == "cnn":
                assert data_shape is not None

                return nets.ConvResidualNet(
                    in_features,
                    out_features,
                    hidden_channels=hidden_features,
                    num_blocks=num_blocks,
                    activation=activation,
                    dropout_probability=dropout_probability,
                    use_batch_norm=batch_norm_within_layers,
                )
            else:
                return nets.ResidualNet(
                    in_features,
                    out_features,
                    hidden_features=hidden_features,
                    num_blocks=num_blocks,
                    activation=activation,
                    dropout_probability=dropout_probability,
                    use_batch_norm=batch_norm_within_layers,
                )

        def make_layers(num_layers):
            nonlocal mask

            layers = []
            for _ in range(num_layers):
                coupling_transform = coupling_constructor(
                    mask=mask,
                    transform_net_create_fn=create_resnet,
                    tails="linear",
                    num_bins=num_bins,
                    tail_bound=tail_bound
                )
                layers.append(coupling_transform)
                mask *= -1

                if include_linear:

                    if self.model_type == "cnn":
                        linear_transform = CompositeTransform([
                            OneByOneConvolution(data_shape[0], identity_init=True)
                        ])
                    else:
                        if do_batchnorm:
                            layers.append(BatchNorm(features))
                        linear_transform = CompositeTransform([
                            RandomPermutation(features=features),
                            LULinear(features, identity_init=True)])
                    layers.append(linear_transform)
                
                return layers
        
        self.temp_modules = []

        num_blocks = num_blocks_per_layer[0]
        if not share_start:
            self.start_layers = make_layers(num_layers[0])
            self.temp_modules.extend(self.start_layers)
        else:
            make_layers(num_layers[0])

        num_blocks = num_blocks_per_layer[1]
        if not share_middle:
            self.middle_layers = make_layers(num_layers[1])
            self.temp_modules.extend(self.middle_layers)
        else:
            make_layers(num_layers[1])
        
        num_blocks = num_blocks_per_layer[2]
        if not share_end:
            self.end_layers = make_layers(num_layers[2])
            self.temp_modules.extend(self.end_layers)
        else:
            make_layers(num_layers[2])

        super().__init__(self.temp_modules)

    def add_shared_module(self, module, location):
        assert location in ["start", "middle", "end"]
        
        if location == "start":
            self.start_layers = [t for t in module._transforms]
            self.has_beginning = True 
        elif location == "middle":
            self.middle_layers = [t for t in module._transforms]
            self.has_middle = True 
        elif location == "end":
            self.end_layers = [t for t in module._transforms]
            self.has_end = True 
    
        super().__init__(self.start_layers + self.middle_layers + self.end_layers)


class SharedSimpleNSFTransform(SharedSimpleFlowTransform):

    def __init__(self, **kwargs):
        super().__init__(
            coupling_constructor=PiecewiseRationalQuadraticCouplingTransform,
            **kwargs,
        )