import torch
import numpy as np

class InvisiblePointEstimator:
    def __init__(self, min_visible_for_estimation=3):
        """
        Args:
            min_visible_for_estimation (int): Minimum number of visible points required
                                             in a sample to attempt estimation.
                                             Needs at least 2 for std dev, 3 for more stability.
        """
        self.min_visible_for_estimation = min_visible_for_estimation
        self._norm_means = None # To store means for each sample in the batch
        self._norm_stds = None  # To store stds for each sample in the batch

    def _normalize_coordinates_batch(self, coords_batch, visibility_batch):
        """
        Normalizes coordinates for each sample in the batch based on its visible points.
        coords_batch: (bs, K, 2) tensor of 2D coordinates.
        visibility_batch: (bs, K) boolean tensor, True if visible.
        Returns: normalized_coords_batch
        """
        bs, K, D = coords_batch.shape
        normalized_coords_list = []
        self._norm_means = torch.zeros(bs, D, device=coords_batch.device)
        self._norm_stds = torch.ones(bs, D, device=coords_batch.device)

        for i in range(bs):
            sample_coords = coords_batch[i]  # (K, 2)
            sample_visibility = visibility_batch[i]  # (K,)
            
            visible_coords = sample_coords[sample_visibility]  # (num_vis, 2)

            if visible_coords.shape[0] >= 2: # Need at least 2 points for std
                mean = torch.mean(visible_coords, dim=0)  # (2,)
                std = torch.std(visible_coords, dim=0)    # (2,)
                std = torch.where(std < 1e-6, torch.ones_like(std) * 1e-6, std) # Avoid division by zero
                
                self._norm_means[i] = mean
                self._norm_stds[i] = std
                normalized_sample_coords = (sample_coords - mean) / std
                normalized_coords_list.append(normalized_sample_coords)
            else:
                # Not enough visible points to normalize reliably, use original or zero-centered
                # For simplicity, we'll just use the original if not enough points,
                # or you could choose to zero-center if only 1 point.
                # Here, we effectively do not normalize (mean=0, std=1 for these samples)
                normalized_coords_list.append(sample_coords.clone())
        
        return torch.stack(normalized_coords_list)

    def _denormalize_coordinates_batch(self, normalized_coords_batch):
        """
        Denormalizes coordinates for each sample in the batch using stored means and stds.
        """
        if self._norm_means is None or self._norm_stds is None:
            # This case should ideally not be hit if normalize was called
            return normalized_coords_batch
        
        # Expand means and stds for broadcasting: (bs, 1, 2)
        means_expanded = self._norm_means.unsqueeze(1)
        stds_expanded = self._norm_stds.unsqueeze(1)
        
        return normalized_coords_batch * stds_expanded + means_expanded

    def estimate_invisible(self, coords_batch, visibility_batch):
        """
        Estimates coordinates for invisible points in a batch.

        Args:
            coords_batch (torch.Tensor): Batch of coordinates (bs, K, 2).
                                         Assumes invisible points might be (0,0) or
                                         as indicated by visibility_batch.
            visibility_batch (torch.Tensor): Batch of boolean visibility flags (bs, K).
                                             True for visible, False for invisible.
        Returns:
            torch.Tensor: Batch of coordinates (bs, K, 2) with invisible points estimated.
        """
        if not isinstance(coords_batch, torch.Tensor):
            coords_batch = torch.tensor(coords_batch, dtype=torch.float32)
        if not isinstance(visibility_batch, torch.Tensor):
            visibility_batch = torch.tensor(visibility_batch, dtype=torch.bool)

        bs, K, D = coords_batch.shape
        if D != 2:
            raise ValueError("Coordinate dimension D must be 2.")

        # 1. Normalize coordinates
        normalized_coords_batch = self._normalize_coordinates_batch(coords_batch, visibility_batch)
        
        output_normalized_coords = normalized_coords_batch.clone()

        for i in range(bs):
            sample_norm_coords = normalized_coords_batch[i]  # (K, 2)
            sample_visibility = visibility_batch[i]      # (K,)
            
            num_visible = torch.sum(sample_visibility).item()

            if num_visible >= self.min_visible_for_estimation:
                visible_norm_coords = sample_norm_coords[sample_visibility] # (num_vis, 2)
                
                # Estimate mean from visible normalized points (MLE for Gaussian mean)
                mu_hat_norm = torch.mean(visible_norm_coords, dim=0)  # (2,)
                
                # Fill invisible points with this estimated normalized mean
                invisible_indices = ~sample_visibility
                output_normalized_coords[i, invisible_indices, :] = mu_hat_norm
            # else: If not enough visible points, invisible points remain as their
            #       original (normalized) values. If they were (0,0) and normalization
            #       was skipped for this sample, they remain (0,0). If normalization
            #       happened with default params, they are normalized (0,0).
            
        # 2. Denormalize coordinates
        estimated_coords_batch = self._denormalize_coordinates_batch(output_normalized_coords)
        
        return estimated_coords_batch


if __name__ == '__main__':
    # --- Example Usage ---
    estimator = InvisiblePointEstimator(min_visible_for_estimation=3)

    # Batch of 2 samples, 5 keypoints, 2D
    # Sample 0: points 2 and 4 are invisible
    coords_data = torch.tensor([
        [[10, 20], [12, 22], [0, 0], [15, 25], [0, 0]],  # Sample 0
        [[100, 200], [0, 0], [105, 205], [108, 208], [0,0]]   # Sample 1
    ], dtype=torch.float32)

    # Visibility: True for visible, False for invisible
    # For sample 0, points 0,1,3 are visible. For sample 1, points 0,2,3 are visible.
    visibility_data = torch.tensor([
        [True, True, False, True, False],
        [True, False, True, True, False]
    ], dtype=torch.bool)
    
    print("Original Coordinates:\n", coords_data)
    print("Original Visibility:\n", visibility_data)

    estimated_coords = estimator.estimate_invisible(coords_data, visibility_data)
    print("\nEstimated Coordinates (invisible points filled):\n", estimated_coords)

    # Test case with too few visible points for one sample
    coords_data_few_vis = torch.tensor([
        [[10, 20], [0, 0], [0, 0], [0, 0], [0, 0]],  # Sample 0 (only 1 visible)
        [[100, 200], [102,202], [105, 205], [108, 208], [0,0]]   # Sample 1 (4 visible)
    ], dtype=torch.float32)
    visibility_data_few_vis = torch.tensor([
        [True, False, False, False, False],
        [True, True, True, True, False]
    ], dtype=torch.bool)

    print("\nOriginal Coords (with few visible):\n", coords_data_few_vis)
    estimated_coords_few_vis = estimator.estimate_invisible(coords_data_few_vis, visibility_data_few_vis)
    print("\nEstimated Coords (with few visible):\n", estimated_coords_few_vis)
    # For sample 0, invisible points should ideally remain close to their original (0,0)
    # or a normalized version of (0,0) if normalization defaults were used.