import cv2
import h5py
import lz4.frame
import numpy as np
from io import BytesIO
from tqdm import tqdm
import sys
import os
from collections import defaultdict
import torch
from PIL import Image
from torchvision.transforms import ToTensor
from zoedepth.models.builder import build_model
from zoedepth.utils.config import get_config
import torch.nn.functional as F

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

def compress_image(image, format='.jpg'):

    # Encode image to the specified format using OpenCV
    is_success, buffer = cv2.imencode(format, image.transpose(1, 2, 0) * 255.0)
    if is_success:
        return np.array(buffer)
    else:
        raise Exception("Failed to compress image")

def decompress_image(compressed_image):
    # Decode image from the specified format using OpenCV
    image = cv2.imdecode(compressed_image, cv2.IMREAD_UNCHANGED)
    if image is not None:
        return image.transpose(2, 0, 1) / 255.0
    else:
        raise Exception("Failed to decompress image")

def predict_depth(model, image, bboxes, device=DEVICE):

    image = torch.from_numpy(image).to(device).unsqueeze(0)
    
    # Predict
    with torch.no_grad():
        
        out = model.infer(image)

        log_depth  = torch.log(out)
        depth_avg  = torch.mean(log_depth)
        depth_std  = torch.std(log_depth)

        norm_sigmoid_depth = 1 / (1 + torch.exp((log_depth - depth_avg) / depth_std))

    return norm_sigmoid_depth.squeeze(0).cpu().numpy()

def process_dataset(source_path, destination_path):
    
    tmp_path = destination_path + '.tmp.hdf5'

    # Open source HDF5 file
    with h5py.File(source_path, 'r') as src_file:
        # Create destination HDF5 file

        with h5py.File(tmp_path, 'w') as dst_file:

            with tqdm(total=7, desc="Copying non-image datasets") as pbar:
                for name in ['image_instance_indices', 'instance_mask_bboxes', 'instance_masks_images', 'sa_index', 'foreground_mask', 'rgb_images', 'instance_masks']:
                    dst_file.copy(src_file[name], name)
                    pbar.update(1)

            with tqdm(total=len(src_file['rgb_images']), desc="compute depth") as pbar:
                rgb_images = src_file['rgb_images']
                compressed_depth_images = dst_file.create_dataset('depth_images', (len(rgb_images),), dtype=h5py.vlen_dtype(np.dtype('uint8')))
                for i, image in enumerate(rgb_images):
                    img = decompress_image(image)
                    pbar.update(1)

if __name__ == "__main__":

    assert len(sys.argv) == 3, "Please provide source and destination paths"
    assert sys.argv[1] != sys.argv[2], "Source and destination paths must be different"

    # make shure output file is not already there
    assert not os.path.exists(sys.argv[2]), "Destination file already exists"

    # comand lines are the source and destination paths
    compress_dataset(sys.argv[1], sys.argv[2])

