import math
import torch
import torch.nn as nn


from modules.complexPyTorch.complexLayers import (
    ComplexConv2d,
    ComplexBatchNorm2d,
    ComplexGELU,
    ComplexLinear
)

from layers import (
    ResidualBottleneck,
    DWConv2d,
    PhasorBlockC,
    SpectralBranches
)



class PsychoNet(nn.Module):
    def __init__(self,
                 dim,
                 in_channels,
                 conv1_config,
                 residual_block_expansion,
                 residual_block_config,
                 phasori_config,
                 phasorc_config,
                 spectral_branch_config,
                 hadamard_filter_init_scale=None,
                 hadamard_filter_init_type=None,
                 phasor_use_shortcut=True,
                 use_maxpool=True,
                 num_classes=1000,
                 img_size=224,
                 filter=None,
                 verbose=False
                 ):
        super().__init__()
        self.verbose = verbose

        self.conv1_config = conv1_config
        self.residual_block_expansion = residual_block_expansion
        self.residual_block_config = residual_block_config
        self.phasori_config = phasori_config
        self.phasorc_config = phasorc_config
        self.spectral_branch_config = spectral_branch_config
        self.hadamard_filter_init_type = hadamard_filter_init_type
        self.hadamard_filter_init_scale = hadamard_filter_init_scale
        self.use_maxpool = use_maxpool
        if self.hadamard_filter_init_scale is not None:
            self.hadamard_filter_init_scale = float(self.hadamard_filter_init_scale)
        self.phasor_use_shortcut = phasor_use_shortcut

        self.Nmax = img_size // 16
        self.subdepth = dim
        self.in_channels = in_channels
        self.num_classes = num_classes
        self.filter = filter
        self.unit = False
        self.compand = 1.25
        self.softshrink = False
        self.sparsity_threshold = 0.1

        if self.verbose:
            print("Companding Applied:", self.compand)

        # setup input layer
        self.conv1 = nn.Conv2d(self.in_channels, self.subdepth,
                               kernel_size=self.conv1_config[0],
                               stride=self.conv1_config[2],
                               bias=False,
                               padding=self.conv1_config[1])  # conv
        self.bn1 = nn.BatchNorm2d(self.subdepth)
        self.gelu1 = nn.GELU()
        if self.use_maxpool:
            self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        else:
            self.maxpool = nn.Identity()

        # setup ResBlocks
        self.resblocks = nn.ModuleList()
        self.resblock_output_names = []
        for i, block_config in enumerate(self.residual_block_config):
            output_name = None
            if len(block_config) == 3:
                in_planes, planes, stride = block_config
            elif len(block_config) == 4:
                in_planes, planes, stride, output_name = block_config
            self.resblock_output_names.append(output_name)
            in_planes = int(in_planes * self.subdepth) * (self.residual_block_expansion if i > 0 else 1)
            is_last_layer = i == (len(self.residual_block_config) - 1)
            planes = int(planes * self.subdepth) // (self.residual_block_expansion if is_last_layer else 1)
            self.resblocks.append(ResidualBottleneck(in_planes, planes, stride, expansion=self.residual_block_expansion))

        # setup Complex (I) block, which generates complementary imaginary features
        # for real valued input.
        self.phase_dwconv1 = DWConv2d(in_dim=self.subdepth * self.phasori_config[0][0],
                                      dim=self.subdepth * self.phasori_config[0][1],
                                      kernels_per_layer=self.phasori_config[1][-1],
                                      kernel_size=self.phasori_config[0][2],
                                      padding=self.phasori_config[0][3])
        self.phase_bn1 = nn.BatchNorm2d(self.subdepth * self.phasori_config[0][1])
        self.phase_gelu1 = nn.GELU()
        self.phase_dwconv2 = DWConv2d(in_dim=self.subdepth * self.phasori_config[1][0],
                                      dim=self.subdepth * self.phasori_config[1][1],
                                      kernels_per_layer=self.phasori_config[1][-1],
                                      kernel_size=self.phasori_config[1][2],
                                      padding=self.phasori_config[1][3])
        self.phase_bn2 = nn.BatchNorm2d(self.subdepth * self.phasori_config[1][1])
        self.phase_gelu2 = nn.GELU()
        self.cconv1 = ComplexConv2d(self.subdepth * self.phasori_config[2][0],
                                    self.subdepth * self.phasori_config[2][1],
                                    kernel_size=3,
                                    stride=self.phasori_config[2][2],
                                    bias=False,
                                    padding=1)
        self.cbn1 = ComplexBatchNorm2d(self.subdepth * self.phasori_config[2][1])
        self.cgelu1 = ComplexGELU(phase_amp=False)

        # set up Phasor Block (C)s
        self.phasor_blocks = nn.ModuleList()
        self.phasor_block_output_names = []
        if self.phasorc_config is None or len(self.phasorc_config) == 0:
            # handle no Phasor Block (C) case
            self.phasor_blocks.append(nn.Identity())
            self.phasor_block_output_names.append(None)
        else:
            for i, block_config in enumerate(self.phasorc_config):
                output_name = None
                if len(block_config) == 8:
                    d_in, d_bottle, d_out, stride, depth, dw_kernel_size, dw_padding, dw_kernels_per_layer = block_config
                elif len(block_config) == 9:
                    d_in, d_bottle, d_out, stride, depth, dw_kernel_size, dw_padding, dw_kernels_per_layer, output_name = block_config
                self.phasor_block_output_names.append(output_name)

                self.phasor_blocks.append(PhasorBlockC(d_in=d_in * self.subdepth,
                                                       d_out=d_out*self.subdepth,
                                                       depth=depth,
                                                       stride=stride,
                                                       dw_kernel_size=dw_kernel_size,
                                                       dw_padding=dw_padding,
                                                       dw_kernels_per_layer=dw_kernels_per_layer))

        self.spectral_branches_blocks = nn.ModuleDict()
        for (block_name, block_config) in self.spectral_branch_config.fields.items():
            self.spectral_branches_blocks[block_name] = SpectralBranches(
                N=self.Nmax,
                subdepth=self.subdepth,
                spectral_branch_config=block_config,
                hadamard_norm_layer=ComplexBatchNorm2d,
                filter_init_type=hadamard_filter_init_type,
                filter_init_scale=hadamard_filter_init_scale
            )
        self.pooling_depth = self.subdepth * sum([sum(block.output_dims) for block in self.spectral_branches_blocks.values()])
        self.classifier = ComplexLinear(self.pooling_depth, self.num_classes)


    def forward(self, x):
        spectral_branches_inputs = {}

        # input layer
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.gelu1(x)
        x = self.maxpool(x)

        # ResBlocks
        for resblock, resblock_output_name in zip(self.resblocks, self.resblock_output_names):
            x = resblock(x)
            if resblock_output_name is not None:
                spectral_branches_inputs[resblock_output_name] = x.clone()

        # Phasor Block (I)
        px = self.phase_dwconv1(x)
        px = self.phase_bn1(px)
        px = self.phase_gelu1(px)
        px = self.phase_dwconv2(px)
        px = self.phase_bn2(px)
        px = self.phase_gelu2(px)

        # combine real and imag (learnt) into complex tensor for input
        x = x + 1j * px.type(torch.complex64)  # add phases to help capture the different mixing of channel info
        x = self.cconv1(x)  # down 4
        x = self.cbn1(x)
        x = self.cgelu1(x)

        # Phasor Block (C)
        for (phasor_block, phasor_block_output_name) in zip(self.phasor_blocks, self.phasor_block_output_names):
            x = phasor_block(x)
            if phasor_block_output_name is not None:
                spectral_branches_inputs[phasor_block_output_name] = x.clone()
        x = x.to(torch.complex64)  # for torch.amp

        # FFT
        Fx = torch.fft.fft2(x, norm='backward', dim=(-2, -1))
        Fx[..., 0, 0] = 1e-12  # remove DC coefficient
        Fx = torch.fft.fftshift(Fx, dim=(-2, -1))  # move DC to center

        if self.unit:  # norm the area to 1
            Fx_area = torch.mean(Fx, dim=(-2, -1))
            Fx = Fx / Fx_area[:, :, None, None]

        # compand the signal by compressing and expanding the dynamic range
        if self.compand == 2:
            Fx = torch.sqrt(torch.abs(Fx)) * torch.exp(1j * Fx.angle())
        elif self.compand == 3:  # mu-law
            # https://en.wikipedia.org/wiki/%CE%9C-law_algorithm
            mu = 255.
            Fx = torch.log(1. + mu * torch.abs(Fx)) / math.log(1. + mu) * torch.exp(1j * Fx.angle())
        elif self.compand is not None:  # cube root
            Fx = torch.pow(torch.abs(Fx), 1. / self.compand) * torch.exp(1j * Fx.angle())

        if not self.filter is None:
            Fx *= self.filter[None, None, :, :]  # apply filter

        spectral_branches_inputs['POST_PHASOR'] = Fx.clone()
        spectral_branch_outputs = []
        for (input_name, features) in spectral_branches_inputs.items():
            output = self.spectral_branches_blocks[input_name](features)
            spectral_branch_outputs.extend(output)

        # aggregate outputs and predict logits
        out = torch.cat(spectral_branch_outputs, dim=1)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        out = out.abs()

        return out

    @classmethod
    def from_cfg(cls, cfg):
        """
        Build model from config file
        """
        return cls(
            dim=getattr(cfg.MODEL, 'DIM', 64),
            in_channels=getattr(cfg.MODEL, 'IN_CHANN', 3),
            num_classes=getattr(cfg.MODEL, 'N_CLASS', 1000),
            img_size=getattr(cfg.MODEL, 'IMG_SIZE', 224),
            conv1_config=cfg.MODEL.CONV1,
            residual_block_expansion=cfg.MODEL.RESIDUAL_BLOCKS.EXPANSION,
            residual_block_config=cfg.MODEL.RESIDUAL_BLOCKS.LAYERS,
            phasori_config=cfg.MODEL.PHASOR_BLOCK_I,
            phasorc_config=getattr(cfg.MODEL, 'PHASOR_BLOCKS', None),
            spectral_branch_config=cfg.MODEL.SPECTRAL_BRANCHES,
            hadamard_filter_init_type=getattr(cfg.MODEL, 'HADAMARD_FILTER_INIT_TYPE', 'weibull'),
            hadamard_filter_init_scale=getattr(cfg.MODEL, 'HADAMARD_FILTER_INIT_SCALE', 0.02),
            phasor_use_shortcut=getattr(cfg.MODEL, 'PHASOR_USE_SHORTCUT', True),
            use_maxpool=getattr(cfg.MODEL, 'USE_MAXPOOL', True)
        )
