import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch_dct import dct_2d, idct_2d, dct, idct
from utils.utils import LambdaModule

# RGB to YCbCr
class RGB2YCbCr(nn.Module):
    def __init__(self):
        super(RGB2YCbCr, self).__init__()

        kr = 0.299
        kg = 0.587
        kb = 0.114

        # The transformation matrix from RGB to YCbCr (ITU-R BT.601 conversion)
        self.register_buffer("matrix", th.tensor([
            [                  kr,                  kg,                    kb],
            [-0.5 * kr / (1 - kb), -0.5 * kg / (1 - kb),                  0.5],
            [                 0.5, -0.5 * kg / (1 - kr), -0.5 * kb / (1 - kr)]
        ]).t(), persistent=False)

        # Adjustments for each channel
        self.register_buffer("shift", th.tensor([0., 0.5, 0.5]), persistent=False)

    def forward(self, img):
        if len(img.shape) != 4 or img.shape[1] != 3:
            raise ValueError('Input image must be 4D tensor with a size of 3 in the second dimension.')

        return th.tensordot(img.permute(0, 2, 3, 1), self.matrix, dims=1).permute(0, 3, 1, 2) + self.shift[None, :, None, None]


# YCbCr to RGB
class YCbCr2RGB(nn.Module):
    def __init__(self):
        super(YCbCr2RGB, self).__init__()

        kr = 0.299
        kg = 0.587
        kb = 0.114

        # The transformation matrix from YCbCr to RGB (ITU-R BT.601 conversion)
        self.register_buffer("matrix", th.tensor([
            [1,                       0,              2 - 2 * kr],
            [1, -kb / kg * (2 - 2 * kb), -kr / kg * (2 - 2 * kr)],
            [1,              2 - 2 * kb,                       0]
        ]).t(), persistent=False)

        # Adjustments for each channel
        self.register_buffer("shift", th.tensor([0., 0.5, 0.5]), persistent=False)

    def forward(self, img):
        if len(img.shape) != 4 or img.shape[1] != 3:
            raise ValueError('Input image must be 4D tensor with a size of 3 in the second dimension.')

        result = th.tensordot((img - self.shift[None, :, None, None]).permute(0, 2, 3, 1), self.matrix, dims=1).permute(0, 3, 1, 2)

        # Clamp the results to the valid range for RGB [0, 1]
        return th.clamp(result, 0, 1)

class JpegTransfrom(nn.Module):
    def __init__(self, level_Y, level_Cb, level_Cr, mode='forward', patch_size=8, type='RGB'):
        super(JpegTransfrom, self).__init__()
        assert mode in ['forward', 'backward'], "Mode must be either 'forward' or 'backward'"
        self.mode = mode
        self.patch_size = patch_size
        self.type = type.lower()

        self.to_YCbCr = RGB2YCbCr()
        self.to_RGB = YCbCr2RGB()

        self.indices = th.nonzero(th.stack((self.triu(level_Y), self.triu(level_Cb), self.triu(level_Cr))))

        self.register_buffer('stats', th.zeros(5, patch_size, patch_size).double())

    def triu(self, level):
        patch = th.zeros((self.patch_size, self.patch_size))
        if level > 0:
            triu  = th.triu(th.ones((level, level))).flip(1)
            patch[:level, :level] = triu

        return patch

    def forward(self, x):
        if self.mode == 'forward':
            if self.type != 'rgb':
                x = x.repeat(1, 3, 1, 1)

            return self.dct2d_and_extract(self.to_YCbCr(x))
        else:
            x = self.reconstruct_and_idct2d(x)

            if self.type != 'rgb':
                x[:, 1:] = 0.5

            x = self.to_RGB(x)

            if self.type != 'rgb':
                x = th.mean(x, dim=1, keepdim=True)

            return x

    # TODO remove padding and use only Y for greyscale in dct

    def dct2d_and_extract(self, x):
        B, _, H, W = x.shape

        x = F.pad(x, (0, x.shape[-1] % self.patch_size, 0, x.shape[-2] % self.patch_size))
        x = rearrange(x, 'b c (h p1) (w p2) -> (b h w) c p1 p2', p1=self.patch_size, p2=self.patch_size)

        X = dct_2d(x, norm='ortho')
        
        with th.no_grad():
            z = rearrange(X, 'b c h w -> (b c) h w').double().detach()
            self.stats[0] = th.minimum(reduce(z, 'b h w -> h w', 'min'), self.stats[0])
            self.stats[1] = self.stats[1]+th.sum(z, dim=0)
            self.stats[2] = self.stats[2]+th.sum(z**2, dim=0)
            self.stats[3] = self.stats[3]+th.sum(th.ones_like(z), dim=0)
            self.stats[4] = th.maximum(reduce(z, 'b h w -> h w', 'max'), self.stats[4])

            # compute patchsize x patchsize heatmaps
            heatmap_min = self.stats[0].detach().cpu().numpy()
            heatmap_mean = (self.stats[1]/self.stats[3]).detach().cpu().numpy()
            heatmap_std = th.sqrt(self.stats[2]/self.stats[3] - (self.stats[1]/self.stats[3])**2).detach().cpu().numpy()
            heatmap_max = self.stats[4].detach().cpu().numpy()

            # reascale the heatmaps to [0, 1]
            heatmap_min = (heatmap_min - heatmap_min.min()) / (heatmap_min.max() - heatmap_min.min())
            heatmap_mean = (heatmap_mean - heatmap_mean.min()) / (heatmap_mean.max() - heatmap_mean.min())
            heatmap_std = (heatmap_std - heatmap_std.min()) / (heatmap_std.max() - heatmap_std.min())
            heatmap_max = (heatmap_max - heatmap_max.min()) / (heatmap_max.max() - heatmap_max.min())

            #compute logscaled variants of the heatmaps
            heatmap_min_log = np.log(heatmap_min + 1e-6)
            heatmap_mean_log = np.log(heatmap_mean + 1e-6)
            heatmap_std_log = np.log(heatmap_std + 1e-6)
            heatmap_max_log = np.log(heatmap_max + 1e-6)

            # rescale the log heatmaps to [0, 1]
            heatmap_min_log = (heatmap_min_log - heatmap_min_log.min()) / (heatmap_min_log.max() - heatmap_min_log.min())
            heatmap_mean_log = (heatmap_mean_log - heatmap_mean_log.min()) / (heatmap_mean_log.max() - heatmap_mean_log.min())
            heatmap_std_log = (heatmap_std_log - heatmap_std_log.min()) / (heatmap_std_log.max() - heatmap_std_log.min())
            heatmap_max_log = (heatmap_max_log - heatmap_max_log.min()) / (heatmap_max_log.max() - heatmap_max_log.min())

            # interpolatet the heatmaps to 512 x 512
            heatmap_min = cv2.resize(heatmap_min, (512, 512), interpolation=cv2.INTER_AREA)
            heatmap_mean = cv2.resize(heatmap_mean, (512, 512), interpolation=cv2.INTER_AREA)
            heatmap_std = cv2.resize(heatmap_std, (512, 512), interpolation=cv2.INTER_AREA)
            heatmap_max = cv2.resize(heatmap_max, (512, 512), interpolation=cv2.INTER_AREA)
            heatmap_min_log = cv2.resize(heatmap_min_log, (512, 512), interpolation=cv2.INTER_AREA)
            heatmap_mean_log = cv2.resize(heatmap_mean_log, (512, 512), interpolation=cv2.INTER_AREA)
            heatmap_std_log = cv2.resize(heatmap_std_log, (512, 512), interpolation=cv2.INTER_AREA)
            heatmap_max_log = cv2.resize(heatmap_max_log, (512, 512), interpolation=cv2.INTER_AREA)

            # color the heatmaps with cv2
            heatmap_min = cv2.applyColorMap((heatmap_min*255).astype(np.uint8), cv2.COLORMAP_JET)
            heatmap_mean = cv2.applyColorMap((heatmap_mean*255).astype(np.uint8), cv2.COLORMAP_JET)
            heatmap_std = cv2.applyColorMap((heatmap_std*255).astype(np.uint8), cv2.COLORMAP_JET)
            heatmap_max = cv2.applyColorMap((heatmap_max*255).astype(np.uint8), cv2.COLORMAP_JET)
            heatmap_min_log = cv2.applyColorMap((heatmap_min_log*255).astype(np.uint8), cv2.COLORMAP_JET)
            heatmap_mean_log = cv2.applyColorMap((heatmap_mean_log*255).astype(np.uint8), cv2.COLORMAP_JET)
            heatmap_std_log = cv2.applyColorMap((heatmap_std_log*255).astype(np.uint8), cv2.COLORMAP_JET)
            heatmap_max_log = cv2.applyColorMap((heatmap_max_log*255).astype(np.uint8), cv2.COLORMAP_JET)

            # save to out/
            cv2.imwrite('out/heatmap_min.png', heatmap_min)
            cv2.imwrite('out/heatmap_mean.png', heatmap_mean)
            cv2.imwrite('out/heatmap_std.png', heatmap_std)
            cv2.imwrite('out/heatmap_max.png', heatmap_max)
            cv2.imwrite('out/heatmap_min_log.png', heatmap_min_log)
            cv2.imwrite('out/heatmap_mean_log.png', heatmap_mean_log)
            cv2.imwrite('out/heatmap_std_log.png', heatmap_std_log)
            cv2.imwrite('out/heatmap_max_log.png', heatmap_max_log)




            # print min, mean +- std, max for each channel
            #print('i, min, mean, std, max')
            #for i in range(self.patch_size**2):
            #    print(f'{i}, {self.stats[i, 0].item():.2e}, {(self.stats[i, 1]/self.stats[i, 3]).item():.2e}, {th.sqrt(self.stats[i, 2]/self.stats[i, 3] - (self.stats[i, 1]/self.stats[i, 3])**2).item():.2e}, {self.stats[i, 4].item():.2e}')

            #print(f"number of patches: {self.stats[0, 3].item():.2e}")


        X_extracted = X[:, self.indices[:, 0], self.indices[:, 1], self.indices[:, 2]]
        X_extracted = rearrange(X_extracted, '(b h w) c -> b c h w', b=B, h=H//self.patch_size, w=W//self.patch_size)

        return X_extracted

    def reconstruct_and_idct2d(self, X):
        B, _, H, W = X.shape

        X = rearrange(X, 'b c h w -> (b h w) c')
        patches = th.zeros((B*H*W, 3, self.patch_size, self.patch_size), dtype=X.dtype, device=X.device)
        patches[th.arange(B*H*W)[:,None], self.indices[:, 0], self.indices[:, 1], self.indices[:, 2]] = X

        x = idct_2d(patches, norm='ortho')

        x = rearrange(x, '(b h w) c p1 p2 -> b c (h p1) (w p2)', b=B, h=H, w=W)
        return th.real(x)

class LociInputOutputCompression(nn.Module):
    """
    The LociInputOutputCompression reduces the resolution of a input by 16x 
    and compresses it using JPEG or interpolates it using bilinear interpolation.

    Attributes
    ----------
    type : str
        Input type of the module, should be one of ['RGB', 'mask', 'depth'].
    mode : str
        Compression mode of the module, should be one of ['low', 'medium', 'high'].
    direction : str
        Compression direction, should be one of ['forward', 'backward'].

    Compression Modes
    -----------------
    * 'low' : 6.25% of the original size.
        - For 'RGB' type: level_Y=8, level_Cb=3, level_Cr=3.
        - For 'mask' and 'depth' types: interpolation is used.
    * 'medium' : 21.4% of the original size.
        - For 'RGB' type: level_Y=14, level_Cb=3, level_Cr=3.
        - For 'mask' type: level_Y=13, level_Cb=0, level_Cr=0.
        - For 'depth' type: level_Y=11, level_Cb=0, level_Cr=0.
    * 'high' : 33.75% of the original size.
        - For 'RGB' type: level_Y=16, level_Cb=7, level_Cr=7.
        - For 'mask' and 'depth' types: level_Y=15, level_Cb=0, level_Cr=0.
    """

    def __init__(self, type='RGB', mode='medium', direction='forward'):
        super(LociInputOutputCompression, self).__init__()
        self.type = type.lower()
        self.mode = mode.lower()
        self.direction = direction.lower()

        if self.mode == 'low':
            if self.type == 'rgb':
                self.transform = JpegTransfrom(level_Y=8, level_Cb=3, level_Cr=3, mode=direction, type=self.type, patch_size=16)

            elif self.type in ['mask', 'depth'] and mode == 'forward':
                self.transform = nn.Sequential(
                    LambdaModule(lambda x: F.interpolate(x, scale_factor=0.25, mode='bilinear', align_corners=False)),
                    nn.PixelUnshuffle(4)
                )

            elif self.type in ['mask', 'depth']  and mode == 'backward':
                self.transform = nn.Sequential(
                    LambdaModule(lambda x: F.interpolate(x, scale_factor=4, mode='bilinear', align_corners=False)),
                    nn.PixelShuffle(4)
                )

        elif self.mode == 'medium':
            if self.type == 'rgb':
                self.transform = JpegTransfrom(level_Y=14, level_Cb=3, level_Cr=3, mode=direction, type=self.type, patch_size=16)
            elif self.type == 'mask':
                self.transform = JpegTransfrom(level_Y=13, level_Cb=0, level_Cr=0, mode=direction, type=self.type, patch_size=16)
            elif self.type == 'depth':
                self.transform = JpegTransfrom(level_Y=11, level_Cb=0, level_Cr=0, mode=direction, type=self.type, patch_size=16)

        elif self.mode == 'high':
            if self.type == 'rgb':
                self.transform = JpegTransfrom(level_Y=16, level_Cb=7, level_Cr=7, mode=direction, type=self.type, patch_size=16)
            elif self.type in ['mask', 'depth']:
                self.transform = JpegTransfrom(level_Y=15, level_Cb=0, level_Cr=0, mode=direction, type=self.type, patch_size=16)

        else:
            raise ValueError(f'Unknown mode {self.mode}')

    @classmethod # FIXME
    def num_channels(type, mode):
        if mode == "low" and type == "rgb":
            return 48
        elif mode == "low" and type in ["mask", "depth"]:
            return 16
        elif mode == "medium" and type == "rgb":
            return 117
        elif mode == "medium" and type == "mask":
            return 91
        elif mode == "medium" and type == "depth":
            return 66
        elif mode == "high" and type == "rgb":
            return 192
        elif mode == "high" and type in ["mask", "depth"]:
            return 120
        else:
            raise ValueError(f'Unknown mode {mode} or type {type}')
        
    def forward(self, x):
        return self.transform(x)

# TODO develop most a skip that includes the most significant frequencies

"""
if __name__ == '__main__':

    from data.datasets.hdf5_lightning_objects import HDF5_Dataset
    from utils.optimizers import Ranger
    from torch.utils.data import Dataset, DataLoader
    from utils.loss import L1SSIMLoss, MaskedL1SSIMLoss
    from utils.io import Timer, UEMA
    from nn.convnext_v2 import *
    import cv2

    DEVICE = th.device('cuda:1' if th.cuda.is_available() else 'cpu')

    size = (256,256)
    dataset = HDF5_Dataset('/media/chief/data/Datasets-HDF5-Compressed/Kubric-Datasets/movi-e-train-256x256.hdf5', size)

    #base_channels = 32
    #blocks= [2,4,6,4]
    base_channels = 16
    blocks= [1,2,3,2]

    default_autoencoder = nn.Sequential(
        PatchDownscale(1, base_channels),
        *[ConvNeXtBlock(base_channels) for _ in range(blocks[0])],
        PatchDownscale(base_channels, base_channels * 2, 2),
        *[ConvNeXtBlock(base_channels * 2) for _ in range(blocks[1])],
        PatchDownscale(base_channels * 2, base_channels * 4, 2),
        *[ConvNeXtBlock(base_channels * 4) for _ in range(blocks[2])],
        PatchDownscale(base_channels * 4, base_channels * 8, 2),
        *[ConvNeXtBlock(base_channels * 8) for _ in range(blocks[3])],
        #nn.Conv2d(base_channels * 8, 4, 1),
        #nn.Conv2d(4, base_channels * 8, 1),
        #*[ConvNeXtBlock(base_channels * 8) for _ in range(blocks[3])],
        PatchUpscale(base_channels * 8, base_channels * 4, 2),
        *[ConvNeXtBlock(base_channels * 4) for _ in range(blocks[2])],
        PatchUpscale(base_channels * 4, base_channels * 2, 2),
        *[ConvNeXtBlock(base_channels * 2) for _ in range(blocks[1])],
        PatchUpscale(base_channels * 2, base_channels, 2),
        *[ConvNeXtBlock(base_channels) for _ in range(blocks[0])],
        PatchUpscale(base_channels, 1, 4),
    ).to(DEVICE)

    input_channels = 66# LociInputOutputCompression.num_channels("depth", "medium")

    compressed_decoder = nn.Sequential(
        nn.Conv2d(input_channels, base_channels * 4, 1),
        *[ConvNeXtBlock(base_channels * 4) for _ in range(blocks[2])],
        PatchDownscale(base_channels * 4, base_channels * 8, 2),
        *[ConvNeXtBlock(base_channels * 8) for _ in range(blocks[3])],
        #nn.Conv2d(base_channels * 8, 4, 1),
        #nn.Conv2d(4, base_channels * 8, 1),
        #*[ConvNeXtBlock(base_channels * 8) for _ in range(blocks[3])],
        PatchUpscale(base_channels * 8, base_channels * 4, 2),
        *[ConvNeXtBlock(base_channels * 4) for _ in range(blocks[2])],
        nn.Conv2d(base_channels * 4, input_channels, 1),
    ).to(DEVICE)

    compress = LociInputOutputCompression("depth", "medium", "forward").to(DEVICE)
    decompress = LociInputOutputCompression("depth", "medium", "backward").to(DEVICE)

    default_optimizer = Ranger(default_autoencoder.parameters(), lr=3e-4)
    compressed_optimizer = Ranger(compressed_decoder.parameters(), lr=3e-4)

    print(f'Default Parameters: {sum(p.numel() for p in default_autoencoder.parameters())}')
    print(f'Compressed Parameters: {sum(p.numel() for p in compressed_decoder.parameters())}')

    dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

    l1ssim = L1SSIMLoss().to(DEVICE)

    avg_loss = UEMA()
    avg_rgb_l1 = UEMA()
    avg_rgb_ssim = UEMA()

    timer = Timer()

    mode = "compressed_depth"
    num_updates = 0

    for epoch in range(100):
        for i, batch in enumerate(dataloader):
            num_updates += 1

            rgb = batch[1].to(DEVICE) # depth
            
            if mode == "default_depth":
                rgb_out = default_autoencoder(rgb)

                loss, l1, ssim = l1ssim(rgb_out, rgb)

                default_optimizer.zero_grad()
                loss.backward()
                default_optimizer.step()

            else:
                with th.no_grad():
                    rgb_compressed = compress(rgb)

                rgb_out_compressed = compressed_decoder(rgb_compressed)

                with th.no_grad():
                    rgb_out = decompress(rgb_out_compressed)

                _, l1, ssim = l1ssim(rgb_out, rgb)

                loss = th.mean((rgb_out_compressed - rgb_compressed)**2)

                compressed_optimizer.zero_grad()
                loss.backward()
                compressed_optimizer.step()

            avg_loss.update(loss.item())
            avg_rgb_l1.update(l1.item())
            avg_rgb_ssim.update(ssim.item())

            print("[{}|{}|{}|{:.2f}%] {}, Loss: {:.2e}, L1: {:.2e}, SSIM: {:.2e}".format(
                epoch, i, len(dataloader), i/len(dataloader)*100,
                str(timer),
                float(avg_loss),
                float(avg_rgb_l1), 
                float(avg_rgb_ssim), 
            ), flush=True)

            if i % 100 == 0:
                if mode != "default_depth":
                    with th.no_grad():
                        rgb_dec = decompress(rgb_compressed)
                for b in range(rgb.shape[0]):
                    cv2.imwrite(f'out/rgb_{mode}_{int(np.log2(num_updates)):03d}_{b:06d}.jpg', rgb[b].permute(1,2,0).detach().cpu().numpy() * 255)
                    cv2.imwrite(f'out/rgb_out_{mode}_{int(np.log2(num_updates)):03d}_{b:06d}.jpg', rgb_out[b].permute(1,2,0).detach().cpu().numpy() * 255)
                    if mode != "default_depth":
                        cv2.imwrite(f'out/rgb_compressed_{mode}_{int(np.log2(num_updates)):03d}_{b:06d}.jpg', rgb_dec[b].permute(1,2,0).detach().cpu().numpy() * 255)

                    abs_error = th.clip(th.mean(th.abs(rgb_out[b] - rgb[b]), dim=0) * 5, 0, 1)
                    cv2.imwrite(f'out/abs_error_{mode}_{int(np.log2(num_updates)):03d}_{b:06d}.jpg', abs_error.detach().cpu().numpy() * 255)

            if i % 10000 == 0:
                th.save(default_autoencoder.state_dict(), f'out/{mode}_autoencoder{epoch:03d}_{i:06d}.pth')
"""

#"""
if __name__ == '__main__':

    from data.datasets.hdf5_lightning_objects import HDF5_Dataset
    from utils.optimizers import Ranger
    from torch.utils.data import Dataset, DataLoader
    from utils.loss import L1SSIMLoss, MaskedL1SSIMLoss
    from utils.io import Timer, UEMA
    from nn.convnext_v2 import *
    import cv2

    DEVICE = th.device('cuda:1' if th.cuda.is_available() else 'cpu')

    size = (256,256)
    dataset = HDF5_Dataset('/media/chief/data/Datasets-HDF5-Compressed/Kubric-Datasets/movi-e-train-256x256.hdf5', size)

    #base_channels = 32
    #blocks= [2,4,6,4]
    base_channels = 12
    blocks= [1,2,3,2]

    default_autoencoder = nn.Sequential(
        PatchDownscale(3, base_channels),
        *[ConvNeXtBlock(base_channels) for _ in range(blocks[0])],
        PatchDownscale(base_channels, base_channels * 2, 2),
        *[ConvNeXtBlock(base_channels * 2) for _ in range(blocks[1])],
        PatchDownscale(base_channels * 2, base_channels * 4, 2),
        *[ConvNeXtBlock(base_channels * 4) for _ in range(blocks[2])],
        PatchDownscale(base_channels * 4, base_channels * 8, 2),
        *[ConvNeXtBlock(base_channels * 8) for _ in range(blocks[3])],
        #nn.Conv2d(base_channels * 8, 4, 1),
        #nn.Conv2d(4, base_channels * 8, 1),
        #*[ConvNeXtBlock(base_channels * 8) for _ in range(blocks[3])],
        PatchUpscale(base_channels * 8, base_channels * 4, 2),
        *[ConvNeXtBlock(base_channels * 4) for _ in range(blocks[2])],
        PatchUpscale(base_channels * 4, base_channels * 2, 2),
        *[ConvNeXtBlock(base_channels * 2) for _ in range(blocks[1])],
        PatchUpscale(base_channels * 2, base_channels, 2),
        *[ConvNeXtBlock(base_channels) for _ in range(blocks[0])],
        PatchUpscale(base_channels, base_channels // 2, 2),
        *[ConvNeXtBlock(base_channels // 2) for _ in range(blocks[0]//2)],
        PatchUpscale(base_channels // 2, 3, 2),
    ).to(DEVICE)

    input_channels = 117# LociInputOutputCompression.num_channels("rgb", "medium")

    compressed_decoder = nn.Sequential(
        nn.Conv2d(input_channels, base_channels * 4, 1),
        *[ConvNeXtBlock(base_channels * 4) for _ in range(blocks[2])],
        PatchDownscale(base_channels * 4, base_channels * 8, 2),
        *[ConvNeXtBlock(base_channels * 8) for _ in range(blocks[3])],
        #nn.Conv2d(base_channels * 8, 4, 1),
        #nn.Conv2d(4, base_channels * 8, 1),
        #*[ConvNeXtBlock(base_channels * 8) for _ in range(blocks[3])],
        PatchUpscale(base_channels * 8, base_channels * 4, 2),
        *[ConvNeXtBlock(base_channels * 4) for _ in range(blocks[2])],
        nn.Conv2d(base_channels * 4, input_channels, 1),
    ).to(DEVICE)

    compress = LociInputOutputCompression("rgb", "medium", "forward").to(DEVICE)
    decompress = LociInputOutputCompression("rgb", "medium", "backward").to(DEVICE)

    default_optimizer = Ranger(default_autoencoder.parameters(), lr=3e-4)
    compressed_optimizer = Ranger(compressed_decoder.parameters(), lr=3e-4)

    print(f'Default Parameters: {sum(p.numel() for p in default_autoencoder.parameters())}')
    print(f'Compressed Parameters: {sum(p.numel() for p in compressed_decoder.parameters())}')

    dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=4)

    l1ssim = L1SSIMLoss().to(DEVICE)

    avg_loss = UEMA()
    avg_rgb_l1 = UEMA()
    avg_rgb_ssim = UEMA()

    timer = Timer()

    mode = "jpeg_loss"
    num_updates = 0

    for epoch in range(100):
        for i, batch in enumerate(dataloader):
            num_updates += 1

            rgb = batch[0].to(DEVICE)
            
            if mode == "default":
                rgb_out = default_autoencoder(rgb)

                loss, l1, ssim = l1ssim(rgb_out, rgb)

                default_optimizer.zero_grad()
                loss.backward()
                default_optimizer.step()

            elif mode == "jpeg_loss":
                rgb_out = default_autoencoder(rgb)

                with th.no_grad():
                    rgb_compressed = compress(rgb)

                rgb_out_compressed = compress(rgb_out)

                loss = th.mean((rgb_out_compressed - rgb_compressed)**2) * 0.85
                loss = loss + th.mean(th.abs(rgb_out - rgb)) * 0.15

                _, l1, ssim = l1ssim(rgb_out, rgb)

                default_optimizer.zero_grad()
                loss.backward()
                default_optimizer.step()
            
            else:
                with th.no_grad():
                    rgb_compressed = compress(rgb)

                rgb_out_compressed = compressed_decoder(rgb_compressed)

                with th.no_grad():
                    rgb_out = decompress(rgb_out_compressed)

                _, l1, ssim = l1ssim(rgb_out, rgb)

                loss = th.mean((rgb_out_compressed - rgb_compressed)**2)

                compressed_optimizer.zero_grad()
                loss.backward()
                compressed_optimizer.step()

            avg_loss.update(loss.item())
            avg_rgb_l1.update(l1.item())
            avg_rgb_ssim.update(ssim.item())

            print("[{}|{}|{}|{:.2f}%] {}, Loss: {:.2e}, L1: {:.2e}, SSIM: {:.2e}".format(
                epoch, i, len(dataloader), i/len(dataloader)*100,
                str(timer),
                float(avg_loss),
                float(avg_rgb_l1), 
                float(avg_rgb_ssim), 
            ), flush=True)

            if i % 100 == 0:
                if mode != "default":
                    with th.no_grad():
                        rgb_dec = decompress(rgb_compressed)
                for b in range(rgb.shape[0]):
                    cv2.imwrite(f'out/rgb_{mode}_{int(np.log2(num_updates)):03d}_{b:06d}.jpg', rgb[b].permute(1,2,0).detach().cpu().numpy() * 255)
                    cv2.imwrite(f'out/rgb_out_{mode}_{int(np.log2(num_updates)):03d}_{b:06d}.jpg', rgb_out[b].permute(1,2,0).detach().cpu().numpy() * 255)
                    if mode != "default":
                        cv2.imwrite(f'out/rgb_compressed_{mode}_{int(np.log2(num_updates)):03d}_{b:06d}.jpg', rgb_dec[b].permute(1,2,0).detach().cpu().numpy() * 255)

                    abs_error = th.clip(th.mean(th.abs(rgb_out[b] - rgb[b]), dim=0) * 5, 0, 1)
                    cv2.imwrite(f'out/abs_error_{mode}_{int(np.log2(num_updates)):03d}_{b:06d}.jpg', abs_error.detach().cpu().numpy() * 255)

            if i % 10000 == 0:
                th.save(default_autoencoder.state_dict(), f'out/{mode}_autoencoder{epoch:03d}_{i:06d}.pth')
#"""
