# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# Copyright 2022 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools
import sys
from typing import Dict, List, NamedTuple, Optional, Sequence, Tuple

import torch
import torch.nn as nn
from torch import Tensor

from .bottlenecks import GsnConditionalLocScaleShift
from .entropy_model_layers import EncoderSection, LearnedPosition, StartSym, Transformer, ConvTransBlock
from .swin_3d import BasicLayer, PatchMerging, PatchEmbed3D
from .layers_utils import make_embedding
from .patcher import Patcher
from einops import rearrange
from compressai.entropy_models import EntropyBottleneck, GaussianConditional

_LATENT_NORM_FAC: float = 35.0  # factor to scale latents by


class PreviousLatent(NamedTuple):
    """Previous latent with the following attributes

    Attributes:
        quantized: the quantized latent
        processed: the processed latent by running it through an encoder. See
            `VCTEntropyModle.process_previous_latent_q` for more details.
    """

    quantized: Tensor
    processed: Tensor

def conv(in_channels, out_channels, kernel_size=5, stride=2):
    return nn.Conv2d(
        in_channels,
        out_channels,
        kernel_size=kernel_size,
        stride=stride,
        padding=kernel_size // 2,
    )

def ste_round(x: Tensor) -> Tensor:
    return torch.round(x) - x.detach() + x


class TemporalEntropyModelOut(NamedTuple):
    """Output of the VCT temporal entropy model

    Attributes:
      perturbed_latent: noised (training=True) or quantized (training=False) latent.
        Tensor of shape [b', seq_len, C]
      bits: bits taken to transmit the latent. Tensor of shape: [b', seq_len, C]
      features: (optional) features of the entropy model to be used by a synthesis
        transform for dequantizing. Tensor of shape: [B, d_model, H, W]
    """

    perturbed_latent: Tensor
    bits: Tensor
    features: Optional[Tensor] = None


class VCTEntropyModel(nn.Module):
    """
    Temporal Entropy Model
    """

    def __init__(
        self,
        num_channels: int = 192,
        context_len: int = 2,
        window_size_enc: int = 4,
        window_size_dec: int = 4,
        num_layers_encoder_sep: int = 6,
        num_layers_encoder_joint: int = 6,
        num_layers_encoder_fusion: int = 4,
        d_model: int = 192,
        num_head: int = 8,
        mlp_expansion: int = 4,
        drop_out_enc: float = 0.1,
        drop_out_dec: float = 0.1,
    ) -> None:
        """
        Temporal Entropy Model
        Args:
            num_channels: number of channels in the latent space,
                i.e. symbols per token. Defaults to 192.
            context_len: number of previous latents. Defaults to 2.
            window_size_enc: window (patch) size in encoder.
                Defaults to 8.
            window_size_dec: window (patch) size in decoder.
                Defaults to 4.
            num_layers_encoder_sep: number of layers in the separate encoder.
                Defaults to 3.
            num_layers_encoder_joint: number of layers in the joint encoder.
                Defaults to 2.
            num_layers_encoder_fusion: number of layers in the decoder.
                Defaults to 5.
            d_model: feature dimensionality inside the model.
                Defaults to 768.
            num_head: number of attention heads in MHA layers.
                Defaults to 16.
            mlp_expansion: expansion *factor* for each MLP.
                Defaults to 4.
            drop_out_enc: dropout probability in encoder.
                Defaults to 0.0.
            drop_out_dec: dropout probability in decoder.
                Defaults to 0.0.
        """
        super().__init__()
        if window_size_enc < window_size_dec:
            raise ValueError(
                f"window_size_enc={window_size_enc} cannot be lower"
                f"than window_size_dec={window_size_dec}."
            )
        if num_channels < 0:
            raise ValueError(f"num_channels={num_channels} cannot be negative")
        self.num_channels = num_channels
        self.window_size_enc = window_size_enc
        self.window_size_dec = window_size_dec

        self.d_model = d_model
        # we will use compressai's GsnConditional as a bottleneck
        self.bottleneck = GsnConditionalLocScaleShift(
            num_scales=256, num_means=100, min_scale=0.01, tail_mass=(2 ** (-8))
        )

        self.gs_condition = GaussianConditional(None)

        self.range_bottleneck = None

        self.context_len = context_len

        self.num_slices = 6 #  
        ##swin3d
        # build layers
        self.swin3d_encoder = BasicLayer(
                dim = self.num_channels,
                depth = num_layers_encoder_sep,
                num_heads = num_head,
                window_size = (1,8,8),
                mlp_ratio=4,
                qkv_bias=True,
                qk_scale=None,
                drop=0.,
                attn_drop=0.,
                drop_path=0.,
                norm_layer=nn.LayerNorm,
                downsample=None,# PatchMerging
                use_checkpoint=False)
        
        self.swin3d_joint = BasicLayer(
                dim=self.num_channels,
                depth = num_layers_encoder_joint,
                num_heads = num_head,
                window_size = (2,8,8),
                mlp_ratio=4,
                qkv_bias=True,
                qk_scale=None,
                drop=0.,
                attn_drop=0.,
                drop_path=0.,
                norm_layer=nn.LayerNorm,
                downsample=None,# PatchMerging
                use_checkpoint=False)
        
        self.swin3d_fusion = BasicLayer(
                dim=self.num_channels,
                depth = num_layers_encoder_fusion,
                num_heads = num_head,
                window_size = (2,8,8),
                mlp_ratio=4,
                qkv_bias=True,
                qk_scale=None,
                drop=0.,
                attn_drop=0.,
                drop_path=0.,
                norm_layer=nn.LayerNorm,
                downsample=None,# PatchMerging
                use_checkpoint=False)

        self.atten_mean = nn.ModuleList(
            nn.Sequential(
                BasicLayer(dim = (self.num_channels + (self.num_channels//self.num_slices)*min(i, 6)), depth = 2, num_heads = 16, window_size = (1,4,4),)
            ) for i in range(self.num_slices)
            )
        self.atten_scale = nn.ModuleList(
            nn.Sequential(
                BasicLayer(dim = (self.num_channels + (self.num_channels//self.num_slices)*min(i, 6)), depth = 2, num_heads = 16, window_size = (1,4,4),)
            ) for i in range(self.num_slices)
            )
        
        self.cc_mean_transforms = nn.ModuleList(
            nn.Sequential(
                conv(self.num_channels + (self.num_channels//self.num_slices)*min(i, 6), 224, stride=1, kernel_size=3),
                nn.GELU(),
                conv(224, 128, stride=1, kernel_size=3),
                nn.GELU(),
                conv(128, (self.num_channels//self.num_slices), stride=1, kernel_size=3),
            ) for i in range(self.num_slices)
        )
        self.cc_scale_transforms = nn.ModuleList(
            nn.Sequential(
                conv(self.num_channels + (self.num_channels//self.num_slices)*min(i, 5), 224, stride=1, kernel_size=3),
                nn.GELU(),
                conv(224, 128, stride=1, kernel_size=3),
                nn.GELU(),
                conv(128, (self.num_channels//self.num_slices), stride=1, kernel_size=3),
            ) for i in range(self.num_slices)
            )
        
        self.lrp_transforms = nn.ModuleList(
            nn.Sequential(
                conv(self.num_channels + (self.num_channels//self.num_slices)*min(i+1, 7), 224, stride=1, kernel_size=3),
                nn.GELU(),
                conv(224, 128, stride=1, kernel_size=3),
                nn.GELU(),
                conv(128, (self.num_channels//self.num_slices), stride=1, kernel_size=3),
            ) for i in range(self.num_slices)
        )

        self.seq_len_dec = window_size_dec**2
        self.seq_len_enc = window_size_enc**2

        self.patcher = Patcher(8, "reflect")
        self.learned_zero = StartSym(hidden_dim=num_channels)

        self.dec_position = LearnedPosition(
            seq_length=self.seq_len_dec, hidden_dim=d_model
        )

        self.post_embedding_layernorm = nn.LayerNorm(d_model, eps=1e-6)

        self.encoder_embedding = make_embedding(
            input_dim=num_channels, hidden_dim=d_model
        )  # a single linear layer
        self.decoder_embedding = make_embedding(
            input_dim=num_channels, hidden_dim=d_model
        )  # a single linear layer


        self.text_kv_proj = nn.Linear(512, self.num_channels)  
        self.fusion_text_block = nn.MultiheadAttention(self.num_channels, num_heads = 8, batch_first=True)


        def _make_final_heads(output_channels: int) -> nn.Module:
            # 3 stacked linear layers with leakyrelu activations
            return nn.Sequential(
                nn.Linear(d_model, d_model),
                nn.GELU(), 
                nn.Linear(d_model, d_model),
                nn.GELU(), 
                nn.Linear(d_model, output_channels),
            )

        self.mean_head = _make_final_heads(num_channels)
        self.scale_head = _make_final_heads(num_channels)

        self.CNN3DMerge = PatchEmbed3D(
            patch_size=(2,1,1), in_chans=self.num_channels, embed_dim=self.num_channels)
        
        self.CNN3DFusion_Prev = PatchEmbed3D(
            patch_size=(2,1,1), in_chans=self.num_channels, embed_dim=self.num_channels)
        
        self.max_support_slices = 6

    @staticmethod
    def round_st(x: Tensor) -> Tensor:
        """
        Straight-trhough round
        """
        return (torch.round(x) - x).detach() + x

    def fusion_prev(self,prev_state: PreviousLatent, old_state: PreviousLatent, text_embeddings) -> PreviousLatent:
        """
        Fuse two PreviousLatent processed features using transformer and return a new PreviousLatent
        """
        prev = prev_state.processed   # [B,C,T,H,W]
        old = old_state.processed     # [B,C,T,H,W]

        fused = self.swin3d_fusion(torch.cat([prev, old], dim=2))  # [B,C,T*2,H,W]
        fused = self.CNN3DFusion_Prev(fused)# [B,C,T,H,W]

        return PreviousLatent(quantized=prev_state.quantized, processed=fused)
    
    def process_previous_latent_q(
        self, previous_latent_quantized: Tensor
    ) -> PreviousLatent:
        """Process previous quantized latent by passing it through the encoder.

        This can be used if previous latents go through expensive transforms
        before being fed to the entropy model, and will be stored in the `processed`
        field of the `PreviousLatent` tuple.

        The output of this function applied to all quantized latents should
        be fed to the `forward` method. This is used to improve efficiency,
        as it avoids calling expensive processing of previous latents at
        each time step.

        Args:
            previous_latent_quantized: previous quantized latent that is to be processed,
                expected shape [B, 1, C, H, W]

        Returns:
            PreviousLatent object with the processed latent in the processed field
        """

        previous_latent_quantized = rearrange(previous_latent_quantized, 'b d c h w -> b c d h w')
        previous_latent_encoded = self.swin3d_encoder(previous_latent_quantized) #   previous_latent_quantized   (B, C, D, H, W)

        return PreviousLatent(previous_latent_quantized, processed=previous_latent_encoded)

    def _embed_latent_q_patched(self, latent_q_patched: Tensor) -> Tensor:
        """Embed current patched latent for decoder

        The input latent is normalized, embedded in d_model dimension, and
        positional encoding is added.
        Args:
            latent_q_patched: tensor of shape [b', seq_len_dec, C]

        Returns:
            tensor of shape [b', seq_len_dec, d_model]
        """
        latent_q_patched = latent_q_patched / _LATENT_NORM_FAC  # [b', seq_len_dec, C]
        latent_q_patched = self.decoder_embedding(
            latent_q_patched
        )  # [b', seq_len_dec, d_model]
        latent_q_patched = self.post_embedding_layernorm(
            latent_q_patched
        )  # [b', seq_len_dec, d_model]
        return self.dec_position(latent_q_patched)  # [b', seq_len_dec, d_model]

    def _get_transformer_output(
        self, *, encoded_patched: Tensor
    ) -> Tuple[Tensor, Tensor, Tensor]:
        """Predict the distribution of the current quantized patched latent

        Args:
            encoded_patched: tensor of shape [*b, context_len*patch_enc^2, d_model],
                where with the defaults (patch_enc=8, d_model=768), so we have
                default expected shape [*b, 128, 768]
            latent_q_patched: tensor of shape [*b, patch_dec^2, num_channels], where
                with the default patch_dec^2=16 and num_channels = 192, so we have
                default expected shape [*b, 16, 192]

        Returns:
            a tuple containing 3 tensors: mean, scale and decoder output
        """

        encoded_patched_joint = self.swin3d_joint(encoded_patched)
        encoded_patched = encoded_patched + encoded_patched_joint
        encoded_patched = self.CNN3DMerge(encoded_patched)
        encoded_patched = encoded_patched.squeeze(2)

        encoded_patches, _ = self.patcher(encoded_patched, 8)# [B', seq_length_dec, num_channels]

        mean = self.mean_head(encoded_patches)  # [B', seq_length_dec, num_channels]
        scale = self.scale_head(encoded_patches)  # [B', seq_length_dec, num_channels]

        return mean, scale, encoded_patches

    def _get_encoded_seqs(
        self, previous_latents: Sequence[PreviousLatent]
    ) -> List[Tensor]:
        """
        Extract the previously procesed latents, repeating them if the
        number of processed latents is less than the context legnth

        Args:
            previous_latents: sequence of sizse at most `context_len`, containing object
                of type PreviousLatent, with two attributes:
                    - `processed` tensor of shape [b', seq_len_enc, d_model]
                    - `quantized` tensor of shape [B, C, H, W], NOT needed in this method

        Returns:
            List of length `context_len` with tensors of shape [b', seq_len_enc, d_model],
                containing `processed` data only (encoder processed data).
        """
        encoded_seqs = [p.processed for p in previous_latents]
        if len(encoded_seqs) < self.context_len:
            if self.context_len == 2:
                # encoded_seqs is a list of size 1
                return encoded_seqs * 2  # [b', seq_len_enc, d_model]*2
            elif self.context_len == 3:
                return (
                    encoded_seqs * 3  # [b', seq_len_enc, d_model]*3
                    if len(encoded_seqs) == 1
                    # repeat the 0th twice
                    else [encoded_seqs[0]] * 2 + [encoded_seqs[1]]
                )
            else:
                ValueError(f"Unsupported context_len={self.context_len}")
        return encoded_seqs

    def forward(
        self,
        latent_unquantized: Tensor,
        previous_latents: Sequence[PreviousLatent],
    ) -> TemporalEntropyModelOut:
        """
        Args:
            latent_unquantized: the latent to transmit (quantize), expected shape is
                [B, C, T, H, W]
            previous_latents: previously transmitted (quantized) latents, should be of
                size at least one and at most `context_len`. Each PreviousLatent has
                    - quantized: floats (i.e. noised) tensor of shape [B, C, T, H, W]
                    - processed: [b', seq_len_enc, d_model]

        Returns:
            TemporalEntropyModelOut, see docstring there.
        """
        H, W = latent_unquantized.shape[-2:]

        encoded_seqs = self._get_encoded_seqs(previous_latents = previous_latents) # swin3d output

        latent_means, latent_scales, dec_output = self._get_transformer_output(
            encoded_patched = torch.cat(encoded_seqs, dim = 2),  # cat on seq dim
        )

        ## chanel wise entropy model
        latent_unquantized = latent_unquantized.squeeze(1)
        y_patched, _ = self.patcher(
            latent_unquantized, 8
        )# [B', seq_length_dec, num_channels]

        y_slices = y_patched.chunk(self.num_slices,-1)# [B', seq_length_dec, num_channels/num_slices]

        y_hat_slices = []
        y_likelihood = []
        mu_list = []
        scale_list = []

        for slice_index, y_slice in enumerate(y_slices):
            support_slices = (y_hat_slices if self.max_support_slices < 0 else y_hat_slices[:self.max_support_slices])

            mean_support = torch.cat([latent_means] + support_slices, dim=-1) # [B', seq_length_dec, num_channels + slice_index * num_channels/num_slices]
            mean_support = rearrange(mean_support, 'b (h w) c -> b c h w', h=8, w=8)# [B', num_channels + slice_index * num_channels/num_slices, 8, 8]

            mean_support = mean_support.unsqueeze(2)# [B', num_channels + slice_index * num_channels/num_slices, 1, 8, 8]
            mean_support = self.atten_mean[slice_index](mean_support) #swin trans 3d
            mean_support = mean_support.squeeze(2)# [B', num_channels + slice_index * num_channels/num_slices, 8, 8]

            mu = self.cc_mean_transforms[slice_index](mean_support) # [B', num_channels//self.num_slices, 8, 8] 
            mu = mu[:, :, :8, :8]
            mu_list.append(mu)

            scale_support = torch.cat([latent_scales] + support_slices, dim=-1) # same as mean
            scale_support = rearrange(scale_support, 'b (h w) c -> b c h w', h=8, w=8)
            scale_support = scale_support.unsqueeze(2)
            scale_support = self.atten_scale[slice_index](scale_support)
            scale_support = scale_support.squeeze(2)
            scale = self.cc_scale_transforms[slice_index](scale_support)
            scale = scale[:, :, :8, :8]
            scale_list.append(scale)
            
            y_slice = rearrange(y_slice,'b (h w) c -> b c h w', h=8, w=8) # [B', seq_length_dec, num_channels/num_slices]
            _, y_slice_likelihood = self.bottleneck(y_slice, scale, mu)
            y_likelihood.append(y_slice_likelihood)
            y_hat_slice = ste_round(y_slice - mu) + mu
            lrp_support = torch.cat([mean_support, y_hat_slice], dim=1)
            lrp = self.lrp_transforms[slice_index](lrp_support)
            lrp = 0.5 * torch.tanh(lrp)
            y_hat_slice += lrp

            y_hat_slice = rearrange(y_hat_slice,' b c h w -> b (h w) c', h=8, w=8) # [B', seq_length_dec, num_channels/num_slices]
            y_hat_slices.append(y_hat_slice)
            
        
        y_patched = torch.cat(y_hat_slices, dim=-1) # [B', seq_length_dec, num_channels]
        # print('y_patched.shape',y_patched.shape)
        y_hat = self.patcher.unpatch(
            (y_patched, (2, 2)), crop=(H, W), channels_last=True
        )# [B, C, H, W] y_hat.shape torch.Size([1, 384, 16, 16])

        means = torch.cat(mu_list, dim=1)
        means = rearrange(means,' b c h w -> b (h w) c', h=8, w=8)
        means = self.patcher.unpatch(
            (means, (2, 2)), crop=(H, W), channels_last=True
        )# [B, C, H, W] y_hat.shape torch.Size([1, 384, 16, 16])

        scales = torch.cat(scale_list, dim=1)
        scales = rearrange(scales,' b c h w -> b (h w) c', h=8, w=8)
        scales = self.patcher.unpatch(
            (scales, (2, 2)), crop=(H, W), channels_last=True
        )# [B, C, H, W] y_hat.shape torch.Size([1, 384, 16, 16])

        y_likelihoods = torch.cat(y_likelihood, dim=1)
        y_likelihoods = rearrange(y_likelihoods,' b c h w -> b (h w) c', h=8, w=8)

        decoder_features = self.patcher.unpatch(
            x_patched=(dec_output, (2, 2)), crop=(H, W), channels_last=True
        )  # [B, C, H, W]

        return TemporalEntropyModelOut(
            perturbed_latent=y_hat, bits=y_likelihoods, features=decoder_features
        )

    def _get_mean_scale_jitted(
        self, *, encoded_patched: Tensor, latent_q_patched: Tensor
    ):
        """
        TODO implement JIT version of the transformer forward pass
            (mean, scale, dec_output = self._get_transformer_output(...))

        Args:
            encoded_patched: what we condition on, tensor of shape
                [b', context_len*seq_len_enc, d_model]
            latent_q_patched: tensor of shape [b', seq_len_dec, C]

        Returns:
            runs the transformer and returns mean, scale and decoder_output
        """
        raise NotImplementedError("JIT not supported yet.")

    def validate_causal(self, latent_q_patched: Tensor, encoded: Tensor) -> None:
        """
        Validate that the masking is causal
        """
        # run model
        masked_means, masked_scales, _ = self._get_transformer_output(
            encoded_patched=encoded, latent_q_patched=latent_q_patched
        )
        # run model iteratively
        current_inp = torch.full_like(latent_q_patched, fill_value=10.0)
        autoreg_means = torch.full_like(latent_q_patched, fill_value=10.0)
        autoreg_scales = torch.full_like(latent_q_patched, fill_value=10.0)

        for i in range(self.seq_len_dec):
            if i > 0:  # first token is the learnt StartSym
                current_inp[:, i - 1, :] = latent_q_patched[:, i - 1, :]
            mean_i, scale_i, _ = self._get_transformer_output(
                encoded_patched=encoded, latent_q_patched=current_inp
            )
            autoreg_means[:, i, :] = mean_i[:, i, :]
            autoreg_scales[:, i, :] = scale_i[:, i, :]

        isclose_means = autoreg_means.isclose(masked_means).all()
        isclose_scales = autoreg_scales.isclose(masked_scales).all()
        causal = isclose_means and isclose_scales
        if not causal:
            msg_mean = "" if isclose_means else "means"
            msg_scales = "" if isclose_scales else "scales"
            raise ValueError(
                f"Larger than expected discrepancy: {msg_mean} {msg_scales}"
            )

    def compress(
        self,
        *,
        latent_unquantized: Tensor,
        previous_latents: Sequence[PreviousLatent],
        run_decode: bool = False,
        validate_causal: bool = False,
    ) -> TemporalEntropyModelOut:
        """
        Compress and decompress autoregressively. Can only handle batch size 1.

        Args:
            latent_unquantized: unquantized latent of shape [1, C, H, W]
            previous_latents: a sequence of length at least 1 and at most `context_len`
                containingg PreviousLatent objects which hold 2 tensors:
                    - quantized: [1, C, H, W]
                    - processed: latents passed through the encoder, expected shape
                        [b', seq_len_enc, d_model]
            run_decode: bool, defaults to False. Whether to run the actual decoding

        Returns:
            TemporalEntropyModelOut object with the following components:
                - perturbed_latent: tensor of ints with same shape as the input tensor
                    `latent_unquatnized` -- [1, C, H, W]
                - bits: number of bits used to compress the input latent, float tensor
                - features: features tensor from the decoder, shape [1, d_model, H, W]
        """
        B, C, H, W = latent_unquantized.shape
        assert B == 1, "Cannot handle batch yet."

        encoded_seqs = self._get_encoded_seqs(previous_latents)
        # previously coded latents, shape is [b', 2*seq_len_enc, d_model], floats
        encoded = torch.cat(encoded_seqs, -2)

        latent_patched, (n_h, n_w) = self.patcher(
            latent_unquantized, self.window_size_dec
        )  # [b', seq_len, C]

        if validate_causal:
            self.validate_causal(latent_q_patched=latent_patched, encoded=encoded)

        # Encoding: compress to strings - strings is a list of len seq_len_dec (16)
        strings, extra = self._encode(latent_patched, encoded)
        means, scales, dec_output, quantized = extra.values()

        decoder_features = self.patcher.unpatch(
            (dec_output, (n_h, n_w)), crop=(H, W), channels_last=True
        )  # [1, d_model, H, W]

        # Count bits in each sequence, each string is a list of len 1
        bits = [sum(len(string[0]) * 8 for string in strings)]

        # decoding:
        if not run_decode:
            # For performance, since coding is lossless, real decode can be skipped
            decoded = torch.round(latent_unquantized)
        else:
            use_output_from_encode = True
            decoded = self._decode(
                strings,
                encoded,
                shape=(H, W, C),
                encode_means=means,
                encode_scales=scales,
                use_output_from_encode=use_output_from_encode,
            )
            dequantized = self.patcher.unpatch(
                x_patched=(quantized, (n_h, n_w)),
                crop=(H, W),
                channels_last=True,
            )
            if use_output_from_encode:
                # This should pass if `use_output_from_encode=True`
                assert (decoded == dequantized).all(), "Something went wrong!"

        return TemporalEntropyModelOut(
            perturbed_latent=decoded,
            bits=torch.tensor(bits, dtype=torch.float32),
            features=decoder_features,
        )

    def _encode(
        self, latent_patched: Tensor, encoded: Tensor
    ) -> Tuple[List[str], Dict[str, Tensor]]:
        """
        Compress patched latents to strings

        Args:
            latent_patched: unquantized latent of shape [b', seq_len_dec, C], where
                b' = #patches * actual batch size, which is 1 (for compress/decompress)
            encoded: the "features" used to code the latent, ie what we condition on
                to predict the distribution of the current latent. This should be a
                tensor of shape [b', context_len*seq_len_enc, d_model] with b'=

        Returns:
            strings (list of strings), extra (dict with tensors).
                NB: In theory, nothing in the  dict (e.g. means and scales) should be
                be used at decode time. In practice, the below ._decode method
                uses them, to avoid issues with non-determinism of the transformer.
        """

        strings = []
        # all 3 tensors are [b', seq_len_dec, C]
        quantized = torch.full_like(latent_patched, fill_value=10.0)
        autoreg_means = torch.full_like(latent_patched, fill_value=100.0)
        autoreg_scales = torch.full_like(latent_patched, fill_value=100.0)
        # the decoder output has shape [b', seq_len_dec, d_model]
        dec_output_shape = (*latent_patched.shape[:-1], self.d_model)
        dec_output = torch.full(
            dec_output_shape,
            fill_value=100.0,
            dtype=torch.float32,
            device=quantized.device,
        )

        prev_mean = None
        prev_scale = None

        # add 0 at the end to to code the very last symbol and then break
        for i in itertools.chain(range(self.seq_len_dec), [0]):
            if prev_mean is not None and prev_scale is not None:  # ensures i > 0
                latent_i = latent_patched[:, i - 1, :]
                # # !!! unsqueeze -- batch size 1
                quantized_i, string = self.bottleneck.compress(
                    inputs=latent_i.unsqueeze(0),
                    scales=prev_scale.unsqueeze(0),
                    means=prev_mean.unsqueeze(0),
                )  # [1, b', C] tensor of ints, [string] list of a string
                strings.append(string)  # each string is a list of len 1
                # quantized should contain the mean.
                quantized[:, i - 1, :] = quantized_i.squeeze(0)

                if i == 0:
                    break

            mean_i, scale_i, dec_output_i = self._get_transformer_output(
                encoded_patched=encoded, latent_q_patched=quantized
            )  # [b', seq_len_dec, C]*2, [b', seq_len_dec, d_model]

            prev_mean = autoreg_means[:, i, :] = mean_i[:, i, :]  # [b', C]
            prev_scale = autoreg_scales[:, i, :] = scale_i[:, i, :]  # [b', C]
            dec_output[:, i, :] = dec_output_i[:, i, :]  # assigning [b', d_model]

        # NOTE: `autoreg_means` and `autoreg_scales` must not be used at decode time.
        # However, due to transofmer non-determinism, we return them and allow "fake"
        # decoding by setting use_output_from_encode=True in `._decode`
        extra = {
            "means": autoreg_means,
            "scales": autoreg_scales,
            "dec_output": dec_output,
            "quantized": quantized,
        }
        return strings, extra

    def _decode(
        self,
        strings: List[str],
        encoded: Tensor,
        shape: Sequence[int],
        encode_means: Tensor,
        encode_scales: Tensor,
        use_output_from_encode: bool = True,
    ) -> Tensor:
        """
        Decompress strings

        Args:
            strings: list of strings, length should be seq_len_dec
            encoded: previous latents, passed to transformer
            shape: H, W, C
            encode_means: means from encode (compresss) step
            encode_scales: scales from encode (compresss) step
            use_output_from_encode: use encode_means and encode_scales from the encode
                step (not real compression), or use the transformer to compute them.
                Note there could be a discrepancy between the two due to the
                non-deterministic nature of transformers (which makes them suboptimal
                for pmf prediction and use in compression), so for research purposes it
                could be acceptable to use the output from the encoder.

        Returns:
            A tenosor of the decoded strings
        """
        H, W, C = shape
        _device = encode_means.device
        fake_patched = self.patcher(
            torch.ones((1, C, H, W), device=_device), self.window_size_dec
        )  # placeholder object Patched(tensor, num_patches)
        decompressed = torch.full_like(
            fake_patched.tensor, fill_value=10.0
        )  # [b', seq_len_dec, C]

        prev_mean = None
        prev_scale = None
        for i in itertools.chain(range(self.seq_len_dec), [0]):
            if prev_mean is not None and prev_scale is not None:
                decoded_i = self.bottleneck.decompress(
                    strings=strings.pop(0),
                    scales=prev_scale.unsqueeze(0),
                    means=prev_mean.unsqueeze(0),  #
                )  # decompressed i-th token, [b', C]
                decompressed[:, i - 1, :] = decoded_i
                if i == 0:
                    break
            # predict mean and scale for the i-th token, given previously decompressed
            mean_i, scale_i, _ = self._get_transformer_output(
                encoded_patched=encoded, latent_q_patched=decompressed
            )  # [b', seq_len_dec, C]x2

            target_mean, target_scale = encode_means[:, i, :], encode_scales[:, i, :]
            actual_mean, actual_scale = mean_i[:, i, :], scale_i[:, i, :]

            if use_output_from_encode:
                # NOTE: To deal with non-deterministm of the transformer, use the means
                # and the scales from encoding, and log errors. Note that this cannot
                # be done in practice.
                prev_mean = target_mean  # mean of current token, [b', C]
                prev_scale = target_scale  # scale of current token, [b', C]
                error_mean = (actual_mean - target_mean).abs().sum()
                error_scale = (actual_scale - target_scale).abs().sum()
                percent_of_total_mean = error_mean / target_mean.abs().sum()
                percent_of_total_scale = error_scale / target_scale.abs().sum()
                if percent_of_total_mean > 0.01 or percent_of_total_scale > 0.01:
                    print(
                        "Larger than expected discrepancy in transformer output found! ",
                        f"Decode step {i}: mean error = {100*percent_of_total_mean}%, ",
                        f"mean error = {100*percent_of_total_scale}%",
                        file=sys.stderr,
                    )
            else:
                prev_mean = actual_mean  # mean of current token, [b', C]
                prev_scale = actual_scale  # scale of current token, [b', C]

        assert not strings
        return self.patcher.unpatch(
            x_patched=(decompressed, fake_patched.num_patches),
            crop=(H, W),
            channels_last=True,
        )

    def update(self, force: bool = False) -> bool:
        """
        Updates the entropy bottleneck(s) CDF values.

        Needs to be called once after training to be able to later compress
        and decompress with an actual entropy coder.

        Args:
            force: overwrite previous values (default: False)

        Returns:
            updated: True if one of the bottlenecks was updated.
        """
        check = getattr(self.bottleneck, "update", None)
        if check is not None:
            bottleneck_updated = self.bottleneck.update(force=force)
        else:
            bottleneck_updated = False

        return bottleneck_updated
