"""
Dataset module for Multi-Scale Attention U-Net medical image segmentation
Generates synthetic medical images with ground truth segmentation masks
"""

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
import cv2
from typing import Tuple, List, Optional
import random
from dataclasses import dataclass

@dataclass
class DatasetConfig:
    """Configuration for synthetic medical dataset"""
    image_size: Tuple[int, int] = (512, 512)
    num_classes: int = 5
    train_size: int = 10000
    val_size: int = 2000
    test_size: int = 1000
    structures_per_image: Tuple[int, int] = (1, 3)
    scale_range: Tuple[int, int] = (32, 256)
    noise_level: float = 0.01
    flip_prob: float = 0.5
    rotation_range: int = 15
    scale_range_aug: Tuple[float, float] = (0.9, 1.1)
    brightness_range: Tuple[float, float] = (0.8, 1.2)

class SyntheticMedicalDataset(Dataset):
    """Synthetic medical image dataset with segmentation masks"""
    
    def __init__(self, config: DatasetConfig, mode: str = 'train', transform: Optional[callable] = None):
        self.config = config
        self.mode = mode
        self.transform = transform
        
        # Set dataset size based on mode
        if mode == 'train':
            self.size = config.train_size
        elif mode == 'val':
            self.size = config.val_size
        else:  # test
            self.size = config.test_size
            
        # Set random seed for reproducibility
        self.seed = 42 if mode == 'train' else 43 if mode == 'val' else 44
        random.seed(self.seed)
        np.random.seed(self.seed)
        
        # Generate data indices
        self.indices = list(range(self.size))
        
    def __len__(self) -> int:
        return self.size
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """Generate synthetic medical image and segmentation mask"""
        # Set seed for reproducible generation
        random.seed(self.seed + idx)
        np.random.seed(self.seed + idx)
        
        # Generate image and mask
        image, mask = self._generate_synthetic_data()
        
        # Apply transforms if provided
        if self.transform:
            image, mask = self.transform(image, mask)
            
        return image, mask
    
    def _generate_synthetic_data(self) -> Tuple[np.ndarray, np.ndarray]:
        """Generate synthetic medical image with anatomical structures"""
        H, W = self.config.image_size
        num_classes = self.config.num_classes
        
        # Initialize image and mask
        image = np.zeros((H, W, 3), dtype=np.float32)
        mask = np.zeros((H, W, num_classes), dtype=np.float32)
        
        # Generate background with texture
        image = self._generate_background(image)
        
        # Determine number of structures to add
        num_structures = random.randint(*self.config.structures_per_image)
        
        # Add anatomical structures
        for _ in range(num_structures):
            class_id = random.randint(0, num_classes - 1)
            structure_image, structure_mask = self._generate_anatomical_structure(class_id)
            
            # Random placement
            h, w = structure_image.shape[:2]
            max_y, max_x = H - h, W - w
            y = random.randint(0, max_y) if max_y > 0 else 0
            x = random.randint(0, max_x) if max_x > 0 else 0
            
            # Place structure
            image[y:y+h, x:x+w] = np.maximum(image[y:y+h, x:x+w], structure_image)
            mask[y:y+h, x:x+w, class_id] = np.maximum(mask[y:y+h, x:x+w, class_id], structure_mask)
        
        # Add noise
        noise = np.random.normal(0, self.config.noise_level, image.shape)
        image = np.clip(image + noise, 0, 1)
        
        # Normalize image
        image = (image - image.min()) / (image.max() - image.min() + 1e-8)
        
        # Convert to tensor format (C, H, W)
        image = torch.from_numpy(image).permute(2, 0, 1).float()
        mask = torch.from_numpy(mask).permute(2, 0, 1).float()
        
        return image, mask
    
    def _generate_background(self, image: np.ndarray) -> np.ndarray:
        """Generate textured background"""
        H, W, C = image.shape
        
        # Create base background
        base_color = np.array([0.1, 0.1, 0.15])  # Dark blue-gray
        
        # Add texture using Perlin-like noise
        x = np.linspace(0, 4, W)
        y = np.linspace(0, 4, H)
        X, Y = np.meshgrid(x, y)
        
        # Multiple noise layers for texture
        noise1 = np.sin(X) * np.cos(Y)
        noise2 = np.sin(X * 2) * np.cos(Y * 2) * 0.5
        noise3 = np.sin(X * 4) * np.cos(Y * 4) * 0.25
        
        combined_noise = noise1 + noise2 + noise3
        combined_noise = (combined_noise - combined_noise.min()) / (combined_noise.max() - combined_noise.min())
        
        # Apply texture to each channel
        for c in range(C):
            image[:, :, c] = base_color[c] + combined_noise * 0.1
            
        return image
    
    def _generate_anatomical_structure(self, class_id: int) -> Tuple[np.ndarray, np.ndarray]:
        """Generate a single anatomical structure"""
        # Random size within scale range
        min_size, max_size = self.config.scale_range
        size = random.randint(min_size, max_size)
        
        # Create structure based on class
        if class_id == 0:  # Heart
            structure_image, structure_mask = self._generate_heart(size)
        elif class_id == 1:  # Liver
            structure_image, structure_mask = self._generate_liver(size)
        elif class_id == 2:  # Kidney
            structure_image, structure_mask = self._generate_kidney(size)
        elif class_id == 3:  # Lung
            structure_image, structure_mask = self._generate_lung(size)
        else:  # Brain
            structure_image, structure_mask = self._generate_brain(size)
            
        return structure_image, structure_mask
    
    def _generate_heart(self, size: int) -> Tuple[np.ndarray, np.ndarray]:
        """Generate heart-shaped structure"""
        # Create heart shape using mathematical formula
        t = np.linspace(0, 2*np.pi, size)
        x = 16 * np.sin(t)**3
        y = 13 * np.cos(t) - 5 * np.cos(2*t) - 2 * np.cos(3*t) - np.cos(4*t)
        
        # Normalize and scale
        x = ((x - x.min()) / (x.max() - x.min()) * (size - 1)).astype(int)
        y = ((y - y.min()) / (y.max() - y.min()) * (size - 1)).astype(int)
        
        # Create mask
        mask = np.zeros((size, size), dtype=np.float32)
        for i, j in zip(x, y):
            if 0 <= i < size and 0 <= j < size:
                mask[j, i] = 1.0
        
        # Fill interior
        mask = cv2.fillPoly(mask, [np.column_stack((x, y))], 1.0)
        
        # Create image with gradient
        image = np.zeros((size, size, 3), dtype=np.float32)
        center = size // 2
        for i in range(size):
            for j in range(size):
                if mask[i, j] > 0:
                    dist = np.sqrt((i - center)**2 + (j - center)**2)
                    intensity = 1.0 - (dist / (size/2))
                    image[i, j] = [0.8, 0.2, 0.2] * intensity  # Red gradient
        
        return image, mask
    
    def _generate_liver(self, size: int) -> Tuple[np.ndarray, np.ndarray]:
        """Generate liver-shaped structure"""
        # Create elliptical shape
        mask = np.zeros((size, size), dtype=np.float32)
        center = (size // 2, size // 2)
        axes = (size // 3, size // 2)
        
        # Create ellipse
        y, x = np.ogrid[:size, :size]
        ellipse = ((x - center[0]) / axes[0])**2 + ((y - center[1]) / axes[1])**2 <= 1
        mask[ellipse] = 1.0
        
        # Add irregularity
        kernel = np.ones((3, 3), np.uint8)
        mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
        
        # Create image
        image = np.zeros((size, size, 3), dtype=np.float32)
        image[mask > 0] = [0.6, 0.4, 0.2]  # Brown color
        
        return image, mask
    
    def _generate_kidney(self, size: int) -> Tuple[np.ndarray, np.ndarray]:
        """Generate kidney-shaped structure"""
        # Create bean-shaped structure
        mask = np.zeros((size, size), dtype=np.float32)
        center = (size // 2, size // 2)
        
        # Create bean shape using parametric equations
        t = np.linspace(0, 2*np.pi, 100)
        x = center[0] + (size // 4) * (np.cos(t) + 0.3 * np.cos(3*t))
        y = center[1] + (size // 3) * (np.sin(t) - 0.1 * np.sin(3*t))
        
        # Create mask
        mask = cv2.fillPoly(mask, [np.column_stack((x, y)).astype(int)], 1.0)
        
        # Create image
        image = np.zeros((size, size, 3), dtype=np.float32)
        image[mask > 0] = [0.5, 0.3, 0.1]  # Dark brown
        
        return image, mask
    
    def _generate_lung(self, size: int) -> Tuple[np.ndarray, np.ndarray]:
        """Generate lung-shaped structure"""
        # Create lung-like structure (two lobes)
        mask = np.zeros((size, size), dtype=np.float32)
        
        # Left lobe
        left_center = (size // 3, size // 2)
        left_axes = (size // 4, size // 3)
        y, x = np.ogrid[:size, :size]
        left_lobe = ((x - left_center[0]) / left_axes[0])**2 + ((y - left_center[1]) / left_axes[1])**2 <= 1
        mask[left_lobe] = 1.0
        
        # Right lobe
        right_center = (2 * size // 3, size // 2)
        right_axes = (size // 4, size // 3)
        right_lobe = ((x - right_center[0]) / right_axes[0])**2 + ((y - right_center[1]) / right_axes[1])**2 <= 1
        mask[right_lobe] = 1.0
        
        # Create image
        image = np.zeros((size, size, 3), dtype=np.float32)
        image[mask > 0] = [0.7, 0.7, 0.8]  # Light gray
        
        return image, mask
    
    def _generate_brain(self, size: int) -> Tuple[np.ndarray, np.ndarray]:
        """Generate brain-shaped structure"""
        # Create brain-like structure (irregular circle with folds)
        mask = np.zeros((size, size), dtype=np.float32)
        center = (size // 2, size // 2)
        
        # Create base circle
        y, x = np.ogrid[:size, :size]
        base_circle = (x - center[0])**2 + (y - center[1])**2 <= (size // 2)**2
        mask[base_circle] = 1.0
        
        # Add brain folds (random lines)
        for _ in range(5):
            angle = random.uniform(0, 2*np.pi)
            start_x = center[0] + (size // 4) * np.cos(angle)
            start_y = center[1] + (size // 4) * np.sin(angle)
            end_x = center[0] + (size // 3) * np.cos(angle)
            end_y = center[1] + (size // 3) * np.sin(angle)
            
            # Draw line
            cv2.line(mask, (int(start_x), int(start_y)), (int(end_x), int(end_y)), 0, 2)
        
        # Create image
        image = np.zeros((size, size, 3), dtype=np.float32)
        image[mask > 0] = [0.8, 0.6, 0.4]  # Flesh color
        
        return image, mask

class MedicalDataAugmentation:
    """Data augmentation for medical images"""
    
    def __init__(self, config: DatasetConfig):
        self.config = config
    
    def __call__(self, image: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Apply data augmentation"""
        # Convert to numpy for augmentation
        image_np = image.permute(1, 2, 0).numpy()
        mask_np = mask.permute(1, 2, 0).numpy()
        
        # Random horizontal flip
        if random.random() < self.config.flip_prob:
            image_np = cv2.flip(image_np, 1)
            mask_np = cv2.flip(mask_np, 1)
        
        # Random vertical flip
        if random.random() < self.config.flip_prob:
            image_np = cv2.flip(image_np, 0)
            mask_np = cv2.flip(mask_np, 0)
        
        # Random rotation
        if random.random() < 0.5:
            angle = random.uniform(-self.config.rotation_range, self.config.rotation_range)
            h, w = image_np.shape[:2]
            center = (w // 2, h // 2)
            M = cv2.getRotationMatrix2D(center, angle, 1.0)
            image_np = cv2.warpAffine(image_np, M, (w, h))
            mask_np = cv2.warpAffine(mask_np, M, (w, h))
        
        # Random scaling
        if random.random() < 0.5:
            scale = random.uniform(*self.config.scale_range_aug)
            h, w = image_np.shape[:2]
            new_h, new_w = int(h * scale), int(w * scale)
            image_np = cv2.resize(image_np, (new_w, new_h))
            mask_np = cv2.resize(mask_np, (new_w, new_h))
            
            # Crop or pad to original size
            if scale > 1:
                start_h = (new_h - h) // 2
                start_w = (new_w - w) // 2
                image_np = image_np[start_h:start_h+h, start_w:start_w+w]
                mask_np = mask_np[start_h:start_h+h, start_w:start_w+w]
            else:
                pad_h = (h - new_h) // 2
                pad_w = (w - new_w) // 2
                image_np = np.pad(image_np, ((pad_h, h-new_h-pad_h), (pad_w, w-new_w-pad_w), (0, 0)), mode='constant')
                mask_np = np.pad(mask_np, ((pad_h, h-new_h-pad_h), (pad_w, w-new_w-pad_w), (0, 0)), mode='constant')
        
        # Random brightness adjustment
        if random.random() < 0.5:
            brightness = random.uniform(*self.config.brightness_range)
            image_np = np.clip(image_np * brightness, 0, 1)
        
        # Convert back to tensors
        image = torch.from_numpy(image_np).permute(2, 0, 1).float()
        mask = torch.from_numpy(mask_np).permute(2, 0, 1).float()
        
        return image, mask

def create_dataloaders(config: DatasetConfig, batch_size: int = 16, num_workers: int = 4) -> Tuple[DataLoader, DataLoader, DataLoader]:
    """Create train, validation, and test dataloaders"""
    
    # Create datasets
    train_dataset = SyntheticMedicalDataset(config, mode='train', transform=MedicalDataAugmentation(config))
    val_dataset = SyntheticMedicalDataset(config, mode='val')
    test_dataset = SyntheticMedicalDataset(config, mode='test')
    
    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    
    return train_loader, val_loader, test_loader

if __name__ == "__main__":
    # Test dataset generation
    config = DatasetConfig()
    dataset = SyntheticMedicalDataset(config, mode='train')
    
    print(f"Dataset size: {len(dataset)}")
    print(f"Image shape: {dataset[0][0].shape}")
    print(f"Mask shape: {dataset[0][1].shape}")
    print(f"Number of classes: {config.num_classes}")
    
    # Test dataloader
    train_loader, val_loader, test_loader = create_dataloaders(config, batch_size=4)
    
    print(f"Train batches: {len(train_loader)}")
    print(f"Val batches: {len(val_loader)}")
    print(f"Test batches: {len(test_loader)}")
    
    # Test batch
    for batch_idx, (images, masks) in enumerate(train_loader):
        print(f"Batch {batch_idx}: images {images.shape}, masks {masks.shape}")
        if batch_idx >= 2:  # Test first 3 batches
            break
