#!/usr/bin/env python3

import sys
from pathlib import Path
project_root = str(Path(__file__).resolve().parent.parent)
if project_root not in sys.path:
    sys.path.append(project_root)
import project_config

import os
import requests
import h5py
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm

def download_and_extract_nyu(data_root):
    """
    Downloads NYU Depth V2 labeled dataset and extracts:
    1. Images to /images/
    2. Depths to /depths/
    3. Metadata to metadata.csv (GUARANTEES ALIGNMENT)
    """
    url = "http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/nyu_depth_v2_labeled.mat"
    filename = "nyu_depth_v2_labeled.mat"
    filepath = os.path.join(data_root, filename)
    
    img_dir = os.path.join(data_root, "images")
    depth_dir = os.path.join(data_root, "depths")
    csv_path = os.path.join(data_root, "metadata.csv")

    # --- 1. Download ---
    if not os.path.exists(filepath):
        print(f"Downloading {filename}...")
        os.makedirs(data_root, exist_ok=True)
        response = requests.get(url, stream=True)
        total_size = int(response.headers.get('content-length', 0))
        with open(filepath, 'wb') as f, tqdm(total=total_size, unit='B', unit_scale=True) as bar:
            for chunk in response.iter_content(chunk_size=1024*1024):
                if chunk:
                    f.write(chunk)
                    bar.update(len(chunk))
    else:
        print(f"Found {filename}, skipping download.")

    # --- 2. Extract Data & Create Metadata CSV ---
    print(f"Extracting data to {data_root}...")
    os.makedirs(img_dir, exist_ok=True)
    os.makedirs(depth_dir, exist_ok=True)

    metadata_rows = []

    with h5py.File(filepath, 'r') as f:
        # Load Labels
        print("  Reading labels...")
        raw_scenes = []
        ref_array = f['sceneTypes'][0]
        for ref in ref_array:
            obj = f[ref]
            scene_str = ''.join(chr(c) for c in obj[:].flatten())
            raw_scenes.append(scene_str)
        
        images = f['images']
        depths = f['depths']
        num_samples = images.shape[0]

        # now go through the items visible in each image
        unique_item_names = []
        ref_array = f['names'][0]
        for ref in ref_array:
            obj = f[ref]
            item_str = ''.join(chr(c) for c in obj[:].flatten())
            unique_item_names.append(item_str)
        print(f"  Found {len(unique_item_names)} unique item names.")
        multiclass_array = np.zeros((num_samples, len(unique_item_names)))
        for i, obj in enumerate(f['labels']):
            # check which classes are in each image
            unique_values = np.unique(obj)
            # drop the zero
            unique_values = unique_values[unique_values != 0] - 1  # zero-based index
            multiclass_array[i, unique_values] = 1
        
        print(f"  Processing {num_samples} samples (Images + Depth + CSV)...")
        
        for i in tqdm(range(num_samples)):
            # Filenames
            fname_img = f"{i:04d}.png"
            fname_depth = f"{i:04d}_depth.png"
            
            # 1. Extract RGB
            # Only save if not exists to save time on re-runs, or overwrite to be safe
            if not os.path.exists(os.path.join(img_dir, fname_img)):
                img_data = np.transpose(images[i], (2, 1, 0)) 
                img = Image.fromarray(img_data.astype('uint8'))
                img.save(os.path.join(img_dir, fname_img))
            
            # 2. Extract Depth
            if not os.path.exists(os.path.join(depth_dir, fname_depth)):
                depth_data = np.transpose(depths[i], (1, 0))
                depth_mm = (depth_data * 1000).astype('uint16')
                Image.fromarray(depth_mm, mode='I;16').save(os.path.join(depth_dir, fname_depth))
            
            # 3. Add to Metadata
            temp_df_sample = pd.DataFrame({
                'sample_index': i,
                'image_name': fname_img,
                'depth_name': fname_depth,
                'scene_label': raw_scenes[i]
            }, index=[i])
            temp_df_items = pd.DataFrame(multiclass_array[i:i+1], columns=unique_item_names, index=[i])
            metadata_rows.append(pd.concat([temp_df_sample, temp_df_items], axis=1).iloc[0])

    # --- 3. Save CSV ---
    df = pd.DataFrame(metadata_rows)
    print(df.head())
    df.to_csv(csv_path, index=False)
    print(f"✓ Metadata saved to {csv_path}")
    print("Alignment is now guaranteed via this CSV file.")

if __name__ == "__main__":
    # Change this path to your actual data directory
    DATA_ROOT = project_config.MM_BENCHMARKS_DATA_ROOT
    download_and_extract_nyu(DATA_ROOT)