import torch
from torch import nn
from torch.nn import functional as F

import sys
# sys.path.append('..')
import dino_wm.distributed_fn as dist_fn
from einops import rearrange

# Copyright 2018 The Sonnet Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================


# Borrowed from https://github.com/deepmind/sonnet and ported it to PyTorch


class Quantize(nn.Module):
    def __init__(self, dim, n_embed, decay=0.99, eps=1e-5):
        super().__init__()

        self.dim = dim
        self.n_embed = n_embed
        self.decay = decay
        self.eps = eps

        embed = torch.randn(dim, n_embed)
        self.register_buffer("embed", embed)
        self.register_buffer("cluster_size", torch.zeros(n_embed))
        self.register_buffer("embed_avg", embed.clone())

    def forward(self, input):
        flatten = input.reshape(-1, self.dim)
        dist = (
            flatten.pow(2).sum(1, keepdim=True)
            - 2 * flatten @ self.embed
            + self.embed.pow(2).sum(0, keepdim=True)
        )
        _, embed_ind = (-dist).max(1)
        embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype)
        embed_ind = embed_ind.view(*input.shape[:-1])
        quantize = self.embed_code(embed_ind)

        if self.training:
            embed_onehot_sum = embed_onehot.sum(0)
            embed_sum = flatten.transpose(0, 1) @ embed_onehot

            dist_fn.all_reduce(embed_onehot_sum)
            dist_fn.all_reduce(embed_sum)

            self.cluster_size.data.mul_(self.decay).add_(
                embed_onehot_sum, alpha=1 - self.decay
            )
            self.embed_avg.data.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay)
            n = self.cluster_size.sum()
            cluster_size = (
                (self.cluster_size + self.eps) / (n + self.n_embed * self.eps) * n
            )
            embed_normalized = self.embed_avg / cluster_size.unsqueeze(0)
            self.embed.data.copy_(embed_normalized)

        diff = (quantize.detach() - input).pow(2).mean()
        quantize = input + (quantize - input).detach()

        return quantize, diff, embed_ind

    def embed_code(self, embed_id):
        return F.embedding(embed_id, self.embed.transpose(0, 1))


class ResBlock(nn.Module):
    def __init__(self, in_channel, channel):
        super().__init__()

        self.conv = nn.Sequential(
            nn.ReLU(),
            nn.Conv2d(in_channel, channel, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(channel, in_channel, 1),
        )

    def forward(self, input):
        out = self.conv(input)
        out += input

        return out


class Encoder(nn.Module):
    def __init__(self, in_channel, channel, n_res_block, n_res_channel, stride):
        super().__init__()

        if stride == 4:
            blocks = [
                nn.Conv2d(in_channel, channel // 2, 4, stride=2, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(channel // 2, channel, 4, stride=2, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(channel, channel, 3, padding=1),
            ]

        elif stride == 2:
            blocks = [
                nn.Conv2d(in_channel, channel // 2, 4, stride=2, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(channel // 2, channel, 3, padding=1),
            ]

        for i in range(n_res_block):
            blocks.append(ResBlock(channel, n_res_channel))

        blocks.append(nn.ReLU(inplace=True))

        self.blocks = nn.Sequential(*blocks)

    def forward(self, input):
        return self.blocks(input)


# class Decoder(nn.Module):
#     def __init__(
#         self, in_channel, out_channel, channel, n_res_block, n_res_channel, stride
#     ):
#         super().__init__()

#         blocks = [nn.Conv2d(in_channel, channel, 3, padding=1)]

#         for i in range(n_res_block):
#             blocks.append(ResBlock(channel, n_res_channel))

#         blocks.append(nn.ReLU(inplace=True))

#         if stride == 4:
#             blocks.extend(
#                 [
#                     nn.ConvTranspose2d(channel, channel // 2, 4, stride=2, padding=1),
#                     nn.ReLU(inplace=True),
#                     nn.ConvTranspose2d(
#                         channel // 2, out_channel, 4, stride=2, padding=1
#                     ),
#                 ]
#             )

#         elif stride == 2:
#             blocks.append(
#                 nn.ConvTranspose2d(channel, out_channel, 4, stride=2, padding=1)
#             )

#         self.blocks = nn.Sequential(*blocks)

#     def forward(self, input):
#         return self.blocks(input)

# class Decoder(nn.Module):
#     def __init__(self, in_channel, out_channel, channel, n_res_block, n_res_channel, scale_factor):
#         super().__init__()
#         # First, a convolution and some residual blocks.
#         blocks = [nn.Conv2d(in_channel, channel, 3, padding=1)]
#         for _ in range(n_res_block):
#             blocks.append(ResBlock(channel, n_res_channel))
#         blocks.append(nn.ReLU(inplace=True))
#         # Use a transposed convolution to upsample by the desired scale factor.
#         # Output size formula for ConvTranspose2d:
#         #   output_size = (input_size - 1) * stride - 2*padding + kernel_size
#         # With kernel_size = scale_factor, stride = scale_factor, padding = 0,
#         # we get: output_size = (input_size - 1)*scale_factor + scale_factor = input_size * scale_factor.
#         blocks.append(nn.ConvTranspose2d(channel, out_channel, kernel_size=scale_factor, stride=scale_factor, padding=0))
#         self.blocks = nn.Sequential(*blocks)

#     def forward(self, x):
#         return self.blocks(x)

class Decoder(nn.Module):
    def __init__(self, in_channel, out_channel, channel, n_res_block, n_res_channel, scale_factor):
        """
        Args:
            in_channel: number of input channels (e.g. emb_dim)
            out_channel: number of output channels (e.g. 3 for RGB or emb_dim for intermediate upsampling)
            channel: intermediate number of channels in the decoder
            n_res_block: number of residual blocks
            n_res_channel: number of channels inside each residual block
            scale_factor: the overall scaling factor for this module; should be either 2 or 7 in our design.
        """
        super().__init__()
        
        blocks = [nn.Conv2d(in_channel, channel, 3, padding=1)]
        for _ in range(n_res_block):
            blocks.append(ResBlock(channel, n_res_channel))
        blocks.append(nn.ReLU(inplace=True))
        
        if scale_factor == 2:
            # First module: need overall factor 2
            # Use one ConvTranspose2d to upsample by 2 and one with stride 1 to refine
            blocks.extend([
                nn.ConvTranspose2d(channel, channel // 2, kernel_size=4, stride=2, padding=1),  # Upsample: factor 2 (e.g., 10->20)
                nn.ReLU(inplace=True),
                nn.ConvTranspose2d(channel // 2, out_channel, kernel_size=3, stride=1, padding=1)  # No spatial change
            ])
        elif scale_factor == 7:
            # Second module: need overall factor 7
            # Use one ConvTranspose2d with stride=1 (refinement) and one with stride=7 for upsampling.
            blocks.extend([
                nn.ConvTranspose2d(channel, channel // 2, kernel_size=3, stride=1, padding=1),  # No spatial change
                nn.ReLU(inplace=True),
                nn.ConvTranspose2d(channel // 2, out_channel, kernel_size=7, stride=7, padding=0)  # Upsample: factor 7 (e.g., 20->140)
            ])
        else:
            raise ValueError("scale_factor must be either 2 or 7")
        
        self.blocks = nn.Sequential(*blocks)

    def forward(self, x):
        return self.blocks(x)

class VQVAE(nn.Module):
    def __init__(
        self,
        in_channel=3,
        channel=128,
        n_res_block=2,
        n_res_channel=32,
        emb_dim=64,
        n_embed=512,
        decay=0.99,
        quantize=True,
    ):
        super().__init__()

        self.quantize = quantize
        self.quantize_b = Quantize(emb_dim, n_embed)

        if not quantize:
            for param in self.quantize_b.parameters():
                param.requires_grad = False

        # self.upsample_b = Decoder(emb_dim, emb_dim, channel, n_res_block, n_res_channel, stride=4)
        # self.dec = Decoder(
        #     emb_dim,
        #     in_channel,
        #     channel,
        #     n_res_block,
        #     n_res_channel,
        #     stride=4,
        # )

        # Here, emb_dim is the number of channels in your latent feature map.
        # The first decoder upsamples from 10x10 to 20x20 (scale_factor=2).
        self.upsample_b = Decoder(emb_dim, emb_dim, channel, n_res_block, n_res_channel, scale_factor=2)
        # The second decoder upsamples from 20x20 to 140x140 (scale_factor=7).
        self.dec = Decoder(emb_dim, in_channel, channel, n_res_block, n_res_channel, scale_factor=7)


        self.info = f"in_channel: {in_channel}, channel: {channel}, n_res_block: {n_res_block}, n_res_channel: {n_res_channel}, emb_dim: {emb_dim}, n_embed: {n_embed}, decay: {decay}"

    def forward(self, input):
        '''
            input: (b, t, num_patches, emb_dim)
        '''
        num_patches = input.shape[2]
        num_side_patches = int(num_patches ** 0.5)    
        input = rearrange(input, "b t (h w) e -> (b t) h w e", h=num_side_patches, w=num_side_patches)

        if self.quantize:
            quant_b, diff_b, id_b = self.quantize_b(input)
        else:
            quant_b, diff_b = input, torch.zeros(1).to(input.device)

        quant_b = quant_b.permute(0, 3, 1, 2)
        diff_b = diff_b.unsqueeze(0)
        dec = self.decode(quant_b)
        return dec, diff_b # diff is 0 if no quantization

    def decode(self, quant_b):
        upsample_b = self.upsample_b(quant_b) 
        dec = self.dec(upsample_b) # quant: (128, 64, 64)
        return dec

    def decode_code(self, code_b): # not used (only used in sample.py in original repo)
        quant_b = self.quantize_b.embed_code(code_b)
        quant_b = quant_b.permute(0, 3, 1, 2)
        dec = self.decode(quant_b)
        return dec
    

if __name__ == '__main__':
    model = VQVAE(emb_dim=384)
    x = torch.randn(2, 3, 100, 384)
    dec, diff = model(x)
    print(dec.shape, diff)
