from torch.utils.data import Dataset, DataLoader
import h5py
import torch
import numpy as np
import os
import threading
from functools import lru_cache

def build_dataloader(training_size: int, batch_size: int = 48):
    """
    Build the dataloader for the training and validation set
    """
    # Optimized trajectory indices generation
    #traj_indices = np.random.permutation(600).astype(np.int32)
    #train_indices = traj_indices[:int(600*training_size)]
    #val_indices = traj_indices[int(600*training_size):]

    # Fixed validation indices
    val_indices = np.array([476, 105, 389, 1, 558, 80, 205, 34, 508, 427, 454, 366, 91, 339, 345, 241, 13, 315,
                            387, 273, 166, 594, 484, 585, 504, 243, 562, 189, 475, 510, 58, 474, 560, 252, 21, 313,
                            459, 160, 276, 191, 385, 413, 491, 343, 308, 130, 99, 372, 87, 458, 330, 214, 466, 121,
                            20, 71, 106, 270, 435, 102], dtype=np.int32)
    
    # Generate training indices as all indices not in val_indices
    all_indices = np.arange(600, dtype=np.int32)
    train_indices = np.setdiff1d(all_indices, val_indices, assume_unique=True)
    
    # Calculate optimal number of workers based on CPU count
    # More workers help with I/O bound operations like HDF5 reading
    num_workers = min(os.cpu_count() or 4, 6)  # Cap at 6 workers to avoid overhead
    
    # Create optimized datasets
    train_dataset = OptimizedCFDDataset(train_indices)
    val_dataset = OptimizedCFDDataset(val_indices)
    
    # Create dataloaders with optimal settings
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,  # Keep pin_memory but avoid direct CUDA operations
        persistent_workers=True,  # Keep workers alive between iterations
        drop_last=True  # Ensure consistent batch sizes for better GPU utilization
    )
    
    val_dataloader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
        persistent_workers=True
    )
    
    return train_dataloader, val_dataloader


class OptimizedCFDDataset(Dataset):
    """
    High-performance dataset for 3D CFD data
    """
    # Class-level shared resources for thread safety
    _file_lock = threading.RLock()
    _shared_file = None
    _reference_count = 0
    _cache = {}
    _MAX_CACHE_SIZE = 256  # Adjust based on your available RAM
    
    def __init__(self, indices):
        """Initialize the dataset with trajectory indices"""
        self.data_path = ""
        self.fields = ['Vx', 'Vy', 'Vz', 'density', 'pressure']
        self.field_count = len(self.fields)
        
        # Generate all valid index pairs efficiently
        self.indices = self._generate_indices(indices)
        
        # Initialize shared file handle
        with self.__class__._file_lock:
            if self.__class__._shared_file is None:
                self.__class__._shared_file = h5py.File(
                    self.data_path, 'r',
                    libver='latest',  # Use latest HDF5 format for performance
                    swmr=True,        # Single-writer/multiple-reader mode
                    rdcc_nbytes=64*1024*1024  # 64MB chunk cache
                )
            self.__class__._reference_count += 1
            
        # Determine shapes once
        self._get_shapes()
            
        # Thread-local buffers for performance
        self._thread_local = threading.local()
    
    def _get_shapes(self):
        """Determine the shapes of the data"""
        with h5py.File(self.data_path, 'r') as f:
            sample = f[self.fields[0]][0, 0]
            self.original_shape = sample.shape
            # Don't downsample - use original shape
            self.data_shape = self.original_shape
    
    def _get_buffers(self):
        """Get thread-local buffers for better performance"""
        if not hasattr(self._thread_local, 'buffers'):
            shape = (self.field_count,) + self.data_shape
            self._thread_local.buffers = {
                'in': np.zeros(shape, dtype=np.float32),
                'out': np.zeros(shape, dtype=np.float32)
            }
        return self._thread_local.buffers
    
    def __del__(self):
        """Clean up resources"""
        with self.__class__._file_lock:
            self.__class__._reference_count -= 1
            if self.__class__._reference_count == 0 and self.__class__._shared_file is not None:
                self.__class__._shared_file.close()
                self.__class__._shared_file = None
    
    def _generate_indices(self, indices):
        """Generate all valid (traj_id, time_id) pairs efficiently"""
        # Pre-allocate arrays for performance
        time_steps = np.arange(20, dtype=np.int32)  # 0-19, allowing for t+1
        
        # Use numpy broadcasting for better performance
        traj_grid, time_grid = np.meshgrid(indices, time_steps, indexing='ij')
        
        # Stack into a single array of pairs
        all_indices = np.stack([traj_grid.ravel(), time_grid.ravel()], axis=1)
        
        return all_indices
    
    def __len__(self):
        """Return the number of samples"""
        return len(self.indices)
    
    @lru_cache(maxsize=32)
    def _read_field(self, field, traj_id, time_id):
        """Read a field with caching for better performance"""
        with self.__class__._file_lock:
            data = self.__class__._shared_file[field][traj_id, time_id]
            return data.copy()  # Return copy to avoid HDF5 access issues
    
    def __getitem__(self, idx):
        """Get a sample at the given index"""
        # Get trajectory and time indices
        traj_id, time_id = self.indices[idx]
        
        # Check if result is in cache
        cache_key = (traj_id, time_id)
        with self.__class__._file_lock:
            if cache_key in self.__class__._cache:
                return self.__class__._cache[cache_key]
        
        # Get thread-local buffers
        buffers = self._get_buffers()
        in_buffer = buffers['in']
        out_buffer = buffers['out']
        
        # Read data for current and next time steps
        for i, field in enumerate(self.fields):
            # Current time step
            in_buffer[i] = self._read_field(field, traj_id, time_id)
            # Next time step
            out_buffer[i] = self._read_field(field, traj_id, time_id + 1)
        
        # Convert to torch tensors (CPU only, no CUDA)
        # Use copy() to ensure memory contiguity for better performance
        batch_in = torch.from_numpy(in_buffer.copy()).float()
        batch_out = torch.from_numpy(out_buffer.copy()).float()

        # Store result in cache
        result = (batch_in, batch_out)
        with self.__class__._file_lock:
            # Simple LRU caching - store result if cache isn't too large
            if len(self.__class__._cache) < self.__class__._MAX_CACHE_SIZE:
                self.__class__._cache[cache_key] = result
                # Remove oldest item if cache is full
                if len(self.__class__._cache) >= self.__class__._MAX_CACHE_SIZE:
                    oldest_key = next(iter(self.__class__._cache))
                    del self.__class__._cache[oldest_key]
        
        return result