from .unet import UNetModel, TimestepEmbedSequential, ResBlock
import logging
import torch as th
import torch.nn as nn
from .nn import (
    SiLU,
    conv_nd,
    zero_module,
    normalization,
    timestep_embedding,
    scale_module,
)
from .dct import DCTLayer


class UNetModel3OutChannels(UNetModel):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def forward(self, x, timesteps, y=None):
        out = super().forward(x, timesteps, y)
        if out.size(1) == 6:
            out = out.split(3, dim=1)[0]
        return out


class UNetModel4Pretrained(UNetModel):
    def __init__(self, head_out_channels, mode='simple', **kwargs):
        super().__init__(**kwargs)
        self.requires_grad_(False)
        self.mode = mode
        logging.info('UNetModel4Pretrained with mode={}'.format(self.mode))
        if mode == 'simple':
            self.out2 = nn.Sequential(
                normalization(self.out_ch),
                SiLU(),
                zero_module(conv_nd(kwargs["dims"], self.model_channels, head_out_channels, 3, padding=1)),
            )
        elif mode == 'complex':
            self.out2 = TimestepEmbedSequential(
                ResBlock(
                    self.out_ch,
                    self.time_embed_dim,
                    self.dropout,
                    dims=kwargs["dims"],
                    use_checkpoint=self.use_checkpoint,
                    use_scale_shift_norm=kwargs["use_scale_shift_norm"],
                ),
                normalization(self.out_ch),
                SiLU(),
                zero_module(conv_nd(kwargs["dims"], self.model_channels, head_out_channels, 3, padding=1)),
            )
        elif mode == 'blockcirc':
            self.out2 = nn.Sequential(
                normalization(self.out_ch),
                SiLU(),
                zero_module(conv_nd(kwargs["dims"], self.model_channels, head_out_channels, 3, padding=1))
            )
            sample_size = 32

            self.d_layer = nn.Sequential(
                normalization(self.mid_ch),
                nn.SiLU(),
                nn.Dropout(p=self.dropout),
                nn.Flatten(),
                nn.Linear(self.mid_ch * 4 ** 2, self.mid_ch * 4 ** 2),
                normalization(self.mid_ch * 4 ** 2),
                nn.SiLU(),
                nn.Dropout(p=self.dropout),
                zero_module(nn.Linear(self.mid_ch * 4 ** 2, sample_size ** 2)),
                nn.Unflatten(-1, (sample_size, sample_size)),
            )
            self.j_layer = nn.Sequential(
                normalization(self.mid_ch),
                nn.SiLU(),
                nn.Dropout(p=self.dropout),
                nn.Flatten(),
                nn.Linear(self.mid_ch * 4 ** 2, head_out_channels ** 2),
                nn.Unflatten(-1, (head_out_channels, head_out_channels)),
            )

            self.DCT = DCTLayer(sample_size, DCTLayer.Mode.DCT)
            self.DCT2 = DCTLayer(sample_size, DCTLayer.Mode.DCT, squared=True)
            self.IDCT = DCTLayer(sample_size, DCTLayer.Mode.IDCT)
        elif mode == '3kronecker':
            self.out2 = nn.Sequential(
                normalization(self.out_ch),
                SiLU(),
                zero_module(conv_nd(kwargs["dims"], self.model_channels, head_out_channels, 3, padding=1)),
            )
            self.d1_layer = nn.Sequential(
                normalization(self.out_ch),
                SiLU(),
                scale_module(conv_nd(kwargs["dims"], self.model_channels, 1, 3, padding=1), 0.1),
            )
            self.d2_layer = nn.Sequential(
                normalization(self.out_ch),
                SiLU(),
                scale_module(conv_nd(kwargs["dims"], self.model_channels, 1, 3, padding=1), 0.1),
            )
            self.j_layer = nn.Sequential(
                normalization(self.mid_ch),
                nn.SiLU(),
                nn.Flatten(),
                nn.Linear(self.mid_ch * 4 ** 2, head_out_channels ** 2),
                nn.Unflatten(-1, (head_out_channels, head_out_channels)),
            )
        elif mode == 'blockcirc_complex':
            self.out2 = TimestepEmbedSequential(
                ResBlock(
                    self.out_ch,
                    self.time_embed_dim,
                    self.dropout,
                    dims=kwargs["dims"],
                    use_checkpoint=self.use_checkpoint,
                    use_scale_shift_norm=kwargs["use_scale_shift_norm"],
                ),
                normalization(self.out_ch),
                SiLU(),
                zero_module(conv_nd(kwargs["dims"], self.model_channels, head_out_channels, 3, padding=1)),
            )
            sample_size = 64
            self.d_layer = TimestepEmbedSequential(
                ResBlock(
                    self.mid_ch,
                    self.time_embed_dim,
                    self.dropout,
                    dims=kwargs["dims"],
                    use_checkpoint=self.use_checkpoint,
                    use_scale_shift_norm=kwargs["use_scale_shift_norm"],
                ),
                normalization(self.mid_ch),
                nn.SiLU(),
                nn.Dropout(p=self.dropout),
                nn.Flatten(),
                zero_module(nn.Linear(self.mid_ch * 8 ** 2, sample_size ** 2)),
                nn.Unflatten(-1, (sample_size, sample_size)),
            )
            self.j_layer = TimestepEmbedSequential(
                ResBlock(
                    self.mid_ch,
                    self.time_embed_dim,
                    self.dropout,
                    dims=kwargs["dims"],
                    use_checkpoint=self.use_checkpoint,
                    use_scale_shift_norm=kwargs["use_scale_shift_norm"],
                ),
                normalization(self.mid_ch),
                nn.SiLU(),
                nn.Dropout(p=self.dropout),
                nn.Flatten(),
                nn.Linear(self.mid_ch * 8 ** 2, head_out_channels ** 2),
                nn.Unflatten(-1, (head_out_channels, head_out_channels)),
            )

            self.DCT = DCTLayer(sample_size, DCTLayer.Mode.DCT)
            self.DCT2 = DCTLayer(sample_size, DCTLayer.Mode.DCT, squared=True)
            self.IDCT = DCTLayer(sample_size, DCTLayer.Mode.IDCT)
        else:
            raise NotImplementedError

    def forward(self, x, timesteps, y=None):
        """
        Apply the model to an input batch.

        :param x: an [N x C x ...] Tensor of inputs.
        :param timesteps: a 1-D batch of timesteps.
        :param y: an [N] Tensor of labels, if class-conditional.
        :return: an [N x C x ...] Tensor of outputs.
        """
        assert (y is not None) == (
            self.num_classes is not None
        ), "must specify y if and only if the model is class-conditional"

        hs = []
        emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))

        if self.num_classes is not None:
            assert y.shape == (x.shape[0],)
            emb = emb + self.label_emb(y)

        # h = x.type(self.inner_dtype)
        h = x
        for module in self.input_blocks:
            h = module(h, emb)
            hs.append(h)
        h = self.middle_block(h, emb)

        if self.mode == 'blockcirc_complex':
            j = self.j_layer(h, emb)
            d = self.d_layer(h, emb)
        elif self.mode == 'blockcirc':
            j = self.j_layer(h)
            d = self.d_layer(h)
        elif self.mode == '3kronecker':
            j = self.j_layer(h)

        for module in self.output_blocks:
            cat_in = th.cat([h, hs.pop()], dim=1)
            h = module(cat_in, emb)
        h = h.type(x.dtype)
        out = self.out(h)
        if out.size(1) == 6:
            out = out.split(3, dim=1)[0]
        out2 = self.out2(h, emb) if 'complex' in self.mode else self.out2(h)
        if 'blockcirc' in self.mode:
            #out2 = nn.Softplus()(out2 -5.)
            out2 = out2 + 1.
            d = nn.Softplus()(d -5.)
            j = th.tril(j)
            jjT = th.bmm(j, th.transpose(j, 1,2))
            return th.cat([out, out2], dim=1), (j, jjT, d)
        if self.mode == '3kronecker':
            out2 = nn.Softplus()(out2 - 5.)
            d1 = self.d1_layer(h)[:,0,:,:]
            d2 = self.d2_layer(h)[:,0,:,:]
            return th.cat([out, out2], dim=1), (j, d1, d2)
        res = th.cat([out, out2], dim=1)
        return res
