import cv2
import h5py
import lz4.frame
import numpy as np
from io import BytesIO
from tqdm import tqdm
import sys
import os
import torch as th
import torch.nn as nn

class RGB2YCbCr(nn.Module):
    def __init__(self):
        super(RGB2YCbCr, self).__init__()

        kr = 0.299
        kg = 0.587
        kb = 0.114

        # The transformation matrix from RGB to YCbCr (ITU-R BT.601 conversion)
        self.register_buffer("matrix", th.tensor([
            [                  kr,                  kg,                    kb],
            [-0.5 * kr / (1 - kb), -0.5 * kg / (1 - kb),                  0.5],
            [                 0.5, -0.5 * kg / (1 - kr), -0.5 * kb / (1 - kr)]
        ]).t(), persistent=False)

        # Adjustments for each channel
        self.register_buffer("shift", th.tensor([0., 0.5, 0.5]), persistent=False)

    def forward(self, img):
        img = th.from_numpy(img).unsqueeze(0).float()
        return (th.tensordot(img.permute(0, 2, 3, 1), self.matrix, dims=1).permute(0, 3, 1, 2) + self.shift[None, :, None, None]).squeeze(0)

def compress_image(image, format='.jpg'):

    # Encode image to the specified format using OpenCV
    is_success, buffer = cv2.imencode(format, image.transpose(1, 2, 0) * 255.0)
    if is_success:
        return np.array(buffer)
    else:
        raise Exception("Failed to compress image")

def decompress_image(buffer, format='.jpg'):
    # Decode image from the specified format using OpenCV
    return cv2.imdecode(buffer, cv2.IMREAD_COLOR).transpose(2, 0, 1) / 255.0

def compress_dataset(source_path, destination_path):
    
    to_ycbcr = RGB2YCbCr()

    # Open source HDF5 file
    with h5py.File(source_path, 'r') as src_file:
        # Create destination HDF5 file

        with h5py.File(destination_path, 'w') as dst_file:

            with tqdm(total=5, desc="Copying non-image datasets") as pbar:
                for name in ['rgb_images', 'image_instance_indices', 'instance_mask_bboxes', 'instance_masks_images', 'sequence_indices']:
                    dst_file.copy(src_file[name], name)
                    pbar.update(1)

            with tqdm(total=len(src_file['rgb_images']), desc="computing fg mask") as pbar:
                rgb_images = src_file['rgb_images']
                shape = decompress_image(rgb_images[0]).shape
                foreground_masks = dst_file.create_dataset(
                    "foreground_mask",
                    (len(rgb_images), 1, shape[1], shape[2]),
                    dtype=np.uint8,
                    compression='gzip',
                    compression_opts=5,
                    chunks=(1, 1, shape[1], shape[2])
                )
                for i, image in enumerate(rgb_images):
                    image = decompress_image(image)
                    image = to_ycbcr(image)

                    cb = image[1]
                    cr = image[2]

                    cb_error = th.abs(cb - th.mean(cb))
                    cr_error = th.abs(cr - th.mean(cr))

                    color_error = th.maximum(cb_error, cr_error)
                    color_mask = th.relu(th.tanh((color_error - th.mean(color_error)) / th.std(color_error)))

                    # save y channel
                    cv2.imwrite(f'y{i:010d}.jpg', image[0].numpy() * 255)
                    # save cb channel
                    cv2.imwrite(f'cb{i:010d}.jpg', image[1].numpy() * 255)
                    # save cr channel
                    cv2.imwrite(f'cr{i:010d}.jpg', image[2].numpy() * 255)
                    # save color channel
                    cv2.imwrite(f'color{i:010d}.jpg', color_error.numpy() * 255)
                    # save color mask
                    cv2.imwrite(f'color_mask{i:010d}.jpg', color_mask.numpy() * 255)

                    foreground_masks[i] = (color_mask.numpy() * 255).astype(np.uint8)
                    pbar.update(1)

if __name__ == "__main__":

    assert len(sys.argv) == 3, "Please provide source and destination paths"
    assert sys.argv[1] != sys.argv[2], "Source and destination paths must be different"

    # make shure output file is not already there
    assert not os.path.exists(sys.argv[2]), "Destination file already exists"

    # comand lines are the source and destination paths
    compress_dataset(sys.argv[1], sys.argv[2])
