import logging
import torch
import torch.nn as nn
from .model import Model, ResnetBlock, Normalize, get_timestep_embedding, nonlinearity
from .dct import DCTLayer

class Model4Pretrained(Model):
    def __init__(self, head_out_ch, mode="simple", **kwargs):
        super().__init__(**kwargs)
        self.requires_grad_(False)
        self.mode = mode
        logging.info('Model4Pretrained with mode={}'.format(self.mode))
        if mode == 'simple':
            self.before_out = lambda x, temb: x
        elif mode == 'complex':
            self.before_out = ResnetBlock(in_channels=self.block_in,
                                          out_channels=self.block_in,
                                          temb_channels=self.temb_ch,
                                          dropout=self.dropout)
        elif 'blockcirc' in mode:
            self.before_mid = lambda x, temb: x
            sample_size = 64
            self.d_layer = nn.Sequential(
                Normalize(self.mid_ch),
                nn.SiLU(),
                nn.Dropout(p=self.dropout),
                nn.Flatten(),
                nn.Linear(self.mid_ch * 4 ** 2, self.mid_ch * 4 ** 2),
                Normalize(self.mid_ch * 4 ** 2),
                nn.SiLU(),
                nn.Dropout(p=self.dropout),
                nn.Linear(self.mid_ch * 4 ** 2, sample_size ** 2),
                nn.Unflatten(-1, (sample_size, sample_size)),
            )
            self.j_layer = nn.Sequential(
                Normalize(self.mid_ch),
                nn.SiLU(),
                nn.Dropout(p=self.dropout),
                nn.Flatten(),
                nn.Linear(self.mid_ch * 4 ** 2, head_out_ch ** 2),
                nn.Unflatten(-1, (head_out_ch, head_out_ch)),
            )

            self.DCT = DCTLayer(sample_size, DCTLayer.Mode.DCT)
            self.DCT2 = DCTLayer(sample_size, DCTLayer.Mode.DCT, squared=True)
            self.IDCT = DCTLayer(sample_size, DCTLayer.Mode.IDCT)
            self.before_out = lambda x, temb : x
            if 'complex' in self.mode:
                self.before_mid = ResnetBlock(in_channels=self.mid_ch,
                                          out_channels=self.mid_ch,
                                          temb_channels=self.temb_ch,
                                          dropout=self.dropout)
                self.before_out = ResnetBlock(in_channels=self.block_in,
                                          out_channels=self.block_in,
                                          temb_channels=self.temb_ch,
                                          dropout=self.dropout)
        else:
            raise NotImplementedError

        self.norm_out2 = Normalize(self.block_in)
        self.conv_out2 = torch.nn.Conv2d(self.block_in,
                                         head_out_ch,
                                         kernel_size=3,
                                         stride=1,
                                         padding=1)

    def forward(self, x, t):
        assert x.shape[2] == x.shape[3] == self.resolution

        # timestep embedding
        temb = get_timestep_embedding(t, self.ch)
        temb = self.temb.dense[0](temb)
        temb = nonlinearity(temb)
        temb = self.temb.dense[1](temb)

        # downsampling
        hs = [self.conv_in(x)]
        for i_level in range(self.num_resolutions):
            for i_block in range(self.num_res_blocks):
                h = self.down[i_level].block[i_block](hs[-1], temb)
                if len(self.down[i_level].attn) > 0:
                    h = self.down[i_level].attn[i_block](h)
                hs.append(h)
            if i_level != self.num_resolutions - 1:
                hs.append(self.down[i_level].downsample(hs[-1]))

        # middle
        h = hs[-1]
        h = self.mid.block_1(h, temb)
        h = self.mid.attn_1(h)
        h = self.mid.block_2(h, temb)
        
        if 'blockcirc' in self.mode:
            h_mid = self.before_mid(h, temb)
            j = self.j_layer(h_mid)
            d = self.d_layer(h_mid)

        # upsampling
        for i_level in reversed(range(self.num_resolutions)):
            for i_block in range(self.num_res_blocks + 1):
                h = self.up[i_level].block[i_block](
                    torch.cat([h, hs.pop()], dim=1), temb)
                if len(self.up[i_level].attn) > 0:
                    h = self.up[i_level].attn[i_block](h)
            if i_level != 0:
                h = self.up[i_level].upsample(h)

        # end
        out = self.conv_out(nonlinearity(self.norm_out(h)))
        out2 = self.conv_out2(nonlinearity(self.norm_out2(self.before_out(h, temb))))
        if 'blockcirc' in self.mode:
            out2 = nn.Softplus()(out2 - 5.)
            d = nn.Softplus()(d - 5.)
            j = torch.tril(j)
            jjT = torch.bmm(j, torch.transpose(j,1,2))
            return torch.cat([out, out2], dim=1), (j, jjT, d)
        res = torch.cat([out, out2], dim=1)
        return res
