"""
Dataset classes for precomputed features with augmentation support.

This module implements dataset classes for working with precomputed feature
vectors, supporting data augmentation through multiple feature versions.
"""

import os
import torch
import random
from torch.utils.data import Dataset, DataLoader


class PrecomputedFeatureDataset(Dataset):
    """Dataset for precomputed features with augmentation support

    Args:
        features_path: Path to precomputed features tensor
        labels_path: Path to labels tensor
        verbose: Whether to print detailed logs
        augmentation_paths: List of paths to augmented feature/label pairs
        use_augmentation: Whether to use augmented versions when available
    """
    def __init__(self, features_path, labels_path, verbose=False,
                 augmentation_paths=None, use_augmentation=True):
        super().__init__()

        # Store augmentation settings
        self.training = True  # Default to training mode
        self.use_augmentation = use_augmentation
        self.augmentation_paths = []
        if augmentation_paths is not None:
            self.augmentation_paths = augmentation_paths

        # Verify input files exist
        if not os.path.exists(features_path):
            raise FileNotFoundError(f"Features file not found: {features_path}")
        if not os.path.exists(labels_path):
            raise FileNotFoundError(f"Labels file not found: {labels_path}")

        # Load base features and labels
        try:
            self.features = torch.load(features_path)
            if verbose:
                print(f"Loaded features: {self.features.shape}")
        except Exception as e:
            raise RuntimeError(f"Failed to load features from {features_path}: {e}")

        try:
            self.labels = torch.load(labels_path)
            if verbose:
                print(f"Loaded labels: {self.labels.shape}")
        except Exception as e:
            raise RuntimeError(f"Failed to load labels from {labels_path}: {e}")

        # Validate dimensions match
        if len(self.features) != len(self.labels):
            raise ValueError(f"Features ({len(self.features)}) and labels ({len(self.labels)}) count mismatch")

        # Load augmented versions if available
        self.augmented_features = []
        self.augmented_labels = []

        if augmentation_paths and use_augmentation:
            for aug_feat_path, aug_label_path in augmentation_paths:
                if os.path.exists(aug_feat_path) and os.path.exists(aug_label_path):
                    try:
                        aug_features = torch.load(aug_feat_path)
                        aug_labels = torch.load(aug_label_path)

                        # Verify shapes match
                        if aug_features.shape == self.features.shape and aug_labels.shape == self.labels.shape:
                            self.augmented_features.append(aug_features)
                            self.augmented_labels.append(aug_labels)
                            if verbose:
                                print(f"Loaded augmentation: {os.path.basename(aug_feat_path)}")
                    except Exception as e:
                        if verbose:
                            print(f"Failed to load augmentation: {e}")

    def __len__(self):
        """Get the number of samples in the dataset"""
        return len(self.features)

    def __getitem__(self, idx):
        """Get a sample from the dataset

        During training, may randomly select from augmented versions.

        Args:
            idx: Sample index

        Returns:
            Dictionary with features, labels, and metadata
        """
        # During training, randomly choose from augmented versions if available
        if self.training and self.augmented_features and self.use_augmentation and random.random() > 0.2:
            # 80% chance to use augmented features
            aug_idx=random.randint(0, len(self.augmented_features) - 1)
            return {
                "features": self.augmented_features[aug_idx][idx],
                "labels": self.augmented_labels[aug_idx][idx],
                "index": idx,
                "augmented": True
            }
        else:
            # Use original features or when evaluating
            return {
                "features": self.features[idx],
                "labels": self.labels[idx],
                "index": idx,
                "augmented": False
            }

    def train(self, mode=True):
        """Set dataset in training or evaluation mode

        Args:
            mode: True for training mode, False for evaluation

        Returns:
            Self for method chaining
        """
        self.training = mode
        return self


class PrecomputedFeatures:
    """Dataset container class for precomputed features with augmentation support

    Args:
        feature_dir: Path to directory with precomputed features
        batch_size: Batch size for dataloaders
        num_workers: Number of worker threads for dataloaders
        persistent_workers: Whether to keep worker processes alive
        use_augmentation: Whether to use augmentations during training
    """

    def __init__(self,
                 feature_dir,
                 batch_size=128,
                 num_workers=8,
                 persistent_workers=False,
                 use_augmentation=True):
        # Verify directory exists
        if not os.path.exists(feature_dir):
            raise FileNotFoundError(f"Feature directory not found: {feature_dir}")

        # Define file paths
        train_features_path = os.path.join(feature_dir, "train_features.pt")
        train_labels_path = os.path.join(feature_dir, "train_labels.pt")
        val_features_path = os.path.join(feature_dir, "val_features.pt")
        val_labels_path = os.path.join(feature_dir, "val_labels.pt")

        # Check if train files exist
        if not os.path.exists(train_features_path):
            raise FileNotFoundError(f"Train features not found at {train_features_path}")

        # Find augmented versions
        augmentation_paths = []
        aug_idx = 1

        while True:
            aug_feat_path = os.path.join(feature_dir, f"train_features_aug{aug_idx}.pt")
            aug_label_path = os.path.join(feature_dir, f"train_labels_aug{aug_idx}.pt")

            if os.path.exists(aug_feat_path) and os.path.exists(aug_label_path):
                augmentation_paths.append((aug_feat_path, aug_label_path))
                aug_idx += 1
            else:
                break

        # Create train dataset
        self.train_dataset = PrecomputedFeatureDataset(
            train_features_path,
            train_labels_path,
            verbose=False,
            augmentation_paths=augmentation_paths,
            use_augmentation=use_augmentation
        )
        self.train_dataset.train(True)

        # Create train loader
        self.train_loader = DataLoader(
            self.train_dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=num_workers,
            persistent_workers=persistent_workers and num_workers > 0,
            pin_memory=True,
            drop_last=False,
            timeout=120,  # Add timeout to prevent hangs
        )

        # Use validation set if available, otherwise test set
        test_features_path = val_features_path if os.path.exists(val_features_path) else train_features_path
        test_labels_path = val_labels_path if os.path.exists(val_labels_path) else train_labels_path

        # Create test dataset (no augmentation for evaluation)
        self.test_dataset = PrecomputedFeatureDataset(
            test_features_path,
            test_labels_path,
            verbose=False,
            augmentation_paths=None,
            use_augmentation=False
        )
        self.test_dataset.train(False)

        # Create test loader
        self.test_loader = DataLoader(
            self.test_dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
            persistent_workers=persistent_workers and num_workers > 0,
            pin_memory=True,
            drop_last=False,
            timeout=120,
        )

        # Load classnames if available
        classnames_path = os.path.join(feature_dir, "classnames.txt")
        if os.path.exists(classnames_path):
            with open(classnames_path, "r") as f:
                self.classnames = [line.strip() for line in f.readlines()]
        else:
            # Create dummy classnames if file doesn't exist
            unique_labels = torch.unique(self.train_dataset.labels)
            self.classnames = [f"class_{i}" for i in range(len(unique_labels))]