import torch
from torch.utils.data import Dataset
from PIL import Image
import numpy as np
import pandas as pd
import os
from sklearn.preprocessing import LabelEncoder

class NYUDepthV2Dataset(Dataset):
    def __init__(self, root_dir, split='train', size=256, transform=None):
        self.root_dir = root_dir
        self.size = size
        self.transform = transform
        
        # 1. Load Metadata CSV (The Single Source of Truth)
        csv_path = os.path.join(root_dir, "metadata.csv")
        if not os.path.exists(csv_path):
            raise FileNotFoundError(f"Metadata not found at {csv_path}. Run download_and_extract.py first!")
            
        self.metadata = pd.read_csv(csv_path)
        
        # 2. Create Label Encoder from the full CSV
        # This ensures the encoder is consistent regardless of train/val split
        self.label_encoder = LabelEncoder()
        self.metadata['label_int'] = self.label_encoder.fit_transform(self.metadata['scene_label'])
        
        # 3. Train/Val Split (Simple 90/10)
        # We split the DATAFRAME, preserving the row order
        split_idx = int(0.9 * len(self.metadata))
        if split == 'train':
            self.data = self.metadata.iloc[:split_idx].reset_index(drop=True)
        else:
            self.data = self.metadata.iloc[split_idx:].reset_index(drop=True)
            
        print(f"[NYUDepthV2Dataset] {split}: {len(self.data)} samples loaded from CSV.")
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        
        # Paths
        img_path = os.path.join(self.root_dir, "images", row['image_name'])
        depth_path = os.path.join(self.root_dir, "depths", row['depth_name'])
        
        # Load Image
        img = Image.open(img_path).convert('RGB')
        img = img.resize((self.size, self.size), Image.BILINEAR)
        img_array = np.array(img).astype(np.float32) / 255.0
        img_tensor = torch.from_numpy(img_array).permute(2, 0, 1)
        
        # Load Depth
        depth_img = Image.open(depth_path)
        depth_img = depth_img.resize((self.size, self.size), Image.NEAREST)
        depth_array = np.array(depth_img).astype(np.float32)
        
        # mm to meters check
        if depth_array.max() > 255.0:
            depth_array = depth_array / 1000.0
        depth_array = np.clip(depth_array, 0, 10.0) / 10.0
        depth_tensor = torch.from_numpy(depth_array).unsqueeze(0)
        
        if self.transform:
            img_tensor = self.transform(img_tensor)
            depth_tensor = self.transform(depth_tensor)
            
        return {
            'image': img_tensor,
            'depth': depth_tensor,
            'scene_label': row['scene_label'],
            'scene_label_int': int(row['label_int']),
            'index': int(row['sample_index'])
        }

    # --- Stats ---
    def compute_depth_stats(self):
        vals = []
        max_d, min_d = [], []
        for _, row in self.data.iterrows():
            d_path = os.path.join(self.root_dir, "depths", row['depth_name'])
            d = np.array(Image.open(d_path)).astype(np.float32)
            if d.max() > 255.0: d /= 1000.0
            d = np.clip(d, 0, 10.0) / 10.0
            vals.append(d.reshape(1, -1))
            max_d.append(d.max())
            min_d.append(d.min())
        
        all_d = np.concatenate(vals, axis=0)
        return np.mean(all_d, axis=1), np.std(all_d, axis=1), np.array(max_d), np.array(min_d)

    def compute_color_stats(self):
        vals = []
        for _, row in self.data.iterrows():
            i_path = os.path.join(self.root_dir, "images", row['image_name'])
            img = Image.open(i_path).convert('RGB').resize((self.size, self.size))
            img = np.array(img).astype(np.float32) / 255.0
            vals.append(img.reshape(1, -1, 3)) # (1, H*W, 3)
        
        all_c = np.concatenate(vals, axis=0) # (N, H*W, 3)
        mean_rgb = np.mean(np.mean(all_c, axis=1), axis=1)
        std_rgb = np.std(np.std(all_c, axis=1), axis=1)
        return mean_rgb, std_rgb

    # --- Properties ---
    @property
    def scene_labels_int(self):
        return self.data['label_int'].values

    @property
    def scene_labels_str(self):
        return self.data['scene_label'].values
        
    @property
    def classes(self):
        return self.label_encoder.classes_