import os
import glob
import random
import torchvision
import numpy as np
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset
import json
import torch

class GlobDataset(Dataset):
    def __init__(self, root, img_size, img_glob=['**/*.png', '**/*.jpg', '**/*.jpeg'], 
                 data_portion=(),  random_data_on_portion=True,
                 vit_norm=False, random_flip=False, vit_input_resolution=448):
        super().__init__()
        if isinstance(root, str) or not hasattr(root, '__iter__'):
            root = [root]
            if isinstance(img_glob, str):
                img_glob = [img_glob]
            elif not hasattr(img_glob, '__iter__'):
                img_glob = [[img_glob]]
            else:
                img_glob = [img_glob]
        if not all(hasattr(sublist, '__iter__') for sublist in data_portion) or data_portion == (): # if not iterable or empty
            data_portion = [data_portion]
        self.root = root
        self.img_size = img_size
        self.episodes = []
        self.vit_norm = vit_norm
        self.random_flip = random_flip

        for n, (r, globs) in enumerate(zip(root, img_glob)):
            episodes = []
            if isinstance(globs, str):
                globs = [globs]
            # Merge episodes from all glob patterns
            for g in globs:
                episodes.extend(glob.glob(os.path.join(r, g), recursive=True))

            episodes = sorted(list(set(episodes)))  # Remove duplicates if any

            data_p = data_portion[n]

            assert (len(data_p) == 0 or len(data_p) == 2)
            if len(data_p) == 2:
                assert max(data_p) <= 1.0 and min(data_p) >= 0.0

            if data_p and data_p != (0., 1.):
                if random_data_on_portion:
                    random.Random(42).shuffle(episodes) # fix results
                episodes = \
                    episodes[int(len(episodes)*data_p[0]):int(len(episodes)*data_p[1])]

            self.episodes += episodes
        
        # resize the shortest side to img_size and center crop
        self.transform = transforms.Compose([
            transforms.Resize(img_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop(img_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5], std=[0.5])
        ])

        if vit_norm:
            self.transform_vit = transforms.Compose([
                transforms.Resize(vit_input_resolution, interpolation=torchvision.transforms.InterpolationMode.BILINEAR),
                transforms.CenterCrop(vit_input_resolution),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225]
                )
            ])
    
    def __len__(self):
        return len(self.episodes)

    def __getitem__(self, i):
        example = {}
        image = Image.open(self.episodes[i])
        if not image.mode == "RGB":
            image = image.convert("RGB")
        if self.random_flip:
            if random.random() > 0.5:
                image = image.transpose(Image.FLIP_LEFT_RIGHT)
        
        # Process image with standard transform
        pixel_values = self.transform(image)
        example["pixel_values"] = pixel_values
        
        if self.vit_norm:
            image_vit = self.transform_vit(image)
            example["pixel_values_vit"] = image_vit
            
        return example

class ClevrTexDataset(Dataset):
    def __init__(
        self,
        root,
        img_size,
        max_num_objects=10,
        keys_to_log=('shape', 'size', 'material'),
        img_glob='*.png',
        data_portion=(),
        random_data_on_portion=True,
    ):
        super().__init__()
        self.root = root
        self.img_size = img_size
        self.max_num_objects = max_num_objects
        self.keys_to_log = keys_to_log
        
        # Setup paths for images and labels
        self.image_dir = os.path.join(root, 'images')
        self.label_dir = os.path.join(root, 'labels')
        
        # Setup image transform
        self.transform = transforms.Compose([
            transforms.Resize(img_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop(img_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5], std=[0.5])
        ])

        # Gather all image paths and ensure label files exist
        self.episodes = []
        image_files = sorted(glob.glob(os.path.join(self.image_dir, img_glob)))
        print('Found {} images in {}'.format(len(image_files), self.image_dir))
        
        # Create label mappings from data
        self.labels_to_idx = {k: {} for k in self.keys_to_log}
        for img_path in image_files:
            img_name = os.path.basename(img_path)
            label_name = img_name[:-9] + '.json'
            label_path = os.path.join(self.label_dir, label_name)
            
            if os.path.exists(label_path):
                self.episodes.append((img_path, label_path))
                # Read label file to gather unique values
                with open(label_path, 'r') as f:
                    label_data = json.load(f)
                    for obj in label_data['objects']:
                        for key in self.keys_to_log:
                            value = obj[key]
                            if value not in self.labels_to_idx[key]:
                                self.labels_to_idx[key][value] = len(self.labels_to_idx[key])

        print("Created label mappings:")
        for key, mapping in self.labels_to_idx.items():
            print(f"{key}: {len(mapping)} unique values")
            print(mapping)

        if data_portion and data_portion != (0., 1.):
            if random_data_on_portion:
                random.Random(42).shuffle(self.episodes)
            start_idx = int(len(self.episodes) * data_portion[0])
            end_idx = int(len(self.episodes) * data_portion[1])
            self.episodes = self.episodes[start_idx:end_idx]

        # Calculate total feature dimension for one-hot encoding
        self.feature_dims = {k: len(v) for k, v in self.labels_to_idx.items()}
        self.total_feature_dim = sum(self.feature_dims.values()) + 4  # +3 for coords, +1 for visibility
        print(f"Total feature dimension: {self.total_feature_dim}")
        print(f"Feature dimensions: {self.feature_dims}")

    def __len__(self):
        return len(self.episodes)

    def __getitem__(self, idx):
        example = {}
        img_path, label_path = self.episodes[idx]
        
        # Load and transform image
        image = Image.open(img_path).convert("RGB")
        pixel_values = self.transform(image)
        example["pixel_values"] = pixel_values
        
        # Load and process labels directly from JSON
        with open(label_path, 'r') as f:
            label_data = json.load(f)
        
        # Process objects and their properties
        num_objects = len(label_data['objects'])
        labels = {k: [] for k in self.keys_to_log}
        labels['visibility'] = []
        
        for obj in label_data['objects']:
            for key in self.keys_to_log:
                value = obj[key]
                # Use our automatically created mapping
                value = self.labels_to_idx[key][value]
                labels[key].append(value)
            labels['visibility'].append(1.0)
            
        # Pad with zeros or last values up to max_num_objects
        for key in self.keys_to_log:
            if labels[key]:
                pad_value = labels[key][-1]
                labels[key].extend([pad_value] * (self.max_num_objects - num_objects))
                example[key] = torch.tensor(labels[key]).float()
            else:
                example[key] = torch.zeros(self.max_num_objects).float()
                
        # Pad visibility
        labels['visibility'].extend([0.0] * (self.max_num_objects - num_objects))
        example['visibility'] = torch.tensor(labels['visibility']).float()
        example['num_objects'] = num_objects

        # Create one-hot encoded feature vectors with coordinates
        feature_vectors = []
        for obj_idx in range(num_objects):
            obj_vector = []
            # Add one-hot vectors for each feature
            for key in self.keys_to_log:
                feat_idx = labels[key][obj_idx]
                one_hot = torch.zeros(self.feature_dims[key])
                one_hot[int(feat_idx)] = 1.0
                obj_vector.append(one_hot)
            
            # Add normalized coordinates from pixel_coords
            coords = torch.tensor(label_data['objects'][obj_idx]['3d_coords'])
            coords = (coords + 3.) / 6  # normalize to [0,1]
            obj_vector.append(coords)
            
            # Add visibility
            obj_vector.append(torch.tensor([labels['visibility'][obj_idx]]))
            
            # Concatenate all features
            obj_vector = torch.cat(obj_vector)
            feature_vectors.append(obj_vector)

        zero_vector = torch.zeros(self.total_feature_dim)
        for _ in range(num_objects, self.max_num_objects):
            feature_vectors.append(zero_vector)

        example['feature_vectors'] = torch.stack(feature_vectors)
        return example

class CelebADataset(Dataset):
    """Dataset for CelebA facial attributes classification"""
    def __init__(
        self,
        root,
        img_size,
        attr_file='CelebAMask-HQ-attribute-anno.txt',
        data_portion=(),
        random_data_on_portion=True,
    ):
        super().__init__()
        self.root = root
        self.img_size = img_size
        
        # Read attribute annotations
        attr_path = os.path.join(root, attr_file)
        self.attr_names = []
        self.attr_map = {}
        
        with open(attr_path, 'r') as f:
            # First line contains number of images
            num_images = int(f.readline().strip())
            # Second line contains attribute names
            self.attr_names = f.readline().strip().split()
            # Remaining lines contain image_name and attributes
            for line in f:
                parts = line.strip().split()
                img_name = parts[0]
                # Convert -1/1 to 0/1
                attrs = [(int(x) + 1) // 2 for x in parts[1:]]
                self.attr_map[img_name] = attrs
        print(f"Loaded {len(self.attr_map)} images with {len(self.attr_names)} attributes")
        # Get all image paths
        self.episodes = sorted(glob.glob(os.path.join(root, '**/*.jpg')))
        print(f"Found {len(self.episodes)} images in {root}")
                    
        # Apply data portion if specified
        if data_portion and data_portion != (0., 1.):
            if random_data_on_portion:
                random.Random(42).shuffle(self.episodes)
            start_idx = int(len(self.episodes) * data_portion[0])
            end_idx = int(len(self.episodes) * data_portion[1])
            self.episodes = self.episodes[start_idx:end_idx]

        self.transform = transforms.Compose([
            transforms.Resize(img_size),
            transforms.CenterCrop(img_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

    def __len__(self):
        return len(self.episodes)

    def __getitem__(self, idx):
        img_path = self.episodes[idx]
        img_name = os.path.basename(img_path)
        
        # Load and transform image
        image = Image.open(img_path).convert("RGB")
        pixel_values = self.transform(image)
        
        # Get attributes
        attrs = self.attr_map.get(img_name)
        if attrs is None:
            # Handle case where image name doesn't match the attr_map keys
            # Try extracting just the numerical part
            base_name = os.path.splitext(img_name)[0]
            numeric_part = ''.join(filter(str.isdigit, base_name))
            # Try different formats
            alt_names = [
                f"{numeric_part}.jpg",
                f"{numeric_part.zfill(5)}.jpg",
                f"{numeric_part.zfill(6)}.jpg",
                f"image{numeric_part}.jpg",
                f"{os.path.basename(img_path)}"
            ]
            
            for alt_name in alt_names:
                if alt_name in self.attr_map:
                    attrs = self.attr_map[alt_name]
                    break
            
            if attrs is None:
                # If still not found, use zeros as a fallback
                print(f"Warning: No attributes found for {img_name}")
                attrs = [0] * len(self.attr_names)
        
        attrs = torch.tensor(attrs, dtype=torch.float32)
        
        example = {
            "pixel_values": pixel_values,
            "labels": attrs
        }
        
        return example

if __name__ == "__main__":
    dataset = GlobDataset(
        root="/research/projects/object_centric/shared_datasets/movi/movi-e/movi-e-train-with-label/images/",
        img_size=256,
        img_glob="**/*.png",
        data_portion=(0.0, 0.9)
    )
    pass