import torch
from torch import nn

from tasks.mino_unet.modules.MIMOUNet import EBlock, DBlock, AFF, FAM, SCM
from tasks.mino_unet.modules.layers import BasicConv
import torch
from torch import nn
from torch.nn import functional as F


class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost):
        super(VectorQuantizer, self).__init__()

        self._embedding_dim = embedding_dim
        self._num_embeddings = num_embeddings

        self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
        self._embedding.weight.data.uniform_(-1 / self._num_embeddings, 1 / self._num_embeddings)
        self._commitment_cost = commitment_cost

    def forward(self, inputs):
        # convert inputs from BCHW -> BHWC
        inputs = inputs.permute(0, 2, 3, 1).contiguous()
        input_shape = inputs.shape

        # Flatten input
        flat_input = inputs.view(-1, self._embedding_dim)

        # Calculate distances
        distances = (torch.sum(flat_input ** 2, dim=1, keepdim=True)
                     + torch.sum(self._embedding.weight ** 2, dim=1)
                     - 2 * torch.matmul(flat_input, self._embedding.weight.t()))

        # Encoding
        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
        encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
        encodings.scatter_(1, encoding_indices, 1)

        # Quantize and unflatten
        quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)

        # Loss
        e_latent_loss = F.mse_loss(quantized.detach(), inputs)
        q_latent_loss = F.mse_loss(quantized, inputs.detach())
        loss = q_latent_loss + self._commitment_cost * e_latent_loss

        quantized = inputs + (quantized - inputs).detach()
        avg_probs = torch.mean(encodings, dim=0)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))

        # convert quantized from BHWC -> BCHW
        return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodings


class MIMOUNetPlus4(nn.Module):
    def __init__(self, num_res=20):
        super(MIMOUNetPlus4, self).__init__()
        base_channel = 32
        self.Encoder = nn.ModuleList([
            EBlock(base_channel, num_res),
            EBlock(base_channel * 2, num_res),
            EBlock(base_channel * 4, num_res),
        ])

        self.feat_extract = nn.ModuleList([
            BasicConv(3, base_channel, kernel_size=3, relu=True, stride=1),
            BasicConv(base_channel, base_channel * 2, kernel_size=3, relu=True, stride=2),
            BasicConv(base_channel * 2, base_channel * 4, kernel_size=3, relu=True, stride=2),
            BasicConv(base_channel * 4, base_channel * 2, kernel_size=4, relu=True, stride=2, transpose=True),
            BasicConv(base_channel * 2, base_channel, kernel_size=4, relu=True, stride=2, transpose=True),
            BasicConv(base_channel, 4, kernel_size=3, relu=False, stride=1)
        ])

        self.Decoder = nn.ModuleList([
            DBlock(base_channel * 4, num_res),
            DBlock(base_channel * 2, num_res),
            DBlock(base_channel, num_res)
        ])

        self.Convs = nn.ModuleList([
            BasicConv(base_channel * 4, base_channel * 2, kernel_size=1, relu=True, stride=1),
            BasicConv(base_channel * 2, base_channel, kernel_size=1, relu=True, stride=1),
        ])

        self.ConvsOut = nn.ModuleList(
            [
                BasicConv(base_channel * 4, 4, kernel_size=3, relu=False, stride=1),
                BasicConv(base_channel * 2, 4, kernel_size=3, relu=False, stride=1),
            ]
        )

        self.AFFs = nn.ModuleList([
            AFF(base_channel * 7, base_channel * 1),
            AFF(base_channel * 7, base_channel * 2)
        ])

        self.FAM1 = FAM(base_channel * 4)
        self.SCM1 = SCM(base_channel * 4)
        self.FAM2 = FAM(base_channel * 2)
        self.SCM2 = SCM(base_channel * 2)

        self.drop1 = nn.Dropout2d(0.1)
        self.drop2 = nn.Dropout2d(0.1)

        self.c_vae = 64
        c_vae_hidden = 64
        self.vae_encoder = nn.Sequential(
            BasicConv(9, c_vae_hidden, kernel_size=3, relu=True, stride=4),
            EBlock(c_vae_hidden, 5),
            BasicConv(c_vae_hidden, c_vae_hidden * 2, kernel_size=3, relu=True, stride=4),
            EBlock(c_vae_hidden * 2, 5),
            nn.Conv2d(c_vae_hidden * 2, self.c_vae * 2, 1)
        )
        self.quantize = VectorQuantizer(8, self.c_vae, 0.25)
        self.vae_z_proj = nn.Conv2d(base_channel * 4 + self.c_vae, base_channel * 4, 1)

    def forward(self, x, x_gt=None):
        B, C, H, W = x.shape
        from utils.commons.hparams import hparams
        klds = None
        if x_gt is not None:
            z_q = self.vae_encoder(torch.cat([x_gt, x, x_gt - x], 1))
            klds, z_q, _, id_z = self.quantize(z_q)
        x_2 = F.interpolate(x, scale_factor=0.5)
        x_4 = F.interpolate(x_2, scale_factor=0.5)
        z2 = self.SCM2(x_2)
        z4 = self.SCM1(x_4)

        outputs = list()
        gate_xs = []

        x_ = self.feat_extract[0](x)
        res1 = self.Encoder[0](x_)

        z = self.feat_extract[1](res1)
        z = self.FAM2(z, z2)
        res2 = self.Encoder[1](z)

        z = self.feat_extract[2](res2)
        z = self.FAM1(z, z4)
        z = self.Encoder[2](z)

        if x_gt is not None:
            z_vae = self.quantize._embedding(torch.randint(0, 8, size=[B]))
        else:
            z_vae = z_q
        z = self.vae_z_proj(torch.cat([z, z_vae[:, :, None, None].repeat(1, 1, z.shape[2], z.shape[3])], 1))

        z12 = F.interpolate(res1, scale_factor=0.5)
        z21 = F.interpolate(res2, scale_factor=2)
        z42 = F.interpolate(z, scale_factor=2)
        z41 = F.interpolate(z42, scale_factor=2)

        res2 = self.AFFs[1](z12, res2, z42)
        res1 = self.AFFs[0](res1, z21, z41)

        res2 = self.drop2(res2)
        res1 = self.drop1(res1)

        z = self.Decoder[0](z)
        z_ = self.ConvsOut[0](z)
        z = self.feat_extract[3](z)
        if hparams['multiscale_gate']:
            gate_x = F.sigmoid(z_[:, :1])
            if hparams['clamp_gate'] > 0:
                gate_x = gate_x.clamp_min(hparams['clamp_gate'])
            gate_xs.append(gate_x)
            outputs.append(z_[:, 1:] * gate_x + x_4)
        else:
            outputs.append(z_[:, 1:] + x_4)

        z = torch.cat([z, res2], dim=1)
        z = self.Convs[0](z)
        z = self.Decoder[1](z)
        z_ = self.ConvsOut[1](z)
        z = self.feat_extract[4](z)
        if hparams['multiscale_gate']:
            gate_x = F.sigmoid(z_[:, :1])
            if hparams['clamp_gate'] > 0:
                gate_x = gate_x.clamp_min(hparams['clamp_gate'])
            gate_xs.append(gate_x)
            outputs.append(z_[:, 1:] * gate_x + x_2)
        else:
            outputs.append(z_[:, 1:] + x_2)

        z = torch.cat([z, res1], dim=1)
        z = self.Convs[1](z)
        z = self.Decoder[2](z)
        z = self.feat_extract[5](z)

        gate_x = F.sigmoid(z[:, :1])
        if hparams['clamp_gate'] > 0:
            gate_x = gate_x.clamp_min(hparams['clamp_gate'])
        outputs.append(z[:, 1:] * gate_x + x)
        gate_xs.append(gate_x)

        return outputs, gate_xs, klds
