"""Dataset classes for Office-Home and DomainNet."""

import os
import pandas as pd
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from typing import List, Dict, Tuple, Optional, Union
from pathlib import Path


class OfficeHomeDataset(Dataset):
    """Office-Home dataset with domain-aware splitting."""
    
    def __init__(
        self,
        data_root: str,
        domain: str,
        split: str = "train",
        transform: Optional[transforms.Compose] = None,
        target_transform: Optional[transforms.Compose] = None,
        indices_file: Optional[str] = None,
        toy_mode: bool = False,
        toy_classes: int = 2,
        toy_samples_per_class: int = 200
    ):
        """Initialize Office-Home dataset.
        
        Args:
            data_root: Root directory containing domain folders
            domain: Domain name (Art, Clipart, Product, RealWorld)
            split: Split name (train, val, test)
            transform: Image transformations
            target_transform: Label transformations
            indices_file: CSV file with image paths and labels
            toy_mode: Whether to use toy mode with reduced data
            toy_classes: Number of classes in toy mode
            toy_samples_per_class: Samples per class in toy mode
        """
        self.data_root = Path(data_root)
        self.domain = domain
        self.split = split
        self.transform = transform
        self.target_transform = target_transform
        self.toy_mode = toy_mode
        
        if indices_file and os.path.exists(indices_file):
            self._load_from_indices(indices_file)
        else:
            self._load_from_directory()
        
        if toy_mode:
            self._create_toy_subset(toy_classes, toy_samples_per_class)
    
    def _load_from_indices(self, indices_file: str):
        """Load dataset from CSV indices file."""
        df = pd.read_csv(indices_file)
        domain_df = df[(df['domain'] == self.domain) & (df['split'] == self.split)]
        
        self.samples = []
        for _, row in domain_df.iterrows():
            img_path = os.path.join(self.data_root, row['img_path'])
            if os.path.exists(img_path):
                self.samples.append({
                    'path': img_path,
                    'label': row['label'],
                    'domain': row['domain']
                })
        
        print(f"Loaded {len(self.samples)} samples from {self.domain} ({self.split})")
    
    def _load_from_directory(self):
        """Load dataset by scanning directory structure."""
        domain_path = self.data_root / self.domain
        if not domain_path.exists():
            raise ValueError(f"Domain directory not found: {domain_path}")
        
        self.samples = []
        for class_idx, class_name in enumerate(sorted(os.listdir(domain_path))):
            class_path = domain_path / class_name
            if class_path.is_dir():
                for img_name in os.listdir(class_path):
                    if img_name.lower().endswith(('.jpg', '.jpeg', '.png')):
                        img_path = class_path / img_name
                        self.samples.append({
                            'path': str(img_path),
                            'label': class_idx,
                            'domain': self.domain
                        })
        
        print(f"Loaded {len(self.samples)} samples from {self.domain} ({self.split})")
    
    def _create_toy_subset(self, num_classes: int, samples_per_class: int):
        """Create a toy subset for quick testing."""
        if len(self.samples) == 0:
            return
        
        # Group by class
        class_groups = {}
        for sample in self.samples:
            label = sample['label']
            if label not in class_groups:
                class_groups[label] = []
            class_groups[label].append(sample)
        
        # Select classes and samples
        selected_classes = sorted(list(class_groups.keys()))[:num_classes]
        toy_samples = []
        
        for class_idx in selected_classes:
            class_samples = class_groups[class_idx][:samples_per_class]
            toy_samples.extend(class_samples)
        
        self.samples = toy_samples
        print(f"Created toy subset: {len(self.samples)} samples from {len(selected_classes)} classes")
    
    def __len__(self) -> int:
        return len(self.samples)
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
        sample = self.samples[idx]
        img_path = sample['path']
        label = sample['label']
        
        # Load image
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        if self.target_transform:
            label = self.target_transform(label)
        
        return image, label


class DomainNetDataset(Dataset):
    """DomainNet dataset with domain-aware splitting."""
    
    def __init__(
        self,
        data_root: str,
        domain: str,
        split: str = "train",
        transform: Optional[transforms.Compose] = None,
        target_transform: Optional[transforms.Compose] = None,
        indices_file: Optional[str] = None,
        toy_mode: bool = False,
        toy_classes: int = 2,
        toy_samples_per_class: int = 200
    ):
        """Initialize DomainNet dataset.
        
        Args:
            data_root: Root directory containing domain folders
            domain: Domain name (clipart, infograph, painting, quickdraw, real, sketch)
            split: Split name (train, val, test)
            transform: Image transformations
            target_transform: Label transformations
            indices_file: CSV file with image paths and labels
            toy_mode: Whether to use toy mode with reduced data
            toy_classes: Number of classes in toy mode
            toy_samples_per_class: Samples per class in toy mode
        """
        self.data_root = Path(data_root)
        self.domain = domain
        self.split = split
        self.transform = transform
        self.target_transform = target_transform
        self.toy_mode = toy_mode
        
        if indices_file and os.path.exists(indices_file):
            self._load_from_indices(indices_file)
        else:
            self._load_from_directory()
        
        if toy_mode:
            self._create_toy_subset(toy_classes, toy_samples_per_class)
    
    def _load_from_indices(self, indices_file: str):
        """Load dataset from CSV indices file."""
        df = pd.read_csv(indices_file)
        domain_df = df[(df['domain'] == self.domain) & (df['split'] == self.split)]
        
        self.samples = []
        for _, row in domain_df.iterrows():
            img_path = os.path.join(self.data_root, row['img_path'])
            if os.path.exists(img_path):
                self.samples.append({
                    'path': img_path,
                    'label': row['label'],
                    'domain': row['domain']
                })
        
        print(f"Loaded {len(self.samples)} samples from {self.domain} ({self.split})")
    
    def _load_from_directory(self):
        """Load dataset by scanning directory structure."""
        domain_path = self.data_root / self.domain
        if not domain_path.exists():
            raise ValueError(f"Domain directory not found: {domain_path}")
        
        self.samples = []
        for class_idx, class_name in enumerate(sorted(os.listdir(domain_path))):
            class_path = domain_path / class_name
            if class_path.is_dir():
                for img_name in os.listdir(class_path):
                    if img_name.lower().endswith(('.jpg', '.jpeg', '.png')):
                        img_path = class_path / img_name
                        self.samples.append({
                            'path': str(img_path),
                            'label': class_idx,
                            'domain': self.domain
                        })
        
        print(f"Loaded {len(self.samples)} samples from {self.domain} ({self.split})")
    
    def _create_toy_subset(self, num_classes: int, samples_per_class: int):
        """Create a toy subset for quick testing."""
        if len(self.samples) == 0:
            return
        
        # Group by class
        class_groups = {}
        for sample in self.samples:
            label = sample['label']
            if label not in class_groups:
                class_groups[label] = []
            class_groups[label].append(sample)
        
        # Select classes and samples
        selected_classes = sorted(list(class_groups.keys()))[:num_classes]
        toy_samples = []
        
        for class_idx in selected_classes:
            class_samples = class_groups[class_idx][:samples_per_class]
            toy_samples.extend(class_samples)
        
        self.samples = toy_samples
        print(f"Created toy subset: {len(self.samples)} samples from {len(selected_classes)} classes")
    
    def __len__(self) -> int:
        return len(self.samples)
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
        sample = self.samples[idx]
        img_path = sample['path']
        label = sample['label']
        
        # Load image
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        if self.target_transform:
            label = self.target_transform(label)
        
        return image, label


def get_transforms(
    image_size: int = 224,
    is_training: bool = True,
    normalize: bool = True
) -> transforms.Compose:
    """Get image transformations.
    
    Args:
        image_size: Target image size
        is_training: Whether to apply training augmentations
        normalize: Whether to normalize images
        
    Returns:
        Composition of transforms
    """
    if is_training:
        transform_list = [
            transforms.RandomResizedCrop(image_size, scale=(0.8, 1.0)),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
            transforms.ToTensor()
        ]
    else:
        transform_list = [
            transforms.Resize(int(image_size * 1.14)),
            transforms.CenterCrop(image_size),
            transforms.ToTensor()
        ]
    
    if normalize:
        transform_list.append(transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        ))
    
    return transforms.Compose(transform_list)


def create_dataloader(
    dataset: Dataset,
    batch_size: int = 32,
    shuffle: bool = True,
    num_workers: int = 4,
    pin_memory: bool = True,
    drop_last: bool = False
) -> DataLoader:
    """Create DataLoader with proper settings.
    
    Args:
        dataset: Dataset to wrap
        batch_size: Batch size
        shuffle: Whether to shuffle data
        num_workers: Number of worker processes
        pin_memory: Whether to pin memory
        drop_last: Whether to drop last incomplete batch
        
    Returns:
        Configured DataLoader
    """
    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=pin_memory,
        drop_last=drop_last
    )


def get_dataset_class(dataset_name: str):
    """Get dataset class by name.
    
    Args:
        dataset_name: Name of dataset (officehome, domainnet)
        
    Returns:
        Dataset class
    """
    dataset_map = {
        'officehome': OfficeHomeDataset,
        'domainnet': DomainNetDataset
    }
    
    if dataset_name not in dataset_map:
        raise ValueError(f"Unknown dataset: {dataset_name}. Available: {list(dataset_map.keys())}")
    
    return dataset_map[dataset_name]
