import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import os
import json
from tqdm import tqdm
import numpy as np
from scipy.ndimage import sobel  # For Sobel gradient (Methodology §3.2.1)


# ------------------------------
# Image Transformations (Matches Eq.1: conv + Sobel features)
# ------------------------------
class SobelTransform:
    """Extract Sobel gradient features (required for structural visual features in §3.2.1)"""

    def __call__(self, img_tensor):
        # img_tensor: [C, H, W] (0-1 normalized)
        img_np = img_tensor.numpy()
        sobel_x = np.stack([sobel(ch, axis=1) for ch in img_np], axis=0)
        sobel_y = np.stack([sobel(ch, axis=0) for ch in img_np], axis=0)
        sobel_mag = np.sqrt(sobel_x ** 2 + sobel_y ** 2)
        return torch.tensor(sobel_mag, dtype=img_tensor.dtype)


def get_transforms(img_size=224, is_train=True):
    """Composed transforms: Resize → Aug → Conv (via DINO) + Sobel (Eq.1) → Normalize"""
    base_transform = [
        transforms.Resize((img_size, img_size), interpolation=Image.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                             std=[0.26862954, 0.26130258, 0.27577711])
    ]

    if is_train:
        base_transform.insert(0, transforms.RandomHorizontalFlip())
        base_transform.insert(0, transforms.RandomResizedCrop(img_size, scale=(0.8, 1.0)))

    # Add Sobel transform (for structural features in Eq.1)
    base_transform.append(SobelTransform())
    return transforms.Compose(base_transform)


# ------------------------------
# 1. ImageNet-1K Dataset (Zero-Shot Classification §4.2.1)
# ------------------------------
class ImageNetDataset(Dataset):
    def __init__(self, root, split='val', img_size=224, class_names_path='imagenet_classes.txt'):
        self.root = root
        self.split = split
        self.img_size = img_size
        self.transform = get_transforms(img_size, is_train=(split == 'train'))

        # Load fixed ImageNet class order (critical for reproducibility)
        with open(class_names_path, 'r') as f:
            self.classes = [line.strip() for line in f.readlines()]
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}

        # Load image paths (filter valid extensions)
        self.samples = []
        for cls in tqdm(self.classes, desc=f"Loading ImageNet {split}"):
            cls_dir = os.path.join(root, split, cls)
            if not os.path.exists(cls_dir):
                raise FileNotFoundError(f"Class dir {cls_dir} not found (check ImageNet root)")
            for img_name in os.listdir(cls_dir):
                if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
                    img_path = os.path.join(cls_dir, img_name)
                    self.samples.append((img_path, self.class_to_idx[cls]))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        try:
            image = Image.open(img_path).convert('RGB')
        except Exception as e:
            raise RuntimeError(f"Failed to load {img_path}: {str(e)}")

        # Get (conv_features + Sobel_features) via transform (Eq.1 input)
        img_tensor = self.transform(image)  # [2*C, H, W] (C=3: RGB + Sobel)
        return img_tensor, label


# ------------------------------
# 2. COCO Dataset (Cross-Modal Retrieval §4.2.2: Karpathy 1K Split)
# ------------------------------
class COCODataset(Dataset):
    def __init__(self, root, ann_file, karpathy_split_file, split='val', img_size=224):
        self.root = root
        self.split = split
        self.img_size = img_size
        self.transform = get_transforms(img_size, is_train=(split == 'train'))

        # Load Karpathy 1K split IDs (required for §4.1 evaluation)
        with open(karpathy_split_file, 'r') as f:
            self.karpathy_img_ids = set(json.load(f)[split])

        # Load COCO annotations (2017 format)
        with open(ann_file, 'r') as f:
            data = json.load(f)

        self.img_id_to_info = {}
        for img in data['images']:
            if img['id'] in self.karpathy_img_ids:
                img_subdir = f"{split}2017"
                img_path = os.path.join(root, img_subdir, img['file_name'])
                self.img_id_to_info[img['id']] = {
                    'path': img_path,
                    'captions': []
                }

        # Assign captions to Karpathy images
        for ann in data['annotations']:
            img_id = ann['image_id']
            if img_id in self.img_id_to_info:
                self.img_id_to_info[img_id]['captions'].append(ann['caption'])

        # Build samples (1 caption/val, all captions/train)
        self.samples = []
        for img_id, info in self.img_id_to_info.items():
            captions = info['captions']
            if split == 'train':
                for cap in captions:
                    self.samples.append((img_id, info['path'], cap))
            else:
                self.samples.append((img_id, info['path'], captions[0] if captions else ""))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_id, img_path, caption = self.samples[idx]
        image = Image.open(img_path).convert('RGB')
        img_tensor = self.transform(image)  # [2*C, H, W] (conv + Sobel)
        return img_tensor, caption, img_id


# ------------------------------
# 3. ADE20K Dataset (Fine-Grained Alignment §4.2.3)
# ------------------------------
class ADE20KDataset(Dataset):
    def __init__(self, root, ann_file, split='val', img_size=224):
        self.root = root
        self.split = split
        self.img_size = img_size
        self.transform = get_transforms(img_size, is_train=(split == 'train'))

        # Load ADE20K annotations (with masks and captions)
        with open(ann_file, 'r') as f:
            self.annotations = json.load(f)

        # Filter split and valid samples
        self.samples = []
        for item in self.annotations:
            if item['split'] != split:
                continue
            img_path = os.path.join(root, item['file_name'])
            mask_path = os.path.join(root, item['mask_file_name'])
            if not (os.path.exists(img_path) and os.path.exists(mask_path)):
                continue

            # Extract object info (for text-driven mask transfer)
            objects = item['objects']  # Format: [{"name": "...", "attributes": [...]}]
            self.samples.append({
                'img_path': img_path,
                'mask_path': mask_path,
                'caption': item['captions'][0] if item['captions'] else "",
                'objects': objects
            })

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        image = Image.open(sample['img_path']).convert('RGB')
        mask = Image.open(sample['mask_path']).convert('L')  # Segmentation mask

        # Transform image (conv + Sobel) and mask
        img_tensor = self.transform(image)
        mask_tensor = transforms.Resize((self.img_size, self.img_size))(transforms.ToTensor()(mask))

        # Prepare object texts (for fine-grained alignment)
        object_texts = [f"{obj['name']} {' '.join(obj.get('attributes', []))}" for obj in sample['objects']]

        return {
            'image': img_tensor,
            'mask': mask_tensor,
            'caption': sample['caption'],
            'object_texts': object_texts
        }


# ------------------------------
# Dataset Factory (Matches paper's 3 core datasets)
# ------------------------------
def get_dataset(name, **kwargs):
    datasets = {
        'imagenet': ImageNetDataset,
        'coco': COCODataset,
        'ade20k': ADE20KDataset
    }
    if name not in datasets:
        raise ValueError(f"Dataset {name} not supported. Choose {list(datasets.keys())}")
    return datasets[name](**kwargs)