import torch
from torch.utils.data import Dataset
from PIL import Image
import numpy as np
import pandas as pd
import os
from sklearn.preprocessing import LabelEncoder

class SUNRGBDDataset(Dataset):
    def __init__(self, root_dir, split='train', size=256, transform=None):
        self.root_dir = root_dir
        self.size = size
        self.transform = transform
        
        # 1. Load Metadata CSV (The Single Source of Truth)
        csv_path = os.path.join(root_dir, "metadata.csv")
        if not os.path.exists(csv_path):
            raise FileNotFoundError(f"Metadata not found at {csv_path}. Run download_and_extract_sunrgbd.py first!")
            
        self.metadata = pd.read_csv(csv_path)
        
        # 2. Create Label Encoder from the full CSV
        # SUN RGB-D has many more scene categories than NYU, so we fit on all of them.
        self.label_encoder = LabelEncoder()
        # Convert to string just in case 'unknown' or NaNs slipped through as non-strings
        self.metadata['scene_label'] = self.metadata['scene_label'].astype(str)
        self.metadata['label_int'] = self.label_encoder.fit_transform(self.metadata['scene_label'])
        
        # 3. Train/Val Split (Simple 90/10)
        # We split the DATAFRAME, preserving the row order
        split_idx = int(0.9 * len(self.metadata))
        if split == 'train':
            self.data = self.metadata.iloc[:split_idx].reset_index(drop=True)
        else:
            self.data = self.metadata.iloc[split_idx:].reset_index(drop=True)
            
        print(f"[SUNRGBDDataset] {split}: {len(self.data)} samples loaded from CSV.")
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        
        # Paths
        img_path = os.path.join(self.root_dir, "images", row['image_name'])
        depth_path = os.path.join(self.root_dir, "depths", row['depth_name'])
        
        # Load Image
        # Some SUN images might be grayscale or RGBA, forcing convert('RGB') is crucial here
        img = Image.open(img_path).convert('RGB')
        img = img.resize((self.size, self.size), Image.BILINEAR)
        img_array = np.array(img).astype(np.float32) / 255.0
        img_tensor = torch.from_numpy(img_array).permute(2, 0, 1)
        
        # Load Depth
        # Standardized to 16-bit PNG in previous step
        depth_img = Image.open(depth_path)
        depth_img = depth_img.resize((self.size, self.size), Image.NEAREST)
        depth_array = np.array(depth_img).astype(np.float32)
        
        # mm to meters check
        # SUN RGB-D is raw mm (so values are often > 255)
        #if depth_array.max() > 255.0:
        #    depth_array = depth_array / 1000.0
        
        # Clip to same range as NYU (0 to 10 meters) and normalize
        #depth_array = np.clip(depth_array, 0, 10.0) / 10.0
        #depth_tensor = torch.from_numpy(depth_array).unsqueeze(0)

        # 1. Convert "Packed" Kinect v1 depth (standard in raw SUN RGB-D)
        # If the values are oddly high (e.g., min > 3000mm implies the closest object is 3m away),
        # it is likely packed with the 3-bit player index.
        # We perform the check on a valid mask (>0) to ignore dropout.
        valid_mask = depth_array > 0
        if valid_mask.any():
            if np.min(depth_array[valid_mask]) > 3000:
                # Apply right shift by 3 bits (divide by 8)
                depth_array = (depth_array.astype(np.uint16) >> 3).astype(np.float32)
        
        # 2. Convert Millimeters to Meters
        # Standard SUN/NYU is in mm. If max > 255, it's definitely mm (or packed mm).
        # We assume anything with a max value > 100 is likely mm.
        if depth_array.max() > 100.0:
            depth_array = depth_array / 1000.0
            
        # 3. Clip and Normalize
        # Clip to [0, 10] meters
        depth_array = np.clip(depth_array, 0, 10.0) 
        
        # Normalize to [0, 1] for the model
        depth_array = depth_array / 10.0

        # 4. Create Tensor
        depth_tensor = torch.from_numpy(depth_array).unsqueeze(0) # (1, H, W)
        
        if self.transform:
            img_tensor = self.transform(img_tensor)
            depth_tensor = self.transform(depth_tensor)
            
        return {
            'image': img_tensor,
            'depth': depth_tensor,
            'scene_label': row['scene_label'],
            'scene_label_int': int(row['label_int']),
            'index': int(row['sample_index'])
        }

    # --- Stats ---
    # Kept identical to your previous class for compatibility
    def compute_depth_stats(self):
        vals = []
        max_d, min_d = [], []
        # Warning: Iterating 10k images might be slow
        print("Computing depth stats (this might take a moment)...")
        for _, row in self.data.iterrows():
            d_path = os.path.join(self.root_dir, "depths", row['depth_name'])
            d = np.array(Image.open(d_path)).astype(np.float32)
            if d.max() > 255.0: d /= 1000.0
            d = np.clip(d, 0, 10.0) / 10.0
            vals.append(d.reshape(1, -1))
            max_d.append(d.max())
            min_d.append(d.min())
        
        all_d = np.concatenate(vals, axis=0)
        return np.mean(all_d, axis=1), np.std(all_d, axis=1), np.array(max_d), np.array(min_d)

    def compute_color_stats(self):
        vals = []
        print("Computing color stats (this might take a moment)...")
        for _, row in self.data.iterrows():
            i_path = os.path.join(self.root_dir, "images", row['image_name'])
            img = Image.open(i_path).convert('RGB').resize((self.size, self.size))
            img = np.array(img).astype(np.float32) / 255.0
            vals.append(img.reshape(1, -1, 3)) # (1, H*W, 3)
        
        all_c = np.concatenate(vals, axis=0) # (N, H*W, 3)
        mean_rgb = np.mean(np.mean(all_c, axis=1), axis=1)
        std_rgb = np.std(np.std(all_c, axis=1), axis=1)
        return mean_rgb, std_rgb

    # --- Properties ---
    @property
    def scene_labels_int(self):
        return self.data['label_int'].values

    @property
    def scene_labels_str(self):
        return self.data['scene_label'].values
        
    @property
    def classes(self):
        return self.label_encoder.classes_