"""
Grid management utilities for creating and manipulating latent grids.
"""

import torch
import numpy as np
from typing import Tuple, Union, Optional


class GridManager:
    """
    Manages creation and manipulation of regular grids for different spatial dimensions.
    """
    
    def __init__(
        self,
        spatial_dim: int,
        grid_resolution: Union[int, Tuple[int, ...]],
        domain_bounds: Tuple[Tuple[float, float], ...],
        device: str = 'cpu'
    ):
        """
        Args:
            spatial_dim: Spatial dimension (1, 2, or 3)
            grid_resolution: Number of grid points per dimension
            domain_bounds: Bounds for each dimension [(min, max), ...]
            device: Device to create tensors on
        """
        self.spatial_dim = spatial_dim
        self._device = device
        self.domain_bounds = domain_bounds
        
        # Handle grid resolution
        if isinstance(grid_resolution, int):
            self.grid_resolution = (grid_resolution,) * spatial_dim
        else:
            self.grid_resolution = tuple(grid_resolution)
        
        assert len(self.grid_resolution) == spatial_dim, \
            f"grid_resolution length ({len(self.grid_resolution)}) must match spatial_dim ({spatial_dim})"
        assert len(domain_bounds) == spatial_dim, \
            f"domain_bounds length ({len(domain_bounds)}) must match spatial_dim ({spatial_dim})"
        
        # Compute grid spacing
        self.grid_spacing = tuple(
            (bounds[1] - bounds[0]) / (res - 1)
            for bounds, res in zip(domain_bounds, self.grid_resolution)
        )
        
        # Total number of grid points
        self.num_grid_points = int(np.prod(self.grid_resolution))
        
        # Cache for grid coordinates (will be created on demand)
        self._grid_cache = {}
    
    def to(self, device):
        """Move grid manager to device and clear cache."""
        device_str = str(device) if not isinstance(device, str) else device
        if device_str != self._device:
            self._device = device_str
            # Clear cache when device changes
            self._grid_cache = {}
        return self
    
    @property
    def device(self):
        """Get current device."""
        return self._device
    
    def _create_grid(self, device: str) -> torch.Tensor:
        """
        Create regular grid points.
        
        Args:
            device: Device to create grid on
        
        Returns:
            Grid points, shape (total_points, spatial_dim)
        """
        # Create 1D grids for each dimension
        grids_1d = []
        for i in range(self.spatial_dim):
            lower, upper = self.domain_bounds[i]
            n_points = self.grid_resolution[i]
            grid_1d = torch.linspace(lower, upper, n_points, device=device)
            grids_1d.append(grid_1d)
        
        # Create meshgrid
        if self.spatial_dim == 1:
            return grids_1d[0].unsqueeze(-1)
        
        elif self.spatial_dim == 2:
            # Use indexing='ij' for matrix/Cartesian ordering
            try:
                grid_x, grid_y = torch.meshgrid(grids_1d[0], grids_1d[1], indexing='ij')
            except TypeError:
                # Fallback for older PyTorch versions
                grid_x, grid_y = torch.meshgrid(grids_1d[0], grids_1d[1])
            
            # Stack as (x, y) for Cartesian coordinates
            return torch.stack([grid_x.flatten(), grid_y.flatten()], dim=-1)
        
        elif self.spatial_dim == 3:
            try:
                grid_x, grid_y, grid_z = torch.meshgrid(
                    grids_1d[0], grids_1d[1], grids_1d[2], indexing='ij'
                )
            except TypeError:
                # Fallback for older PyTorch versions
                grid_x, grid_y, grid_z = torch.meshgrid(
                    grids_1d[0], grids_1d[1], grids_1d[2]
                )
            
            return torch.stack([
                grid_x.flatten(),
                grid_y.flatten(),
                grid_z.flatten()
            ], dim=-1)
        
        else:
            raise ValueError(f"Unsupported spatial dimension: {self.spatial_dim}")
    
    def get_grid(
        self, 
        batch_size: int = 1,
        device: Optional[str] = None
    ) -> torch.Tensor:
        """
        Get grid points for a batch.
        
        Args:
            batch_size: Number of batches
            device: Device to create grid on (if None, uses self._device)
        
        Returns:
            Grid points, shape (batch_size, num_grid_points, spatial_dim)
        """
        if device is None:
            device = self._device
        
        # Convert device to string for cache key
        device_str = str(device)
        
        # Check cache
        cache_key = (batch_size, device_str)
        if cache_key in self._grid_cache:
            return self._grid_cache[cache_key]
        
        # Create grid
        grid_points = self._create_grid(device_str)
        
        # Expand for batch
        grid_batch = grid_points.unsqueeze(0).expand(batch_size, -1, -1)
        
        # Cache it
        self._grid_cache[cache_key] = grid_batch
        
        return grid_batch
    
    def reshape_to_grid(self, values: torch.Tensor) -> torch.Tensor:
        """
        Reshape flattened values back to grid shape.
        
        Args:
            values: Flattened values, shape (batch, channels, num_grid_points)
        
        Returns:
            Reshaped values:
                1D: (batch, channels, grid_size)
                2D: (batch, channels, height, width)
                3D: (batch, channels, depth, height, width)
        """
        batch_size, channels, num_points = values.shape
        
        if num_points != self.num_grid_points:
            raise ValueError(
                f"Number of points ({num_points}) doesn't match grid "
                f"({self.num_grid_points})"
            )
        
        if self.spatial_dim == 1:
            return values.reshape(batch_size, channels, self.grid_resolution[0])
        elif self.spatial_dim == 2:
            return values.reshape(
                batch_size, channels,
                self.grid_resolution[0], self.grid_resolution[1]
            )
        elif self.spatial_dim == 3:
            return values.reshape(
                batch_size, channels,
                self.grid_resolution[0],
                self.grid_resolution[1],
                self.grid_resolution[2]
            )
        else:
            raise ValueError(f"Unsupported spatial_dim: {self.spatial_dim}")
    
    def flatten_from_grid(self, values: torch.Tensor) -> torch.Tensor:
        """
        Flatten grid-shaped values.
        
        Args:
            values: Grid-shaped values
        
        Returns:
            Flattened values, shape (batch, channels, num_grid_points)
        """
        batch_size, channels = values.shape[:2]
        return values.reshape(batch_size, channels, -1)
    
    def get_grid_spacing(self) -> Tuple[float, ...]:
        """Get grid spacing for each dimension."""
        return self.grid_spacing
    
    def nearest_grid_indices(self, x: torch.Tensor) -> torch.Tensor:
        """
        Find nearest grid indices for given points.
        
        Args:
            x: Points, shape (batch, n_points, spatial_dim)
        
        Returns:
            Grid indices, shape (batch, n_points, spatial_dim)
        """
        indices = []
        
        for i in range(self.spatial_dim):
            lower, upper = self.domain_bounds[i]
            n_points = self.grid_resolution[i]
            
            # Normalize to [0, n_points - 1]
            x_dim = x[..., i]
            x_norm = (x_dim - lower) / (upper - lower) * (n_points - 1)
            x_indices = torch.clamp(torch.round(x_norm).long(), 0, n_points - 1)
            indices.append(x_indices)
        
        return torch.stack(indices, dim=-1)
    
    def get_boundary_mask(self, device: Optional[str] = None) -> torch.Tensor:
        """
        Get mask indicating which grid points are on the boundary.
        
        Args:
            device: Device to create mask on (if None, uses self._device)
        
        Returns:
            Boolean mask, shape (num_grid_points,)
        """
        if device is None:
            device = self._device
        
        if self.spatial_dim == 1:
            mask = torch.zeros(self.num_grid_points, dtype=torch.bool, device=device)
            mask[0] = True
            mask[-1] = True
            return mask
        
        elif self.spatial_dim == 2:
            h, w = self.grid_resolution
            mask = torch.zeros(h, w, dtype=torch.bool, device=device)
            mask[0, :] = True  # Top
            mask[-1, :] = True  # Bottom
            mask[:, 0] = True  # Left
            mask[:, -1] = True  # Right
            return mask.flatten()
        
        elif self.spatial_dim == 3:
            d, h, w = self.grid_resolution
            mask = torch.zeros(d, h, w, dtype=torch.bool, device=device)
            # All 6 faces
            mask[0, :, :] = True  # Front
            mask[-1, :, :] = True  # Back
            mask[:, 0, :] = True  # Top
            mask[:, -1, :] = True  # Bottom
            mask[:, :, 0] = True  # Left
            mask[:, :, -1] = True  # Right
            return mask.flatten()
        
        else:
            raise ValueError(f"Unsupported spatial_dim: {self.spatial_dim}")
    
    def get_interior_mask(self, device: Optional[str] = None) -> torch.Tensor:
        """Get mask for interior (non-boundary) points."""
        return ~self.get_boundary_mask(device=device)
    
    def interpolate_to_points(
        self,
        grid_features: torch.Tensor,
        x: torch.Tensor,
        mode: str = 'bilinear'
    ) -> torch.Tensor:
        """
        Interpolate grid features to arbitrary points.
        
        Args:
            grid_features: Features on grid, shape (batch, channels, *grid_resolution)
            x: Target points, shape (batch, n_points, spatial_dim)
            mode: Interpolation mode ('nearest', 'bilinear', 'bicubic')
        
        Returns:
            Interpolated features, shape (batch, channels, n_points)
        """
        batch_size, n_points, _ = x.shape
        
        # Normalize coordinates to [-1, 1] for grid_sample
        normalized_coords = []
        for dim in range(self.spatial_dim):
            x_min, x_max = self.domain_bounds[dim]
            x_dim = x[..., dim]
            # Map [x_min, x_max] to [-1, 1]
            normalized = 2 * (x_dim - x_min) / (x_max - x_min) - 1
            normalized_coords.append(normalized)
        
        # Stack coordinates
        if self.spatial_dim == 1:
            # grid_sample 1D not directly supported, use 2D with height=1
            coords = torch.stack([
                normalized_coords[0],
                torch.zeros_like(normalized_coords[0])
            ], dim=-1)  # (batch, n_points, 2)
            coords = coords.unsqueeze(1)  # (batch, 1, n_points, 2)
            
            # Add dummy dimension to features
            grid_features_2d = grid_features.unsqueeze(-1)  # (batch, channels, n_x, 1)
            
            # Interpolate
            interpolated = torch.nn.functional.grid_sample(
                grid_features_2d,
                coords,
                mode='bilinear' if mode != 'nearest' else 'nearest',
                padding_mode='border',
                align_corners=True
            )  # (batch, channels, 1, n_points)
            
            # Remove dummy dimensions
            interpolated = interpolated.squeeze(2)  # (batch, channels, n_points)
        
        elif self.spatial_dim == 2:
            # Stack as (x, y) -> need to reverse for grid_sample (expects y, x)
            coords = torch.stack([
                normalized_coords[1],  # y
                normalized_coords[0]   # x
            ], dim=-1)  # (batch, n_points, 2)
            coords = coords.unsqueeze(1)  # (batch, 1, n_points, 2)
            
            # Interpolate
            interpolated = torch.nn.functional.grid_sample(
                grid_features,
                coords,
                mode=mode,
                padding_mode='border',
                align_corners=True
            )  # (batch, channels, 1, n_points)
            
            interpolated = interpolated.squeeze(2)  # (batch, channels, n_points)
        
        else:
            raise NotImplementedError("3D interpolation not yet implemented")
        
        return interpolated