 

from overrides import overrides
from typing import Dict, Tuple
import torch
import torch.nn as nn
from torch.nn import Parameter

from  .flow import Flow
from  .actnorm import ActNorm2dFlow
from  .conv import Conv1x1Flow
from  .cdvft import Conv2dCDVFT
from  .nice import NICE
from .utils import squeeze2d, unsqueeze2d, split2d, unsplit2d

import copy


class Prior(Flow):
    """
    prior for multi-scale architecture
    """
    def __init__(self, in_channels, hidden_channels=None, s_channels=None, scale=True, inverse=False, factor=2, use_conv1by1=False,
                 use_intermediate_perm=False, permute_channel=False,
                 reverse_perm=False, bi_direction=False,
                 cdvft_partitions=7, LU_decomposed=False):
        super(Prior, self).__init__(inverse)
        self.actnorm = ActNorm2dFlow(in_channels, inverse=inverse)
        self.use_conv1by1 = use_conv1by1
        if use_conv1by1:
            self.conv1x1 = Conv1x1Flow(in_channels, inverse=inverse, LU_decomposed=LU_decomposed)
        else:
            self.conv1x1 = None

        if use_intermediate_perm:
            self.cdvft = Conv2dCDVFT(in_channels,n2_1=cdvft_partitions)
            # self.cdvft = Conv1x1Flow(in_channels, inverse=inverse, LU_decomposed=LU_decomposed)
        else:
            self.cdvft = None
        self.nice = NICE(in_channels, hidden_channels=hidden_channels, s_channels=s_channels, scale=scale, inverse=inverse, factor=factor)
        self.z1_channels = self.nice.z1_channels

        self.module_dict = nn.ModuleDict({
            'actnorm': self.actnorm,
            'conv1x1': self.conv1x1,
            'cdvft': self.cdvft,
            'nice': self.nice
        })

    def sync(self):
        # pass
        if self.conv1x1 is not None:
            self.conv1x1.sync() # todo: maybe remove this function (could cause bugs, check sync for butterfly flow)
        if self.cdvft is not None:
            self.cdvft.sync() # todo: maybe remove this function (could cause bugs, check sync for butterfly flow)

    def get_cdvft(self):
        return nn.ModuleDict({'cdvft': self.cdvft})


    def assign_module(self, target_module_dict):
        '''
        target_module_dict: dict containing replacement key and module
        '''
        if 'cdvft' in target_module_dict:
            self.cdvft = target_module_dict['cdvft']

    def get_parameters(self):

        model_params = list(self.actnorm.parameters()) + list(self.nice.parameters())
        if self.conv1x1 is not None:
            model_params += list(self.conv1x1.parameters())

        if self.cdvft is not None:
            cdvft_params = list(self.cdvft.parameters())
        else:
            cdvft_params = []

        return model_params, cdvft_params
    @overrides
    def forward(self, input: torch.Tensor, s=None) -> Tuple[torch.Tensor, torch.Tensor]:
        out, logdet_accum = self.actnorm.forward(input)

        if self.conv1x1 is not None:
            out, logdet = self.conv1x1.forward(out)
            logdet_accum = logdet_accum + logdet

        # added

        if self.cdvft is not None:
            out, logdet = self.cdvft.forward(out)
            logdet_accum = logdet_accum + logdet

        # added
        out, logdet = self.nice.forward(out, s=s)
        logdet_accum = logdet_accum + logdet
        return out, logdet_accum

    def backward(self, input: torch.Tensor, s=None) -> Tuple[torch.Tensor, torch.Tensor]:
        out, logdet_accum = self.nice.backward(input, s=s)

        # added
        if self.cdvft is not None:
            out, logdet = self.cdvft.backward(out)
            logdet_accum = logdet_accum + logdet

        if self.conv1x1 is not None:
            out, logdet = self.conv1x1.backward(out)
            logdet_accum = logdet_accum + logdet
        out, logdet = self.actnorm.backward(out)
        logdet_accum = logdet_accum + logdet
        return out, logdet_accum

    @overrides
    def init(self, data, s=None, init_scale=1.0) -> Tuple[torch.Tensor, torch.Tensor]:
        out, logdet_accum = self.actnorm.init(data, init_scale=init_scale)

        if self.conv1x1 is not None:
            out, logdet = self.conv1x1.init(out, init_scale=init_scale)
            logdet_accum = logdet_accum + logdet

        # added

        if self.cdvft is not None:
            out, logdet = self.cdvft.init(out, init_scale=init_scale)
            logdet_accum = logdet_accum + logdet


        out, logdet = self.nice.init(out, s=s, init_scale=init_scale)
        logdet_accum = logdet_accum + logdet
        return out, logdet_accum


class CDVFTGlowStep(Flow):
    """
    A step of Glow. A Conv1x1 followed with a NICE
    """
    def __init__(self, in_channels, hidden_channels=512, s_channels=0, scale=True, inverse=False,
                 coupling_type='conv', slice=None, heads=1, pos_enc=True, dropout=0.0, use_conv1by1=False,
                 use_intermediate_perm=False, weight_init_types=None, reverse_perm=False, bi_direction=False,
                 cdvft_partitions=7, permute_channel=False,
                 LU_decomposed=False, rotate=False):
        super(CDVFTGlowStep, self).__init__(inverse)
        self.actnorm = ActNorm2dFlow(in_channels, inverse=inverse)
        self.use_conv1by1 = use_conv1by1
        if use_conv1by1:
            self.conv1x1 = Conv1x1Flow(in_channels, inverse=inverse, LU_decomposed=LU_decomposed)
        else:
            self.conv1x1 = None

        if use_intermediate_perm:
            # self.cdvft = Conv1x1Flow(in_channels, inverse=inverse, LU_decomposed=LU_decomposed)
            self.cdvft = Conv2dCDVFT(in_channels,n2_1=cdvft_partitions)
        else:
            self.cdvft = None

        self.coupling = NICE(in_channels, hidden_channels=hidden_channels, s_channels=s_channels,
                             scale=scale, inverse=inverse, type=coupling_type, slice=slice, heads=heads, pos_enc=pos_enc, dropout=dropout)

        # self.module_dict = nn.ModuleDict({
        #     'actnorm': self.actnorm,
        #     'conv1x1': self.conv1x1,
        #     'butterfly': self.butterfly,
        #     'coupling': self.coupling
        # })
    def sync(self):
        if self.conv1x1 is not None:
            self.conv1x1.sync()
        # if self.butterfly is not None:
        #     self.butterfly.sync() # todo: maybe remove this function (could cause bugs, check sync for butterfly flow)

    def get_cdvft(self):
        return nn.ModuleDict({'cdvft': self.cdvft})

    def assign_module(self, target_module_dict):
        '''
        target_module_dict: dict containing replacement key and module
        '''

        if 'cdvft' in target_module_dict:
            self.cdvft = target_module_dict['cdvft']

    def get_parameters(self):

        model_params = list(self.actnorm.parameters()) + list(self.coupling.parameters())
        if self.conv1x1 is not None:
            model_params += list(self.conv1x1.parameters())

        if self.cdvft is not None:
            cdvft_params = list(self.cdvft.parameters())
        else:
            cdvft_params = []

        return model_params, cdvft_params
    @overrides
    def forward(self, input: torch.Tensor, s=None) -> Tuple[torch.Tensor, torch.Tensor]:
        out, logdet_accum = self.actnorm.forward(input)


        if self.conv1x1 is not None:
            out, logdet = self.conv1x1.forward(out)
            logdet_accum = logdet_accum + logdet
        # added

        if self.cdvft is not None:
            out, logdet = self.cdvft.forward(out)
            logdet_accum = logdet_accum + logdet

        # added
        out, logdet = self.coupling.forward(out, s=s)
        logdet_accum = logdet_accum + logdet
        return out, logdet_accum

    @overrides
    def backward(self, input: torch.Tensor, s=None) -> Tuple[torch.Tensor, torch.Tensor]:
        out, logdet_accum = self.coupling.backward(input, s=s)

        if self.cdvft is not None:
            out, logdet = self.cdvft.backward(out)
            logdet_accum = logdet_accum + logdet


        if self.conv1x1 is not None:
            out, logdet = self.conv1x1.backward(out)
            logdet_accum = logdet_accum + logdet
        # added
        out, logdet = self.actnorm.backward(out)
        logdet_accum = logdet_accum + logdet
        return out, logdet_accum

    @overrides
    def init(self, data, s=None, init_scale=1.0) -> Tuple[torch.Tensor, torch.Tensor]:
        out, logdet_accum = self.actnorm.init(data, init_scale=init_scale)

        if self.conv1x1 is not None:
            out, logdet = self.conv1x1.init(out, init_scale=init_scale)
            logdet_accum = logdet_accum + logdet

        # # added

        if self.cdvft is not None:
            out, logdet = self.cdvft.init(out, init_scale=init_scale)
            logdet_accum = logdet_accum + logdet


        out, logdet = self.coupling.init(out, s=s, init_scale=init_scale)
        logdet_accum = logdet_accum + logdet
        return out, logdet_accum

    def visualize_weights(self):

        if self.cdvft is None:
            return None
        else:
            imgs = self.cdvft.visualize_weights()
            return imgs

class CDVFTGlowTopBlock(Flow):
    """
    CDVFTGlow Block (squeeze at beginning)
    """
    def __init__(self, num_steps, in_channels, scale=True, inverse=False, use_conv1by1=False,
                 use_intermediate_perm=False, permute_channel=False, hidden_channels=None,
                 weight_init_types=None, reverse_perm=False,bi_direction=False,
                 cdvft_partitions=7,  LU_decomposed=False, rotate=False):
        super(CDVFTGlowTopBlock, self).__init__(inverse)

        if hidden_channels is None:
            glowstep_hidden = 512
        else:
            glowstep_hidden = hidden_channels
        steps = [CDVFTGlowStep(in_channels, scale=scale, inverse=inverse, use_conv1by1=use_conv1by1,
                                use_intermediate_perm=use_intermediate_perm, weight_init_types=weight_init_types,
                                reverse_perm=reverse_perm,bi_direction=bi_direction, hidden_channels=glowstep_hidden,
                                LU_decomposed=LU_decomposed, cdvft_partitions=cdvft_partitions, rotate=rotate,
                                permute_channel=permute_channel) for _ in range(num_steps)]
        self.steps = nn.ModuleList(steps)

        self.module_dict = nn.ModuleDict(
            {'steps': self.steps}
        )

    def sync(self):
        for step in self.steps:
            step.sync()

    def get_cdvft(self):
        cdvfts = nn.ModuleList([module.get_cdvft() for module in self.steps])

        return nn.ModuleDict({'steps':cdvfts})

    def assign_module(self, target_module_dict):
        '''
        target_module_dict: dict containing replacement key and module
        '''
        assert len(self.steps) == len(target_module_dict['steps'])
        for i, val in enumerate(target_module_dict['steps']):
            self.steps[i].assign_module(val)

    def get_parameters(self):
        model_params, cdvft_params = [],[]
        for module in self.steps:
            mp, cp = module.get_parameters()
            model_params += mp
            cdvft_params += cp

        return model_params, cdvft_params

    @overrides
    def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        out = input
        # [batch]
        logdet_accum = input.new_zeros(input.size(0))
        for step in self.steps:
            out, logdet = step.forward(out)
            logdet_accum = logdet_accum + logdet
        return out, logdet_accum

    @overrides
    def backward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        logdet_accum = input.new_zeros(input.size(0))
        out = input
        for step in reversed(self.steps):
            out, logdet = step.backward(out)
            logdet_accum = logdet_accum + logdet
        return out, logdet_accum

    @overrides
    def init(self, data, init_scale=1.0) -> Tuple[torch.Tensor, torch.Tensor]:
        out = data
        # [batch]
        logdet_accum = data.new_zeros(data.size(0))
        for step in self.steps:
            out, logdet = step.init(out, init_scale=init_scale)
            logdet_accum = logdet_accum + logdet
        return out, logdet_accum

    def visualize_weights(self):

        ret = []
        for step in self.steps:
            ret.append(step.visualize_weights())

        return ret

class CDVFTGlowInternalBlock(Flow):
    """
    CDVFTGlow Internal Block (squeeze at beginning and split at end)
    """
    def __init__(self, num_steps, in_channels, scale=True, inverse=False, use_conv1by1=False,
                 use_intermediate_perm=False, permute_channel=False, hidden_channels=None,
                 weight_init_types=None, reverse_perm=False,bi_direction=False,
                 cdvft_partitions=None, LU_decomposed=False, rotate=False):
        super(CDVFTGlowInternalBlock, self).__init__(inverse)
        if hidden_channels is None:
            glowstep_hidden = 512
            prior_hidden = None
        else:
            glowstep_hidden = prior_hidden = hidden_channels
        steps = [CDVFTGlowStep(in_channels, scale=scale, inverse=inverse, use_conv1by1=use_conv1by1,
                                use_intermediate_perm=use_intermediate_perm,hidden_channels=glowstep_hidden,
                                weight_init_types=weight_init_types, reverse_perm=reverse_perm,bi_direction=bi_direction,
                                cdvft_partitions=cdvft_partitions, LU_decomposed=LU_decomposed, rotate=rotate, permute_channel=permute_channel) for _ in range(num_steps)]
        self.steps = nn.ModuleList(steps)
        self.prior = Prior(in_channels, scale=scale, inverse=True,  use_conv1by1=use_conv1by1,
                           use_intermediate_perm=use_intermediate_perm, hidden_channels=prior_hidden,
                           reverse_perm=reverse_perm,bi_direction=bi_direction,
                           cdvft_partitions=cdvft_partitions,
                           LU_decomposed=LU_decomposed, permute_channel=permute_channel)

        self.module_dict = nn.ModuleDict(
            {'steps': self.steps,
             'prior': self.prior}
        )

    def sync(self):
        for step in self.steps:
            step.sync()
        self.prior.sync()

    def get_cdvft(self):
        cdvfts = nn.ModuleList([module.get_cdvft() for module in self.steps])

        prior_cf = self.prior.get_cdvft()

        return nn.ModuleDict({'steps': cdvfts, 'prior': prior_cf})

    def assign_module(self, target_module_dict):
        '''
        target_module_dict: dict containing replacement key and module
        '''
        assert len(self.steps) == len(target_module_dict['steps'])
        for i, val in enumerate(target_module_dict['steps']):
            self.steps[i].assign_module(val)

        self.prior.assign_module(target_module_dict['prior'])

    def get_parameters(self):
        model_params, cdvft_params = [], []
        for module in self.steps:
            mp, cp = module.get_parameters()
            model_params += mp
            cdvft_params += cp
        mp, cp = self.prior.get_parameters()
        model_params += mp
        cdvft_params += cp
        return model_params, cdvft_params
    @overrides
    def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        out = input
        # [batch]
        logdet_accum = input.new_zeros(input.size(0))
        for step in self.steps:
            out, logdet = step.forward(out)
            logdet_accum = logdet_accum + logdet
        out, logdet = self.prior.forward(out)
        logdet_accum = logdet_accum + logdet
        return out, logdet_accum

    @overrides
    def backward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # [batch]
        out, logdet_accum = self.prior.backward(input)
        for step in reversed(self.steps):
            out, logdet = step.backward(out)
            logdet_accum = logdet_accum + logdet
        return out, logdet_accum

    @overrides
    def init(self, data, init_scale=1.0) -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
        out = data
        # [batch]
        logdet_accum = data.new_zeros(data.size(0))
        for step in self.steps:
            out, logdet = step.init(out, init_scale=init_scale)
            logdet_accum = logdet_accum + logdet
        out, logdet = self.prior.init(out, init_scale=init_scale)
        logdet_accum = logdet_accum + logdet
        return out, logdet_accum

    def visualize_weights(self):

        ret = []
        for step in self.steps:
            ret.append(step.visualize_weights())

        return ret

class CDVFT_Glow(Flow):
    """
    CDVFT_Glow
    """
    def __init__(self, levels, num_steps, factors, in_channels, scale=True, inverse=False,
                 use_conv1by1=False, hidden_channels=None,
                 use_intermediate_perm=False, LU_decomposed=False, permute_channel=False,
                 weight_init_types=None, model_layers=None, reverse_perm=False, bi_direction=False,
                 cdvft_partitions=7, rotate=False):
        super(CDVFT_Glow, self).__init__(inverse)
        assert levels > 1, 'CDVFT_Glow should have at least 2 levels.'
        assert levels == len(num_steps)
        blocks = []
        self.levels = levels
        self.factors = factors
        self.use_conv1by1 = use_conv1by1

        if not use_intermediate_perm:
            self.cdvft = Conv2dCDVFT(in_channels,n2_1=cdvft_partitions)
        else:
            self.cdvft = None

        self.inverse = inverse

        for level in range(levels):
            if level == levels - 1:
                in_channels = in_channels * (self.factors[level]**2)
                macow_block = CDVFTGlowTopBlock(num_steps[level], in_channels, scale=scale, inverse=inverse, use_conv1by1=use_conv1by1,  LU_decomposed=LU_decomposed,
                                                 use_intermediate_perm=use_intermediate_perm, weight_init_types=weight_init_types, reverse_perm=reverse_perm,
                                                 bi_direction=bi_direction, hidden_channels=hidden_channels,
                                                 cdvft_partitions=cdvft_partitions, rotate=rotate, permute_channel=permute_channel)
                blocks.append(macow_block)
            else:
                in_channels = in_channels * (self.factors[level]**2)
                macow_block = CDVFTGlowInternalBlock(num_steps[level], in_channels, scale=scale, inverse=inverse,  use_conv1by1=use_conv1by1, LU_decomposed=LU_decomposed,
                                                      use_intermediate_perm=use_intermediate_perm, weight_init_types=weight_init_types, reverse_perm=reverse_perm,
                                                      bi_direction=bi_direction,hidden_channels=hidden_channels,
                                                      cdvft_partitions=cdvft_partitions, rotate=rotate, permute_channel=permute_channel)
                blocks.append(macow_block)
                in_channels = in_channels // self.factors[level]

                # if weight_init_types is not None and len(weight_init_types) > 0:
                
                #     if len(weight_init_types)> 0 and (isinstance(weight_init_types[0], list) or isinstance(weight_init_types[0], tuple) ):
                #         weight_init_types = copy.deepcopy(weight_init_types)
                #         for k in range(len(weight_init_types)):
                #             weight_init_types[k].pop()
                #         cdvft_partitions = copy.copy(cdvft_partitions)
                #         cdvft_partitions = [cp//self.factors[level] for cp in cdvft_partitions]
                #     else:
                
                #         weight_init_types = copy.copy(weight_init_types)
                #         weight_init_types.pop()
               



        self.blocks = nn.ModuleList(blocks)

        self.module_dict = nn.ModuleDict({
            'blocks': self.blocks,
            'cdvft': self.cdvft
        })

    def sync(self):
        for block in self.blocks:
            block.sync()

    def get_cdvft(self):
        cdvfts = {'blocks': nn.ModuleList([module.get_cdvft() for module in self.blocks]),
                      'cdvft': self.cdvft}

        return nn.ModuleDict(cdvfts)

    def assign_module(self, target_module_dict):
        '''
        target_module_dict: dict containing replacement key and module
        '''

        assert len(self.blocks) == len(target_module_dict['blocks'])
        for i, val in enumerate(target_module_dict['blocks']):
            self.blocks[i].assign_module(val)

        if 'cdvft' in target_module_dict:
            self.cdvft = target_module_dict['cdvft']

    def get_parameters(self):
        model_params, cdvft_params = [], []
        for module in self.blocks:
            mp, cp = module.get_parameters()
            model_params += mp
            cdvft_params += cp

        if self.cdvft is not None:
            cdvft_params += list(self.cdvft.parameters())
        return model_params, cdvft_params
    @overrides
    def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        logdet_accum = input.new_zeros(input.size(0))
        out = input
        outputs = []

        if self.cdvft is not None:
            out, logdet = self.cdvft(out)
            logdet_accum = logdet_accum + logdet

        for i, block in enumerate(self.blocks):
            out = squeeze2d(out, factor=self.factors[i])
            out, logdet = block.forward(out)
            logdet_accum = logdet_accum + logdet
            if isinstance(block, CDVFTGlowInternalBlock):
                out1, out2 = split2d(out, out.size(1) // self.factors[i])
                outputs.append(out2)
                out = out1

        out = unsqueeze2d(out, factor=self.factors[-1])
        for j in reversed(range(self.levels - 1)):
            out2 = outputs.pop()
            out = unsqueeze2d(unsplit2d([out, out2]), factor=self.factors[j])
        assert len(outputs) == 0
        return out, logdet_accum

    @overrides
    def backward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        outputs = []
        out = squeeze2d(input, factor=self.factors[-1])
        for j in range(self.levels - 1):
            out1, out2 = split2d(out, out.size(1) // self.factors[j])
            outputs.append(out2)
            out = squeeze2d(out1, factor=self.factors[j])

        logdet_accum = input.new_zeros(input.size(0))
        for i, block in enumerate(reversed(self.blocks)):
            if isinstance(block, CDVFTGlowInternalBlock):
                out2 = outputs.pop()
                out = unsplit2d([out, out2])
            out, logdet = block.backward(out)
            logdet_accum = logdet_accum + logdet
            out = unsqueeze2d(out, factor=self.factors[i])

        if self.cdvft is not None:
            out, logdet = self.cdvft.backward(out)
            logdet_accum = logdet_accum + logdet
        assert len(outputs) == 0
        return out, logdet_accum

    @overrides
    def init(self, data, init_scale=1.0) -> Tuple[torch.Tensor, torch.Tensor]:

        logdet_accum = data.new_zeros(data.size(0))
        out = data
        if self.cdvft is not None:
            out, logdet = self.cdvft(out)

            logdet_accum = logdet_accum + logdet
        outputs = []
        for i, block in enumerate(self.blocks):
            out = squeeze2d(out, factor=self.factors[i])
            out, logdet = block.init(out, init_scale=init_scale)
            logdet_accum = logdet_accum + logdet
            if isinstance(block, CDVFTGlowInternalBlock):
                out1, out2 = split2d(out, out.size(1) // self.factors[i])
                outputs.append(out2)
                out = out1

        out = unsqueeze2d(out, factor=self.factors[-1])
        for j in reversed(range(self.levels - 1)):
            out2 = outputs.pop()
            out = unsqueeze2d(unsplit2d([out, out2]), factor=self.factors[j])
        assert len(outputs) == 0
        return out, logdet_accum

    @classmethod
    def from_params(cls, params: Dict) -> "CDVFT_Glow":

        return CDVFT_Glow(**params)

    def visualize_weights(self):
        viz = []
        for module in self.blocks:
            imgs = module.visualize_weights()
            viz.append(imgs)

        return viz

CDVFT_Glow.register('CDVFT_glow')

# class Prior(Flow):
#     """
#     prior for multi-scale architecture
#     """
#     def __init__(self, in_channels, hidden_channels=None, s_channels=None, scale=True, inverse=False, factor=2, LU_decomposed=False):
#         super(Prior, self).__init__(inverse)
#         self.actnorm = ActNorm2dFlow(in_channels, inverse=inverse)
#         self.cdvft = CDVFTLinear(in_channels,in_channels)
#         self.nice = NICE(in_channels, hidden_channels=hidden_channels, s_channels=s_channels, scale=scale, inverse=inverse, factor=factor)
#         self.z1_channels = self.nice.z1_channels

#     # def sync(self):
#     #     self.cdvft.sync()

#     @overrides
#     def forward(self, input: torch.Tensor, s=None) -> Tuple[torch.Tensor, torch.Tensor]:
#         out, logdet_accum = self.actnorm.forward(input)

#         _,_,h,w = out.size()
#         out = self.cdvft.forward(out)
#         logdet_accum += self.cdvft.logabsdet().mul(h*w)

#         out, logdet = self.nice.forward(out, s=s)
#         logdet_accum = logdet_accum + logdet
#         return out, logdet_accum

#     def backward(self, input: torch.Tensor, s=None) -> Tuple[torch.Tensor, torch.Tensor]:
#         out, logdet_accum = self.nice.backward(input, s=s)

#         _,_,h,w = out.size()
#         out = self.cdvft.reverse(out)
#         logdet = -self.cdvft.logabsdet().mul(h*w)
#         logdet_accum = logdet_accum + logdet

#         out, logdet = self.actnorm.backward(out)
#         logdet_accum = logdet_accum + logdet
#         return out, logdet_accum

#     @overrides
#     def init(self, data, s=None, init_scale=1.0) -> Tuple[torch.Tensor, torch.Tensor]:
#         out, logdet_accum = self.actnorm.init(data, init_scale=init_scale)

#         _,_,h,w = out.size()
#         out= self.cdvft.init(out, init_scale=init_scale)
#         logdet = -self.cdvft.logabsdet().mul(h*w)
#         logdet_accum = logdet_accum + logdet

#         out, logdet = self.nice.init(out, s=s, init_scale=init_scale)
#         logdet_accum = logdet_accum + logdet
#         return out, logdet_accum


# class CDVFTGlowStep(Flow):
#     """
#     A step of Glow. A Conv1x1 followed with a NICE
#     """
#     def __init__(self, in_channels, hidden_channels=512, s_channels=0, scale=True, inverse=False,
#                  coupling_type='conv', slice=None, heads=1, pos_enc=True, dropout=0.0, LU_decomposed=False):
#         super(CDVFTGlowStep, self).__init__(inverse)
#         self.actnorm = ActNorm2dFlow(in_channels, inverse=inverse)
#         self.cdvft = CDVFTLinear(in_channels,in_channels)
#         self.coupling = NICE(in_channels, hidden_channels=hidden_channels, s_channels=s_channels,
#                              scale=scale, inverse=inverse, type=coupling_type, slice=slice, heads=heads, pos_enc=pos_enc, dropout=dropout)

#     def sync(self):
#         self.cdvft.sync()

#     @overrides
#     def forward(self, input: torch.Tensor, s=None) -> Tuple[torch.Tensor, torch.Tensor]:
#         out, logdet_accum = self.actnorm.forward(input)

#         _,_,h,w = out.size()
#         out = self.cdvft.forward(out)
#         logdet_accum += self.cdvft.logabsdet().mul(h*w)

#         out, logdet = self.coupling.forward(out, s=s)
#         logdet_accum = logdet_accum + logdet
#         return out, logdet_accum

#     @overrides
#     def backward(self, input: torch.Tensor, s=None) -> Tuple[torch.Tensor, torch.Tensor]:
#         out, logdet_accum = self.coupling.backward(input, s=s)

#         _,_,h,w = out.size()
#         out = self.cdvft.reverse(out)
#         logdet = -self.cdvft.logabsdet().mul(h*w)
#         logdet_accum = logdet_accum + logdet

#         out, logdet = self.actnorm.backward(out)
#         logdet_accum = logdet_accum + logdet
#         return out, logdet_accum

#     @overrides
#     def init(self, data, s=None, init_scale=1.0) -> Tuple[torch.Tensor, torch.Tensor]:
#         out, logdet_accum = self.actnorm.init(data, init_scale=init_scale)

#         out, logdet = self.cdvft.init(out, init_scale=init_scale)
#         logdet_accum = logdet_accum + logdet

#         out, logdet = self.coupling.init(out, s=s, init_scale=init_scale)
#         logdet_accum = logdet_accum + logdet
#         return out, logdet_accum


# class CDVFTGlowTopBlock(Flow):
#     """
#     Glow Block (squeeze at beginning)
#     """
#     def __init__(self, num_steps, in_channels, scale=True, inverse=False, LU_decomposed=False):
#         super(CDVFTGlowTopBlock, self).__init__(inverse)
#         steps = [CDVFTGlowStep(in_channels, scale=scale, inverse=inverse, LU_decomposed=LU_decomposed) for _ in range(num_steps)]
#         self.steps = nn.ModuleList(steps)

#     def sync(self):
#         for step in self.steps:
#             step.sync()

#     @overrides
#     def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
#         out = input
#         # [batch]
#         logdet_accum = input.new_zeros(input.size(0))
#         for step in self.steps:
#             out, logdet = step.forward(out)
#             logdet_accum = logdet_accum + logdet
#         return out, logdet_accum

#     @overrides
#     def backward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
#         logdet_accum = input.new_zeros(input.size(0))
#         out = input
#         for step in reversed(self.steps):
#             out, logdet = step.backward(out)
#             logdet_accum = logdet_accum + logdet
#         return out, logdet_accum

#     @overrides
#     def init(self, data, init_scale=1.0) -> Tuple[torch.Tensor, torch.Tensor]:
#         out = data
#         # [batch]
#         logdet_accum = data.new_zeros(data.size(0))
#         for step in self.steps:
#             out, logdet = step.init(out, init_scale=init_scale)
#             logdet_accum = logdet_accum + logdet
#         return out, logdet_accum


# class CDVFTGlowInternalBlock(Flow):
#     """
#     Glow Internal Block (squeeze at beginning and split at end)
#     """
#     def __init__(self, num_steps, in_channels, scale=True, inverse=False, LU_decomposed=False):
#         super(CDVFTGlowInternalBlock, self).__init__(inverse)
#         steps = [CDVFTGlowStep(in_channels, scale=scale, inverse=inverse, LU_decomposed=LU_decomposed) for _ in range(num_steps)]
#         self.steps = nn.ModuleList(steps)
#         self.prior = Prior(in_channels, scale=scale, inverse=True, LU_decomposed=LU_decomposed)

#     def sync(self):
#         for step in self.steps:
#             step.sync()
#         self.prior.sync()

#     @overrides
#     def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
#         out = input
#         # [batch]
#         logdet_accum = input.new_zeros(input.size(0))
#         for step in self.steps:
#             out, logdet = step.forward(out)
#             logdet_accum = logdet_accum + logdet
#         out, logdet = self.prior.forward(out)
#         logdet_accum = logdet_accum + logdet
#         return out, logdet_accum

#     @overrides
#     def backward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
#         # [batch]
#         out, logdet_accum = self.prior.backward(input)
#         for step in reversed(self.steps):
#             out, logdet = step.backward(out)
#             logdet_accum = logdet_accum + logdet
#         return out, logdet_accum

#     @overrides
#     def init(self, data, init_scale=1.0) -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
#         out = data
#         # [batch]
#         logdet_accum = data.new_zeros(data.size(0))
#         for step in self.steps:
#             out, logdet = step.init(out, init_scale=init_scale)
#             logdet_accum = logdet_accum + logdet
#         out, logdet = self.prior.init(out, init_scale=init_scale)
#         logdet_accum = logdet_accum + logdet
#         return out, logdet_accum


# class CDVFTGlow(Flow):
#     """
#     Glow
#     """
#     def __init__(self, levels, num_steps, factors, in_channels, scale=True, inverse=False, LU_decomposed=False,**kwargs):
#         super(CDVFTGlow, self).__init__(inverse)
#         assert levels > 1, 'Glow should have at least 2 levels.'
#         assert levels == len(num_steps) == len(factors)
#         blocks = []
#         self.factors = factors
#         self.levels = levels
#         for level in range(levels):
#             if level == levels - 1:
#                 in_channels = in_channels * (factors[level]**2)
#                 macow_block = CDVFTGlowTopBlock(num_steps[level], in_channels, scale=scale, inverse=inverse, LU_decomposed=LU_decomposed)
#                 blocks.append(macow_block)
#             else:
#                 in_channels = in_channels * (factors[level]**2)
#                 macow_block = CDVFTGlowInternalBlock(num_steps[level], in_channels, scale=scale, inverse=inverse, LU_decomposed=LU_decomposed)
#                 blocks.append(macow_block)
#                 in_channels = in_channels // factors[level]
#         self.blocks = nn.ModuleList(blocks)

#     def sync(self):
#         for block in self.blocks:
#             block.sync()

#     def get_parameters(self):
#         return list(self.parameters()), []

#     @overrides
#     def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
#         logdet_accum = input.new_zeros(input.size(0))
#         out = input
#         outputs = []
#         for i, block in enumerate(self.blocks):
#             out = squeeze2d(out, factor=self.factors[i])
#             out, logdet = block.forward(out)
#             logdet_accum = logdet_accum + logdet
#             if isinstance(block, CDVFTGlowInternalBlock):
#                 out1, out2 = split2d(out, out.size(1) // self.factors[i])
#                 outputs.append(out2)
#                 out = out1

#         out = unsqueeze2d(out, factor=self.factors[-1])
#         for j in reversed(range(self.levels - 1)):
#             out2 = outputs.pop()
#             out = unsqueeze2d(unsplit2d([out, out2]), factor=self.factors[j])
#         assert len(outputs) == 0
#         return out, logdet_accum

#     @overrides
#     def backward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
#         outputs = []
#         out = squeeze2d(input, factor=self.factors[-1])
#         for j in range(self.levels - 1):
#             out1, out2 = split2d(out, out.size(1) // self.factors[j])
#             outputs.append(out2)
#             out = squeeze2d(out1, factor=self.factors[j])

#         logdet_accum = input.new_zeros(input.size(0))
#         for i, block in enumerate(reversed(self.blocks)):
#             if isinstance(block, CDVFTGlowInternalBlock):
#                 out2 = outputs.pop()
#                 out = unsplit2d([out, out2])
#             out, logdet = block.backward(out)
#             logdet_accum = logdet_accum + logdet
#             out = unsqueeze2d(out, factor=self.factors[i])
#         assert len(outputs) == 0
#         return out, logdet_accum

#     @overrides
#     def init(self, data, init_scale=1.0) -> Tuple[torch.Tensor, torch.Tensor]:
#         logdet_accum = data.new_zeros(data.size(0))
#         out = data
#         outputs = []
#         for i, block in enumerate(self.blocks):
#             out = squeeze2d(out, factor=self.factors[i])
#             out, logdet = block.init(out, init_scale=init_scale)
#             logdet_accum = logdet_accum + logdet
#             if isinstance(block, CDVFTGlowInternalBlock):
#                 out1, out2 = split2d(out, out.size(1) // self.factors[i])
#                 outputs.append(out2)
#                 out = out1

#         out = unsqueeze2d(out, factor=self.factors[-1])
#         for j in reversed(range(self.levels - 1)):
#             out2 = outputs.pop()
#             out = unsqueeze2d(unsplit2d([out, out2]), factor=self.factors[j])
#         assert len(outputs) == 0
#         return out, logdet_accum

#     @classmethod
#     def from_params(cls, params: Dict) -> "CDVFTGlow":
#         return CDVFTGlow(**params)

#     def visualize_weights(self, *args):
#         pass

# CDVFTGlow.register('cdvft_glow')

