import os
import numpy as np
import torch
from typing import Callable, Optional, List, Tuple
from sklearn.model_selection import train_test_split
from PIL import Image
import torchvision.transforms as transforms
from torch.utils.data import Dataset

# Assuming .augment module is available for NDA transforms
from .augment import get_nda_transform_no_resize

class DVSGestureFrameDataset(Dataset):
    """
    Improved DVS Gesture dataset with NDA-based augmentation.
    Processes event streams into temporal frames for SNN processing.
    """
    def __init__(
        self,
        root_dir: str,
        split: str = "train",
        validation_split: float = 0.2,
        random_seed: int = 42,
        transform: Optional[Callable] = None,
        config: Optional[dict] = None,
        num_bins: int = 16,
        width: int = 128,
        height: int = 128,
        encoding_mode: str = "rate",
        preserve_sparsity: bool = False,
        apply_augmentation: bool = True,
        **kwargs
    ):
        self.root_dir = root_dir
        self.split = split
        self.validation_split = validation_split
        self.random_seed = random_seed
        self.num_bins = num_bins
        self.width = width
        self.height = height
        self.encoding_mode = encoding_mode
        self.preserve_sparsity = preserve_sparsity
        self.apply_augmentation = apply_augmentation

        # Data quality filtering parameters
        self.min_spike_ratio = 0.08
        self.max_spike_ratio = 0.22
        self.quality_check_enabled = False

        # NDA Transform setup
        if transform is not None:
            self.transform = transform
        elif config is not None:
            self.transform = get_nda_transform_no_resize(config, split, "dvsgesture")
        else:
            self.transform = None

        # Collect and split files
        self.file_paths, self.labels = self._collect_and_split_files()

    def _collect_and_split_files(self) -> Tuple[List[str], List[int]]:
        """
        Collect files and perform stratified split for train/validation.
        """
        if self.split == 'test':
            return self._collect_original_split('test')

        train_files, train_labels = self._collect_original_split('train')

        if self.split == 'train_full':
            return train_files, train_labels

        train_indices, val_indices = train_test_split(
            range(len(train_files)),
            test_size=self.validation_split,
            stratify=train_labels,
            random_state=self.random_seed,
            shuffle=True
        )

        if self.split == 'train':
            selected_indices = train_indices
        elif self.split == 'validation':
            selected_indices = val_indices
        else:
            raise ValueError(f"Unknown split: {self.split}")

        selected_files = [train_files[i] for i in selected_indices]
        selected_labels = [train_labels[i] for i in selected_indices]

        return selected_files, selected_labels

    def _collect_original_split(self, split_name: str) -> Tuple[List[str], List[int]]:
        """
        Collect files for original train/test splits.
        """
        split_dir = os.path.join(self.root_dir, split_name)
        file_paths = []
        labels = []
        class_dirs = sorted(os.listdir(split_dir), key=lambda x: int(x))
        for cls_idx, cls_name in enumerate(class_dirs):
            cls_dir = os.path.join(split_dir, cls_name)
            if not os.path.isdir(cls_dir):
                continue
            cls_files = [os.path.join(cls_dir, fname) for fname in os.listdir(cls_dir) if fname.endswith(".npz")]
            np.random.seed(self.random_seed)
            np.random.shuffle(cls_files)
            file_paths.extend(cls_files)
            labels.extend([cls_idx] * len(cls_files))
        return file_paths, labels

    def __len__(self) -> int:
        return len(self.file_paths)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
        """
        Load event data, convert to frames, apply quality check, and augment if enabled.
        """
        path = self.file_paths[idx]
        label = self.labels[idx]

        evt = np.load(path)
        t, x, y, p = evt["t"], evt["x"], evt["y"], evt["p"]

        frames = self.events_to_frames_snn_optimized(t, x, y, p)

        if self.quality_check_enabled:
            spike_ratio = (frames > 0).mean()
            if spike_ratio < self.min_spike_ratio or spike_ratio > self.max_spike_ratio:
                new_idx = np.random.randint(0, len(self.file_paths))
                if new_idx != idx:
                    return self.__getitem__(new_idx)

        frames = torch.from_numpy(frames).float()

        if self.transform:
            frames = self.transform(frames)

        return frames, label

    def events_to_frames_snn_optimized(self, t, x, y, p) -> np.ndarray:
        """
        Convert event streams to temporal frames with SNN-optimized encoding.
        """
        t = np.asarray(t, dtype=np.float64)
        x = np.asarray(x, dtype=np.int32)
        y = np.asarray(y, dtype=np.int32)
        p = np.asarray(p, dtype=np.int32)

        frames = np.zeros((self.num_bins, 2, self.height, self.width), dtype=np.float32)

        if len(t) == 0:
            return frames

        t_min, t_max = t.min(), t.max()

        if t_max == t_min:
            bins = np.zeros(len(t), dtype=np.int32)
        else:
            t_norm = (t - t_min) / (t_max - t_min)
            t_adaptive = np.sqrt(t_norm)
            bins = np.clip((t_adaptive * self.num_bins).astype(np.int32), 0, self.num_bins-1)

        valid = (x >= 0) & (x < self.width) & (y >= 0) & (y < self.height)
        x_valid, y_valid, p_valid, bins_valid = x[valid], y[valid], p[valid], bins[valid]

        for xi, yi, pi, bi in zip(x_valid, y_valid, p_valid, bins_valid):
            channel = int(pi > 0)
            temporal_weight = 1.0 + 0.3 * (bi / self.num_bins)
            frames[bi, channel, yi, xi] += temporal_weight

        if self.preserve_sparsity:
            for t_bin in range(self.num_bins):
                for ch in range(2):
                    channel_data = frames[t_bin, ch]
                    max_val = channel_data.max()
                    if max_val > 0:
                        frames[t_bin, ch] = np.tanh(channel_data / max_val)
            else:
                global_max = frames.max()
                if global_max > 0:
                    frames = np.tanh(frames / global_max)

        return frames

class CIFAR10DVSDataset(Dataset):
    """
    CIFAR10-DVS dataset with NDA-based augmentation.
    Processes event streams into temporal frames for SNN processing.
    """
    CLASS_NAMES = [
        'airplane', 'automobile', 'bird', 'cat', 'deer',
        'dog', 'frog', 'horse', 'ship', 'truck'
    ]

    def __init__(
        self,
        root_dir: str,
        split: str = "train",
        validation_split: float = 0.2,
        random_seed: int = 42,
        transform: Optional[Callable] = None,
        config: Optional[dict] = None,
        num_bins: int = 16,
        width: int = 128,
        height: int = 128,
        encoding_mode: str = "rate",
        preserve_sparsity: bool = False,
        apply_augmentation: bool = True,
        **kwargs
    ):
        self.root_dir = root_dir
        self.split = split
        self.validation_split = validation_split
        self.random_seed = random_seed
        self.num_bins = num_bins
        self.width = width
        self.height = height
        self.encoding_mode = encoding_mode
        self.preserve_sparsity = preserve_sparsity
        self.apply_augmentation = apply_augmentation

        # Data quality filtering parameters
        self.min_spike_ratio = 0.05
        self.max_spike_ratio = 0.25
        self.quality_check_enabled = False

        # NDA Transform setup
        if transform is not None:
            self.transform = transform
        elif config is not None:
            self.transform = get_nda_transform_no_resize(config, split, "cifar10_dvs")
        else:
            self.transform = None

        # Collect and split files
        self.file_paths, self.labels = self._collect_and_split_files()

    def _collect_and_split_files(self) -> Tuple[List[str], List[int]]:
        """
        Collect all files and perform stratified split for train/validation.
        """
        all_files, all_labels = self._collect_all_files()

        if self.split == 'train_full':
            return all_files, all_labels

        train_indices, val_indices = train_test_split(
            range(len(all_files)),
            test_size=self.validation_split,
            stratify=all_labels,
            random_state=self.random_seed,
            shuffle=True
        )

        if self.split == 'train':
            selected_indices = train_indices
        elif self.split == 'validation' or self.split == 'test':
            selected_indices = val_indices
        else:
            raise ValueError(f"Unknown split: {self.split}")

        selected_files = [all_files[i] for i in selected_indices]
        selected_labels = [all_labels[i] for i in selected_indices]

        return selected_files, selected_labels

    def _collect_all_files(self) -> Tuple[List[str], List[int]]:
        """
        Collect all CIFAR10-DVS files across classes.
        """
        file_paths = []
        labels = []
        for cls_idx, cls_name in enumerate(self.CLASS_NAMES):
            cls_files = self._collect_class_files(cls_name, cls_idx)
            file_paths.extend(cls_files)
            labels.extend([cls_idx] * len(cls_files))
        return file_paths, labels

    def _collect_class_files(self, class_name: str, class_idx: int) -> List[str]:
        """
        Collect files for a specific class, handling different directory structures.
        """
        cls_files = []
        # Case 1: Class-named directories
        cls_dir = os.path.join(self.root_dir, class_name)
        if os.path.isdir(cls_dir):
            cls_files = [os.path.join(cls_dir, fname) for fname in os.listdir(cls_dir) if fname.endswith(".npz")]
        # Case 2: Numeric directories (0,1,2,...)
        elif os.path.isdir(os.path.join(self.root_dir, str(class_idx))):
            cls_dir = os.path.join(self.root_dir, str(class_idx))
            cls_files = [os.path.join(cls_dir, fname) for fname in os.listdir(cls_dir) if fname.endswith(".npz")]
        # Case 3: Root directory with class-named files
        else:
            cls_files = [os.path.join(self.root_dir, fname) for fname in os.listdir(self.root_dir)
                         if fname.endswith(".npz") and class_name in fname.lower()]

        np.random.seed(self.random_seed + class_idx)
        np.random.shuffle(cls_files)
        return cls_files

    def __len__(self) -> int:
        return len(self.file_paths)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
        """
        Load event data, convert to frames, and apply augmentation if enabled.
        """
        path = self.file_paths[idx]
        label = self.labels[idx]

        evt = np.load(path)
        if "t" in evt and "x" in evt and "y" in evt and "p" in evt:
            t, x, y, p = evt["t"], evt["x"], evt["y"], evt["p"]
        elif "timestamp" in evt and "x" in evt and "y" in evt and "polarity" in evt:
            t, x, y, p = evt["timestamp"], evt["x"], evt["y"], evt["polarity"]
        else:
            raise KeyError("Unknown data structure")

        frames = self.events_to_frames_snn_optimized(t, x, y, p)
        frames = torch.from_numpy(frames).float()

        if self.transform:
            frames = self.transform(frames)

        return frames, label

    def events_to_frames_snn_optimized(self, t, x, y, p) -> np.ndarray:
        """
        Convert event streams to temporal frames with SNN-optimized encoding.
        """
        t = np.asarray(t, dtype=np.float64)
        x = np.asarray(x, dtype=np.int32)
        y = np.asarray(y, dtype=np.int32)
        p = np.asarray(p, dtype=np.int32)

        frames = np.zeros((self.num_bins, 2, self.height, self.width), dtype=np.float32)

        if len(t) == 0:
            return frames

        t_min, t_max = t.min(), t.max()

        if t_max == t_min:
            bins = np.zeros(len(t), dtype=np.int32)
        else:
            t_norm = (t - t_min) / (t_max - t_min)
            t_adaptive = np.sqrt(t_norm)
            bins = np.clip((t_adaptive * self.num_bins).astype(np.int32), 0, self.num_bins-1)

        valid = (x >= 0) & (x < self.width) & (y >= 0) & (y < self.height)
        x_valid, y_valid, p_valid, bins_valid = x[valid], y[valid], p[valid], bins[valid]

        for xi, yi, pi, bi in zip(x_valid, y_valid, p_valid, bins_valid):
            channel = int(pi > 0)
            temporal_weight = 1.0 + 0.3 * (bi / self.num_bins)
            frames[bi, channel, yi, xi] += temporal_weight

        if self.preserve_sparsity:
            for t_bin in range(self.num_bins):
                for ch in range(2):
                    channel_data = frames[t_bin, ch]
                    max_val = channel_data.max()
                    if max_val > 0:
                        frames[t_bin, ch] = np.tanh(channel_data / max_val)
            else:
                global_max = frames.max()
                if global_max > 0:
                    frames = np.tanh(frames / global_max)

        return frames

class ImageNetDataset(Dataset):
    """
    ImageNet dataset with NDA-based augmentation.
    Loads RGB images and repeats them temporally for SNN processing.
    Supports pre-split datasets by disabling internal splitting when validation_split=0.0.
    """
    def __init__(
        self,
        root_dir: str,
        split: str = "train",
        validation_split: float = 0.2,
        random_seed: int = 42,
        transform: Optional[Callable] = None,
        config: Optional[dict] = None,
        num_bins: int = 16,
        width: int = 128,
        height: int = 128,
        max_samples_per_class: Optional[int] = None,
        apply_augmentation: bool = True,
        **kwargs
    ):
        self.root_dir = root_dir
        self.split = split
        self.validation_split = validation_split
        self.random_seed = random_seed
        self.T = num_bins
        self.width = width
        self.height = height
        self.max_samples_per_class = max_samples_per_class
        self.apply_augmentation = apply_augmentation and (split == 'train')

        # ImageNet preprocessing pipeline
        self.image_transform = transforms.Compose([
            transforms.Resize((width, height)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

        # NDA Transform setup
        if transform is not None:
            self.transform = transform
        elif config is not None:
            self.transform = get_nda_transform_no_resize(config, split, "imagenet")
        else:
            self.transform = None

        # Load class mapping
        self.class_to_idx, self.idx_to_class = self._load_class_mapping()

        # Collect and split files (skip internal split if validation_split=0.0)
        self.file_paths, self.labels = self._collect_and_split_files()

    def _load_class_mapping(self) -> Tuple[dict, dict]:
        """
        Scan ImageNet class directories and create mapping.
        """
        class_dirs = sorted([d for d in os.listdir(self.root_dir)
                             if os.path.isdir(os.path.join(self.root_dir, d)) and d.startswith('n')])
        if len(class_dirs) == 0:
            raise ValueError(f"No ImageNet class directories found in {self.root_dir}")
        class_to_idx = {cls_name: idx for idx, cls_name in enumerate(class_dirs)}
        idx_to_class = {idx: cls_name for cls_name, idx in class_to_idx.items()}
        return class_to_idx, idx_to_class

    def _collect_and_split_files(self) -> Tuple[List[str], List[int]]:
        """
        Collect files and perform stratified split (skipped if validation_split=0.0).
        """
        all_files, all_labels = self._collect_all_files()

        if self.validation_split == 0.0:
            return all_files, all_labels

        if self.split == 'train_full':
            return all_files, all_labels

        train_indices, val_indices = train_test_split(
            range(len(all_files)),
            test_size=self.validation_split,
            stratify=all_labels,
            random_state=self.random_seed,
            shuffle=True
        )

        if self.split == 'train':
            selected_indices = train_indices
        elif self.split == 'validation' or self.split == 'test':
            selected_indices = val_indices
        else:
            raise ValueError(f"Unknown split: {self.split}")

        selected_files = [all_files[i] for i in selected_indices]
        selected_labels = [all_labels[i] for i in selected_indices]

        return selected_files, selected_labels

    def _collect_all_files(self) -> Tuple[List[str], List[int]]:
        """
        Collect all ImageNet files across classes.
        """
        file_paths = []
        labels = []
        valid_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.JPEG', '.JPG', '.PNG'}
        for class_name, class_idx in self.class_to_idx.items():
            class_dir = os.path.join(self.root_dir, class_name)
            if not os.path.isdir(class_dir):
                continue
            cls_files = [os.path.join(class_dir, fname) for fname in os.listdir(class_dir)
                         if any(fname.endswith(ext) for ext in valid_extensions)]
            if len(cls_files) == 0:
                continue
            np.random.seed(self.random_seed + class_idx)
            np.random.shuffle(cls_files)
            if self.max_samples_per_class is not None:
                cls_files = cls_files[:self.max_samples_per_class]
            file_paths.extend(cls_files)
            labels.extend([class_idx] * len(cls_files))
        return file_paths, labels

    def __len__(self) -> int:
        return len(self.file_paths)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
        """
        Load RGB image, apply ImageNet normalization, repeat temporally, and augment if enabled.
        """
        path = self.file_paths[idx]
        label = self.labels[idx]

        image = self.load_rgb_image(path)  # [3, H, W]

        snn_input = image.unsqueeze(0).repeat(self.T, 1, 1, 1)  # [T, 3, H, W]

        if self.transform:
            snn_input = self.transform(snn_input)

        return snn_input, label

    def load_rgb_image(self, path: str) -> torch.Tensor:
        """
        Load RGB image and apply ImageNet normalization.
        """
        image = Image.open(path).convert('RGB')
        image_tensor = self.image_transform(image)  # [3, H, W], normalized
        return image_tensor

# Dataset factory function
def create_dataset(dataset_name: str, config: Optional[dict] = None, **kwargs):
    """
    Factory function to create datasets with optional config for NDA.
    """
    dataset_name = dataset_name.lower()
    if dataset_name in ['dvs_gesture', 'dvsgesture', 'gesture']:
        return DVSGestureFrameDataset(config=config, **kwargs)
    elif dataset_name in ['cifar10_dvs', 'cifar10dvs', 'cifar10-dvs']:
        return CIFAR10DVSDataset(config=config, **kwargs)
    elif dataset_name in ['imagenet', 'ilsvrc', 'ilsvrc2012']:
        return ImageNetDataset(config=config, **kwargs)
    else:
        raise ValueError(f"Unsupported dataset: {dataset_name}. Supported: ['dvs_gesture', 'cifar10_dvs', 'imagenet']")

# DataLoader creation helper
def create_snn_dataloader(dataset, batch_size=16, shuffle=True, num_workers=0, **kwargs):
    """
    Create DataLoader optimized for SNN datasets.
    """
    from torch.utils.data import DataLoader
    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=False,
        drop_last=True if shuffle else False,
        persistent_workers=True if num_workers > 0 else False,
        **kwargs
    )

# Config-based dataset creation helper
def create_dataset_with_config(dataset_name: str, config: dict, split: str = "train", **kwargs):
    """
    Helper to create dataset using provided config.
    """
    dataset_config = config.get('dataset', {})
    dataset_kwargs = {
        'split': split,
        'config': config,
        'apply_augmentation': True,
        **dataset_config,
        **kwargs
    }
    return create_dataset(dataset_name, **dataset_kwargs)

# Dataset information utility
def get_dataset_info(dataset_name: str) -> dict:
    """
    Return information about the specified dataset.
    """
    dataset_name = dataset_name.lower()
    if dataset_name in ['dvs_gesture', 'dvsgesture', 'gesture']:
        return {
            'name': 'DVS Gesture',
            'num_classes': 11,
            'input_size': 128,
            'channels': 2,
            'description': 'DVS Gesture Recognition Dataset'
        }
    elif dataset_name in ['cifar10_dvs', 'cifar10dvs', 'cifar10-dvs']:
        return {
            'name': 'CIFAR10-DVS',
            'num_classes': 10,
            'input_size': 128,
            'channels': 2,
            'class_names': CIFAR10DVSDataset.CLASS_NAMES,
            'description': 'CIFAR10-DVS Object Classification Dataset'
        }
    elif dataset_name in ['imagenet', 'ilsvrc', 'ilsvrc2012']:
        return {
            'name': 'ImageNet',
            'num_classes': 1000,
            'input_size': 128,
            'channels': 3,
            'normalization': 'ImageNet Standard (mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])',
            'value_range': '[-2.1, +2.6]',
            'description': 'ImageNet Large Scale Visual Recognition Dataset with RGB direct usage'
        }
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")

if __name__ == "__main__":
    pass  # No execution; module for import
