import numpy as np
import torch.nn.functional as F
from utils.commons.hparams import hparams
from .MIMOUNet import EBlock, DBlock, FAM, SCM, AFF
from .glow import Glow, calc_loss
from .layers import *
import torch
from torch import nn

from .unet.unet import AttentionBlock
from .vit import ViT
import torch.distributions as dist


class MIMOUNetPlus3(nn.Module):
    def __init__(self, num_res=20):
        super(MIMOUNetPlus3, self).__init__()
        base_channel = 32
        self.Encoder = nn.ModuleList([
            EBlock(base_channel, num_res),
            EBlock(base_channel * 2, num_res),
            EBlock(base_channel * 4, num_res),
        ])

        self.feat_extract = nn.ModuleList([
            BasicConv(4, base_channel, kernel_size=3, relu=True, stride=1),
            BasicConv(base_channel, base_channel * 2, kernel_size=3, relu=True, stride=2),
            BasicConv(base_channel * 2, base_channel * 4, kernel_size=3, relu=True, stride=2),
            BasicConv(base_channel * 4, base_channel * 2, kernel_size=4, relu=True, stride=2, transpose=True),
            BasicConv(base_channel * 2, base_channel, kernel_size=4, relu=True, stride=2, transpose=True),
            BasicConv(base_channel, 3, kernel_size=3, relu=False, stride=1)
        ])

        self.Decoder = nn.ModuleList([
            DBlock(base_channel * 4, num_res),
            DBlock(base_channel * 2, num_res),
            DBlock(base_channel, num_res)
        ])

        self.Convs = nn.ModuleList([
            BasicConv(base_channel * 4, base_channel * 2, kernel_size=1, relu=True, stride=1),
            BasicConv(base_channel * 2, base_channel, kernel_size=1, relu=True, stride=1),
        ])

        self.ConvsOut = nn.ModuleList(
            [
                BasicConv(base_channel * 4, 3, kernel_size=3, relu=False, stride=1),
                BasicConv(base_channel * 2, 3, kernel_size=3, relu=False, stride=1),
            ]
        )

        self.AFFs = nn.ModuleList([
            AFF(base_channel * 7, base_channel * 1),
            AFF(base_channel * 7, base_channel * 2)
        ])

        self.FAM1 = FAM(base_channel * 4)
        self.SCM1 = SCM(base_channel * 4, c_in=4)
        self.FAM2 = FAM(base_channel * 2)
        self.SCM2 = SCM(base_channel * 2, c_in=4)

        self.c_vae = hparams.get('c_vae', 32)
        c_vae_hidden = self.c_vae
        self.vae_encoder = nn.Sequential(
            BasicConv(7, c_vae_hidden, kernel_size=3, relu=True, stride=hparams['vae_strides'][0]),
            EBlock(c_vae_hidden, 5),
            BasicConv(c_vae_hidden, c_vae_hidden * 2, kernel_size=3, relu=True, stride=hparams['vae_strides'][1]),
            EBlock(c_vae_hidden * 2, 5),
            nn.Conv2d(c_vae_hidden * 2, self.c_vae * 2, 1)
        )
        torch.nn.init.xavier_uniform_(self.vae_encoder[-1].weight)
        torch.nn.init.zeros_(self.vae_encoder[-1].bias)
        self.vae_z_proj = nn.Conv2d(base_channel * 4 + self.c_vae, base_channel * 4, 1)

        self.vit = ViT(image_size=(1440, 2160), patch_size=(30, 30), dim=256, depth=4, heads=4, mlp_dim=1024)
        self.attn = AttentionBlock(c_q=128, c_kv=256, channels=64)
        if hparams['use_flow']:
            self.flow = Glow(self.c_vae, 16, 1)
            self.cond_encoder = nn.Sequential(
                BasicConv(4, c_vae_hidden, kernel_size=3, relu=True, stride=hparams['vae_strides'][0]),
                EBlock(c_vae_hidden, 5),
                BasicConv(c_vae_hidden, c_vae_hidden * 2, kernel_size=3, relu=True, stride=hparams['vae_strides'][1]),
                EBlock(c_vae_hidden * 2, 5),
                nn.Conv2d(c_vae_hidden * 2, 16, 1)
            )

    def forward(self, x, x_gt, x_full):
        from utils.commons.hparams import hparams
        klds = None
        flow_cond = self.cond_encoder(x)

        if x_gt is not None and hparams['use_vae'] and self.training:
            z_q = self.vae_encoder(torch.cat([x_gt, x], 1))
            m_q, logs_q = torch.split(z_q, self.c_vae, dim=1)
            z_q = (m_q + torch.randn_like(m_q) * logs_q.exp())
            q_dist = dist.Normal(m_q, logs_q.exp())
            B, C, H, W = z_q.shape
            logqx = q_dist.log_prob(z_q).reshape(B, -1).sum(-1)
            klds = logqx

            if hparams['use_flow']:
                log_p, logdet, _ = self.flow(z_q, cond=flow_cond)
                logdet = logdet
                logpx = logdet + log_p
                klds = logqx - logpx

            klds = klds / C / H / W

        x_2 = F.interpolate(x, scale_factor=0.5)
        x_4 = F.interpolate(x_2, scale_factor=0.5)
        z2 = self.SCM2(x_2)
        z4 = self.SCM1(x_4)

        outputs = list()

        x_ = self.feat_extract[0](x)
        res1 = self.Encoder[0](x_)

        z = self.feat_extract[1](res1)
        z = self.FAM2(z, z2)
        res2 = self.Encoder[1](z)  # torch.Size([24, 128, 128, 128])

        z = self.feat_extract[2](res2)
        z = self.FAM1(z, z4)
        z = self.Encoder[2](z)  # torch.Size([24, 128, 64, 64])

        if hparams['use_attn']:
            x_full = self.vit(x_full)
            z = z + self.attn(z, x_full)

        if hparams['use_vae']:
            m = np.prod(hparams['vae_strides']) // 4
            if x_gt is not None and self.training:
                z_vae = z_q
            else:
                if hparams['use_flow']:
                    z_q = self.vae_encoder(torch.cat([x_gt, x], 1))
                    m_q, logs_q = torch.split(z_q, self.c_vae, dim=1)
                    z_q = (m_q + torch.randn_like(m_q) * logs_q.exp())
                    log_p, logdet, z_list = self.flow(z_q, cond=flow_cond)
                    # z_vae = z_list[0]
                    z_vae = torch.randn_like(z_list[0]) * hparams['vae_scale']
                    z_vae = self.flow.reverse([z_vae], cond=flow_cond)
                else:
                    z_vae = torch.randn_like(
                        z[:, :self.c_vae, :z.shape[2] // m, :z.shape[3] // m]) * hparams['vae_scale']
            # print(x.shape, z.shape, z_vae.shape)
            z = self.vae_z_proj(torch.cat([
                z, z_vae.repeat_interleave(m, 2).repeat_interleave(m, 3)], 1))
        z12 = F.interpolate(res1, scale_factor=0.5)
        z21 = F.interpolate(res2, scale_factor=2)
        z42 = F.interpolate(z, scale_factor=2)
        z41 = F.interpolate(z42, scale_factor=2)

        res2 = self.AFFs[1](z12, res2, z42)
        res1 = self.AFFs[0](res1, z21, z41)

        z = self.Decoder[0](z)
        z_ = self.ConvsOut[0](z)
        z = self.feat_extract[3](z)
        outputs.append(z_ + x_4[:, :3])

        z = torch.cat([z, res2], dim=1)
        z = self.Convs[0](z)
        z = self.Decoder[1](z)
        z_ = self.ConvsOut[1](z)
        z = self.feat_extract[4](z)
        outputs.append(z_ + x_2[:, :3])

        z = torch.cat([z, res1], dim=1)
        z = self.Convs[1](z)
        z = self.Decoder[2](z)
        z = self.feat_extract[5](z)
        outputs.append(z + x[:, :3])
        return outputs, klds

    def reparameterize(self, mu, logvar):
        """
        Will a single z be enough ti compute the expectation
        for the loss??
        :param mu: (Tensor) Mean of the latent Gaussian
        :param logvar: (Tensor) Standard deviation of the latent Gaussian
        :return:
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu
