import os
DISTRL_DEBUG_IMAGEPOOL_SAVE_IMAGE = os.environ.get("DISTRL_DEBUG_IMAGEPOOL_SAVE_IMAGE", None)
DISTRL_DEBUG_DETAIL_LOG = os.environ.get("DISTRL_DEBUG_DETAIL_LOG", None)

import torch
from tqdm import tqdm
from typing import List, Tuple, Optional
import pickle
import logging
import math
import random
import numpy as np
import torch.distributed as dist
import time
from PIL import Image

# Import FID computation utilities
from distrl.autofid.fid import calculate_activation_statistics, calculate_frechet_distance
from distrl.models.diffusion_models.edm2.original.calculate_metrics import InceptionV3Detector

logger = logging.getLogger(__name__)

class ImagePool:
    """
    Class for managing a pool of generated images, their noise vectors (xT), prompts, and Inception features.
    Supports distributed generation and feature extraction.
    """
    def __init__(
        self,
        output_dir: str,
        device: torch.device,
        is_distributed: bool = False,
        pool_size: int = 10000,
        batch_size: int = 4,
        image_size: int = 512,
        rank: int = 0,
        world_size: int = 1,
        dist_method: str = "filesystem",  # Options: "filesystem", "gpu"
    ):
        """
        Initialize the image pool.

        Args:
            output_dir: Directory to save/load the image pool
            device: Device to use for generation and feature extraction
            is_distributed: Whether running in a distributed setting
            pool_size: Maximum size of the image pool
            batch_size: Batch size for image generation
            image_size: Size of generated images
            rank: Process rank in distributed setting
            world_size: Total number of processes in distributed setting
            dist_method: Distribution method for synchronizing pools ("filesystem" or "gpu")
        """
        self.output_dir = output_dir
        self.device = device
        self.is_distributed = is_distributed
        self.pool_size = pool_size
        self.batch_size = batch_size
        self.image_size = image_size
        self.rank = rank
        self.world_size = world_size
        self.dist_method = dist_method

        # Storage for images and features
        self.noise_vectors = []  # xT noise vectors
        self.prompts = []
        self.inception_features = []

        # Additional metadata
        self.prompt_indices = []  # For tracking original prompt indices

        # Storage for ground truth FID statistics (as numpy arrays)
        self.gt_mu: Optional[np.ndarray] = None
        self.gt_sigma: Optional[np.ndarray] = None

        # Cache for pool statistics
        self._pool_mu: Optional[np.ndarray] = None
        self._pool_sigma: Optional[np.ndarray] = None
        self._pool_stats_valid = False

        # Create output directory if it doesn't exist
        os.makedirs(output_dir, exist_ok=True)

        # Initialize the Inception model for feature extraction
        self._setup_inception_model()

        if DISTRL_DEBUG_IMAGEPOOL_SAVE_IMAGE:
            self.debug_output_dir = os.path.join(output_dir, "debug_images")
            os.makedirs(self.debug_output_dir, exist_ok=True)

    def _setup_inception_model(self):
        """Set up the Inception v3 model for feature extraction."""
        # Create InceptionV3Detector
        self.inception_model = InceptionV3Detector()
        self.inception_model.model.to(self.device)

    def _reset_pool(self) -> None:
        """Reset the image pool."""
        self.noise_vectors = []
        self.prompts = []
        self.inception_features = []
        self.prompt_indices = []
        self._pool_mu = None
        self._pool_sigma = None
        self._invalidate_pool_stats()

    def fill_the_pool(
        self,
        pipe,
        prompts: List[str],
        num_inference_steps: int = 50,
        guidance_scale: float = 7.5,
        seed: Optional[int] = None,
        force_fill: bool = False,
    ) -> None:
        """
        Fill the image pool with balanced images per prompt.
        Each prompt will have approximately the same number of images.

        Args:
            prompts: List of prompts to use
            num_inference_steps: Number of denoising steps
            guidance_scale: Guidance scale for classifier-free guidance
            seed: Random seed for reproducibility
        """
        os.environ["DISTRL_STATUS_FILL_POOL"] = "1"

        dtype = next(pipe.unet.parameters()).dtype
        pipe.unet.eval()
        pipe.vae.eval()

        # Check if we can load the complete pool from disk first
        complete_pool_path = os.path.join(self.output_dir, "image_pool.pkl")
        if os.path.exists(complete_pool_path) and not force_fill:
            if self.load_from_disk():
                logger.info("Successfully loaded existing complete image pool")
                return

        # Calculate images per prompt (ceil to ensure we have at least pool_size)
        images_per_prompt = math.ceil(self.pool_size / len(prompts))

        # Calculate total images after ceiling (may exceed pool_size slightly)
        total_images = images_per_prompt * len(prompts)

        logger.info(f"Generating {images_per_prompt} images per prompt for {len(prompts)} prompts")
        logger.info(f"Total images: {total_images} (requested: {self.pool_size})")

        # Distribute prompts among ranks - each rank handles a subset of prompts
        prompts_per_rank = math.ceil(len(prompts) / self.world_size)
        start_idx = self.rank * prompts_per_rank
        end_idx = min(start_idx + prompts_per_rank, len(prompts))

        # Handle edge case where a rank might get no prompts
        if start_idx >= len(prompts):
            logger.warning(f"Rank {self.rank} has no prompts to process")
            rank_prompts = []
        else:
            rank_prompts = prompts[start_idx:end_idx]
            logger.info(f"Rank {self.rank} handling prompts {start_idx} to {end_idx-1} ({len(rank_prompts)} prompts)")

        noise_vectors = []
        prompt_texts = []
        prompt_indices = []  # For tracking original prompt indices
        inception_features = []

        # First, expand selected_prompts to include multiple copies for args.images_per_prompt
        expanded_prompts = []
        for prompt in rank_prompts:
            expanded_prompts.extend([prompt] * images_per_prompt)

        # Shuffle expanded prompts to ensure diverse batches
        # state = random.getstate()
        # random.shuffle(expanded_prompts)
        # random.setstate(state)

        # Process in batches
        batch_iterator = range(0, len(expanded_prompts), self.batch_size)
        # Add tqdm progress bar for the main process only
        if self.rank == 0:
            batch_iterator = tqdm(batch_iterator, desc="Filling image pool", total=len(expanded_prompts)//self.batch_size + (1 if len(expanded_prompts) % self.batch_size != 0 else 0))

        for batch_idx in batch_iterator:
            batch_prompts = expanded_prompts[batch_idx:batch_idx + self.batch_size]

            if not batch_prompts:
                continue

            # Generate noise latents using prepare_latents
            batch_size = len(batch_prompts)
            num_channels_latents = pipe.unet.in_channels if not hasattr(pipe.unet, 'module') else pipe.unet.module.in_channels

            # Create a generator for deterministic noise
            if seed is not None:
                generator = torch.Generator(device=self.device).manual_seed(seed + batch_idx)
            else:
                generator = None

            # Generate noise latents
            noise_batch = pipe.prepare_latents(
                batch_size=batch_size,
                num_channels_latents=num_channels_latents,
                height=self.image_size,
                width=self.image_size,
                dtype=dtype,
                device=self.device,
                generator=generator,
                latents=None,
            )

            # Generate images with noise vectors
            # Note: SiT uses a VAE scaling factor of 0.18215 (same as Stable Diffusion)
            # which is applied inside the model's decode_latents method
            with torch.no_grad():
                output = pipe(
                    batch_prompts,
                    guidance_scale=guidance_scale,
                    num_inference_steps=num_inference_steps,
                    generator=generator,
                    latents=noise_batch,
                    return_dict=True,
                    output_type="pt",  # Get tensor output in BCHW format, uint8 [0-255]
                )

            batch_images = output.images  # This should be in BCHW format, uint8 [0-255]

            # Process batch of images for Inception model
            processed_images = []
            for image in batch_images:
                # Process image to match Inception input requirements
                if isinstance(image, np.ndarray):
                    image = (image * 255).astype(np.uint8)
                    processed_image = torch.from_numpy(image.transpose(2, 0, 1)).to(self.device)  # HWC -> CHW
                elif isinstance(image, torch.Tensor):
                    if image.dtype == torch.uint8:
                        processed_image = image.to(self.device)  # CHW
                    else:
                        # Convert from float [0,1] to uint8 [0,255], keeping CHW format
                        processed_image = (image * 255).to(dtype=torch.uint8).to(self.device)   # image: (3, 256, 256)
                else:
                    raise NotImplementedError("Only numpy and torch.Tensor format is supported")
                processed_images.append(processed_image)

            # Stack all processed images into a batch (BCHW format)
            batch_processed = torch.stack(processed_images).to(self.device)

            with torch.no_grad():
                # Extract features for the entire batch at once
                batch_features = self.inception_model(batch_processed)

            # Split features back into list
            batch_features = [feat.cpu() for feat in batch_features]

            # Store results
            noise_vectors.extend([noise_batch[j].clone() for j in range(batch_size)])
            prompt_texts.extend(batch_prompts)
            prompt_indices.extend([start_idx + i // images_per_prompt for i in range(batch_idx, batch_idx + batch_size)])
            inception_features.extend(batch_features)

            if DISTRL_DEBUG_IMAGEPOOL_SAVE_IMAGE:
                for i in range(batch_size):
                    image = batch_processed[i]
                    Image.fromarray(image.cpu().numpy().transpose(1, 2, 0)).save(f"{self.debug_output_dir}/{batch_prompts[i].replace('an image of ', '').replace(' ', '_')}.r{self.rank}b{batch_idx}i{i}.png")

            # Free up memory
            del batch_images, batch_processed, processed_images
            torch.cuda.empty_cache()

        # Add to pool with metadata
        self.add_to_pool_with_metadata(
            noise_vectors=noise_vectors,
            prompts=prompt_texts,
            inception_features=inception_features,
            prompt_indices=prompt_indices
        )

        # Handle distributed synchronization based on method
        if self.is_distributed:
            if self.dist_method == "filesystem":
                self.distribute_with_filesystem(complete_pool_path)
            elif self.dist_method == "gpu":
                self.distribute_with_gpu()
            else:
                raise ValueError(f"Unknown distribution method: {self.dist_method}")
        os.environ["DISTRL_STATUS_FILL_POOL"] = "0"

    def distribute_with_filesystem(self, complete_pool_path: str) -> None:
        """
        Synchronize image pool across ranks using filesystem.
        Rank 0 combines files from all ranks, others wait for the complete file.

        Args:
            complete_pool_path: Path to save the complete pool file
        """
        # Save rank-specific data to disk
        self.save_to_disk(filename=f"image_pool_rank_{self.rank}.pkl")

        logger.info(f"Rank {self.rank} finished generating images, saved to file")
        dist.barrier()

        # Only rank 0 combines all files
        if self.rank == 0:
            logger.info("Rank 0 combining all rank files")
            self.combine_rank_files()

            # Sort and trim if needed
            self.sort_by_prompt_indices()

            # Trim to exactly pool_size if we exceeded it
            if len(self.noise_vectors) > self.pool_size:
                logger.info(f"Trimming pool from {len(self.noise_vectors)} to {self.pool_size} images")
                self.trim_to_size(self.pool_size)

            # Save the complete pool
            self.save_to_disk()
            logger.info("Complete pool saved to disk")

        # All ranks load the complete pool
        while not os.path.exists(complete_pool_path):
            logger.info(f"Rank {self.rank} waiting for complete pool file")
            time.sleep(5)  # Wait for a few seconds before checking again

        # Load the complete pool
        self.load_from_disk()
        logger.info(f"Rank {self.rank} loaded complete pool")

    def distribute_with_gpu(self) -> None:
        """
        Synchronize image pool across ranks using GPU communication.
        Uses all_gather_object to collect data from all ranks.
        """

        if DISTRL_DEBUG_DETAIL_LOG:
            logger.info(f"Rank {self.rank} gathering data via GPU communication, data size: {len(self.noise_vectors)}")

        # Prepare data for gathering
        local_data = {
            "noise_vectors": [x.cpu() for x in self.noise_vectors],
            "prompts": self.prompts,
            "inception_features": [x.cpu() for x in self.inception_features],
            "prompt_indices": self.prompt_indices
        }

        # Create a list to receive data from all processes
        gathered_data = [None] * self.world_size

        # Gather objects from all processes
        dist.all_gather_object(gathered_data, local_data)

        # Reset the pool
        self._reset_pool()

        # Process gathered data from all ranks
        for rank_data in gathered_data:
            if rank_data is not None:
                # Add to the current pool
                self.noise_vectors.extend([x.to(self.device) for x in rank_data["noise_vectors"]])
                self.prompts.extend(rank_data["prompts"])
                self.inception_features.extend([x.to(self.device) for x in rank_data["inception_features"]])
                self.prompt_indices.extend(rank_data["prompt_indices"])

        # Sort and trim - each rank does this locally
        self.sort_by_prompt_indices()

        # Trim to exactly pool_size if we exceeded it
        if len(self.noise_vectors) > self.pool_size:
            logger.info(f"Rank {self.rank} trimming pool from {len(self.noise_vectors)} to {self.pool_size} images")
            self.trim_to_size(self.pool_size)

        if DISTRL_DEBUG_DETAIL_LOG:
            logger.info(f"Rank {self.rank} completed GPU distribution, final pool size: {len(self.noise_vectors)}")

        # Ensure pool is identical across ranks
        dist.barrier()

    def combine_rank_files(self) -> None:
        """
        Combine the rank-specific pool files into a single complete pool.
        This method should only be called by rank 0.
        """
        if self.rank != 0:
            logger.warning("combine_rank_files should only be called by rank 0")
            return

        # Clear the current pool
        self.noise_vectors = []
        self.prompts = []
        self.inception_features = []
        self.prompt_indices = []

        # Load each rank's file
        ranks_loaded = 0
        total_expected_images = 0
        for rank in range(self.world_size):
            rank_file = os.path.join(self.output_dir, f"image_pool_rank_{rank}.pkl")

            # Wait for rank file to exist with timeout
            max_wait_time = 300  # 5 minutes timeout
            wait_interval = 5  # Check every 5 seconds
            start_time = time.time()

            while not os.path.exists(rank_file):
                if time.time() - start_time > max_wait_time:
                    logger.error(f"Timeout waiting for rank file {rank_file} after {max_wait_time} seconds")
                    break

                logger.info(f"Waiting for rank file {rank_file} to be created...")
                time.sleep(wait_interval)

            if not os.path.exists(rank_file):
                logger.error(f"Rank file {rank_file} still does not exist after waiting {max_wait_time} seconds, skipping")
                continue

            logger.info(f"Loading rank file {rank_file}")

            try:
                with open(rank_file, "rb") as f:
                    data = pickle.load(f)

                # Add to the current pool
                self.noise_vectors.extend([x.to(self.device) for x in data["noise_vectors"]])
                self.prompts.extend(data["prompts"])
                self.inception_features.extend([x.to(self.device) for x in data["inception_features"]])
                self.prompt_indices.extend(data["prompt_indices"])

                logger.info(f"Added {len(data['noise_vectors'])} images from rank {rank}")
                ranks_loaded += 1
                total_expected_images += len(data["noise_vectors"])

            except Exception as e:
                logger.error(f"Failed to load rank file {rank_file}: {e}")

        logger.info(f"Combined pool has {len(self.noise_vectors)} images from {ranks_loaded}/{self.world_size} ranks")
        logger.info(f"Expected approximately {total_expected_images} images - if total doesn't match, some ranks failed")

        if ranks_loaded < self.world_size:
            logger.warning(f"Missing {self.world_size - ranks_loaded} rank files - pool may be incomplete")

        # Invalidate cached statistics
        self._invalidate_pool_stats()

    def extract_inception_features(self, images: List[torch.Tensor]) -> List[torch.Tensor]:
        """
        Extract Inception v3 features from images.

        Args:
            images: List of image tensors

        Returns:
            List of feature tensors
        """
        features = []

        with torch.no_grad():
            for image in images:
                # Process image to match Inception input requirements
                if isinstance(image, np.ndarray):
                    # Convert from numpy array (HWC format) to tensor (CHW format)
                    image = (image * 255).astype(np.uint8)
                    processed_image = torch.from_numpy(image.transpose(2, 0, 1)).to(self.device)  # HWC -> CHW
                elif isinstance(image, torch.Tensor):
                    if image.dtype == torch.uint8:
                        processed_image = image.to(self.device)  # Already in BCHW format
                    else:
                        # Convert from float [0,1] to uint8 [0,255], keeping BCHW format
                        processed_image = (image * 255).to(dtype=torch.uint8).to(self.device)
                else:
                    raise NotImplementedError("PIL images are not supported")

                # Add batch dimension (CHW -> BCHW)
                processed_image = processed_image.unsqueeze(0).to(self.device)

                # Extract feature
                feature = self.inception_model(processed_image)
                features.append(feature.cpu())

        return features

    def add_to_pool_with_metadata(
        self,
        noise_vectors: List[torch.Tensor],
        prompts: List[str],
        inception_features: List[torch.Tensor],
        prompt_indices: List[int]
    ) -> None:
        """
        Add generated images and their features to the pool with metadata.

        Args:
            noise_vectors: List of noise vectors (xT)
            prompts: List of prompts used for generation
            inception_features: List of Inception features
            prompt_indices: List of indices mapping to original prompts
        """
        # Extend the lists with new data
        self.noise_vectors.extend(noise_vectors)
        self.prompts.extend(prompts)
        self.inception_features.extend(inception_features)
        self.prompt_indices.extend(prompt_indices)

        # Invalidate cached statistics
        self._invalidate_pool_stats()

        # Ensure we don't exceed the pool size
        if len(self.noise_vectors) > self.pool_size:
            self.trim_to_size(self.pool_size)

    def sort_by_prompt_indices(self) -> None:
        """
        Sort all pool data by prompt indices.
        This ensures prompts are grouped together.
        """
        if len(self.prompt_indices) == 0:
            return

        # Create sorted indices
        sorted_indices = sorted(range(len(self.prompt_indices)), key=lambda i: self.prompt_indices[i])

        # Apply sorting to all lists
        self.noise_vectors = [self.noise_vectors[i] for i in sorted_indices]
        self.prompts = [self.prompts[i] for i in sorted_indices]
        self.inception_features = [self.inception_features[i] for i in sorted_indices]
        self.prompt_indices = [self.prompt_indices[i] for i in sorted_indices]

    def trim_to_size(self, size: int) -> None:
        """
        Trim the pool to a specific size, preserving balanced distribution across prompts.

        Args:
            size: Target size for the pool
        """
        if len(self.noise_vectors) <= size:
            return

        # Find unique prompt indices
        unique_prompts = sorted(set(self.prompt_indices))
        num_unique_prompts = len(unique_prompts)

        # Calculate images per prompt after trimming
        imgs_per_prompt = math.ceil(size / num_unique_prompts)
        remainder = size - (imgs_per_prompt * num_unique_prompts)

        # Handle negative remainder (occurs when using ceil)
        if remainder < 0:
            # Reduce imgs_per_prompt for some prompts to achieve exact size
            imgs_to_reduce = abs(remainder)
            prompts_to_reduce = sorted(unique_prompts)[:imgs_to_reduce]
            prompt_adjustments = {p: -1 if p in prompts_to_reduce else 0 for p in unique_prompts}
            remainder = 0
        else:
            prompt_adjustments = {p: 0 for p in unique_prompts}

        # Create new lists
        new_noise_vectors = []
        new_prompts = []
        new_inception_features = []
        new_prompt_indices = []

        # Keep track of how many images we've kept for each prompt
        prompt_counts = {p: 0 for p in unique_prompts}

        # First pass: Keep imgs_per_prompt for each prompt_index
        for i in range(len(self.noise_vectors)):
            prompt_idx = self.prompt_indices[i]
            adjusted_count = imgs_per_prompt + prompt_adjustments[prompt_idx]
            if prompt_counts[prompt_idx] < adjusted_count:
                new_noise_vectors.append(self.noise_vectors[i])
                new_prompts.append(self.prompts[i])
                new_inception_features.append(self.inception_features[i])
                new_prompt_indices.append(prompt_idx)
                prompt_counts[prompt_idx] += 1

        # Second pass: Distribute remainder across prompts
        if remainder > 0:
            # Start from prompts with lowest indices for determinism
            for prompt_idx in sorted(unique_prompts):
                # Find the next image for this prompt that wasn't included yet
                for i in range(len(self.noise_vectors)):
                    if self.prompt_indices[i] == prompt_idx and i not in new_prompt_indices:
                        new_noise_vectors.append(self.noise_vectors[i])
                        new_prompts.append(self.prompts[i])
                        new_inception_features.append(self.inception_features[i])
                        new_prompt_indices.append(prompt_idx)
                        remainder -= 1
                        if remainder == 0:
                            break
                if remainder == 0:
                    break

        # Update pool with trimmed data
        self.noise_vectors = new_noise_vectors
        self.prompts = new_prompts
        self.inception_features = new_inception_features
        self.prompt_indices = new_prompt_indices

        # Invalidate cached statistics after trimming
        self._invalidate_pool_stats()

        # Verify exact pool size
        if len(self.noise_vectors) != size:
            logger.warning(f"Pool size {len(self.noise_vectors)} doesn't match requested size {size}")
            # Force exact size if needed by adding/removing images randomly
            if len(self.noise_vectors) < size:
                # Add random duplicates to reach exact size
                indices_to_duplicate = np.random.choice(len(self.noise_vectors), size=size-len(self.noise_vectors))
                for idx in indices_to_duplicate:
                    self.noise_vectors.append(self.noise_vectors[idx])
                    self.prompts.append(self.prompts[idx])
                    self.inception_features.append(self.inception_features[idx])
                    self.prompt_indices.append(self.prompt_indices[idx])
            elif len(self.noise_vectors) > size:
                # Remove random images to reach exact size
                indices_to_keep = sorted(np.random.choice(len(self.noise_vectors), size=size, replace=False))
                self.noise_vectors = [self.noise_vectors[i] for i in indices_to_keep]
                self.prompts = [self.prompts[i] for i in indices_to_keep]
                self.inception_features = [self.inception_features[i] for i in indices_to_keep]
                self.prompt_indices = [self.prompt_indices[i] for i in indices_to_keep]

        logger.info(f"Trimmed pool to exactly {len(self.noise_vectors)} images")

    def save_to_disk(self, filename: str = "image_pool.pkl") -> None:
        """
        Save the image pool to disk.

        Args:
            filename: Name of the file to save the pool to
        """
        save_path = os.path.join(self.output_dir, filename)
        logger.info(f"Saving image pool to {save_path}")

        # Prepare data for saving
        save_data = {
            "noise_vectors": [x.cpu() for x in self.noise_vectors],
            "prompts": self.prompts,
            "inception_features": [x.cpu() for x in self.inception_features],
            "prompt_indices": self.prompt_indices,
        }

        with open(save_path, "wb") as f:
            pickle.dump(save_data, f)

    def load_from_disk(self, filename: str = "image_pool.pkl") -> bool:
        """
        Load the image pool from disk.

        Args:
            filename: Name of the file to load the pool from

        Returns:
            True if loading was successful, False otherwise
        """
        load_path = os.path.join(self.output_dir, filename)

        if not os.path.exists(load_path):
            logger.warning(f"Image pool file {load_path} does not exist")
            return False

        logger.info(f"Loading image pool from {load_path}")

        try:
            with open(load_path, "rb") as f:
                data = pickle.load(f)

            self.noise_vectors = [x.to(self.device) for x in data["noise_vectors"]]
            self.prompts = data["prompts"]
            self.inception_features = [x.to(self.device) for x in data["inception_features"]]

            # Load prompt indices if available
            if "prompt_indices" in data:
                self.prompt_indices = data["prompt_indices"]
            else:
                # Fallback: Create sequential indices
                logger.warning("Prompt indices not found in saved data, creating default indices")
                self.prompt_indices = list(range(len(self.prompts)))

            logger.info(f"Loaded {len(self.noise_vectors)} images from pool")
            return True

        except Exception as e:
            logger.error(f"Failed to load image pool: {e}")
            return False

    def _invalidate_pool_stats(self) -> None:
        """Invalidate cached pool statistics."""
        self._pool_stats_valid = False

    def _update_pool_stats(self) -> None:
        """Update cached pool statistics if invalid."""
        if not self._pool_stats_valid:
            if not self.inception_features:
                raise ValueError("No inception features in pool")

            # Stack features and convert to numpy
            features = torch.stack(self.inception_features)
            features_np = features.cpu().numpy()

            # Compute statistics using numpy implementation
            self._pool_mu, self._pool_sigma = calculate_activation_statistics(features_np)
            self._pool_stats_valid = True

    def compute_fid_statistics(self) -> Tuple[np.ndarray, np.ndarray]:
        """
        Compute mean and covariance of inception features for FID calculation.
        Uses cached statistics if available.

        Returns:
            Tuple of (mean, covariance) as numpy arrays
        """
        self._update_pool_stats()
        if self._pool_mu is None or self._pool_sigma is None:
            raise ValueError("Failed to compute pool statistics")
        return self._pool_mu, self._pool_sigma

    def _compute_fid_score(self, mu1: np.ndarray, sigma1: np.ndarray,
                         mu2: np.ndarray, sigma2: np.ndarray) -> float:
        """
        Compute FID score between two distributions using numpy implementation.

        Args:
            mu1, sigma1: Mean and covariance of first distribution (numpy arrays)
            mu2, sigma2: Mean and covariance of second distribution (numpy arrays)

        Returns:
            FID score
        """
        return calculate_frechet_distance(mu1, sigma1, mu2, sigma2)

    def sample_replacement_indices(self, num_groups: int, num_samples: int) -> Tuple[List[List[int]], List[List[str]]]:
        """
        Sample indices and their corresponding prompts for replacement.
        Each group will have num_samples unique indices, but indices can be repeated across groups.

        Args:
            num_groups: Number of groups to sample
            num_samples: Number of samples per group

        Returns:
            Tuple of:
            - List[List[int]]: List of groups, each containing num_samples unique indices
            - List[List[str]]: List of groups, each containing the prompts corresponding to the indices
        """
        if num_samples > len(self.noise_vectors):
            raise ValueError(f"Requested {num_samples} samples but pool only has {len(self.noise_vectors)} images")

        # Sample indices for each group
        indices_groups = []
        prompts_groups = []

        for _ in range(num_groups):
            # Sample num_samples unique indices
            group_indices = np.random.choice(len(self.noise_vectors), size=num_samples, replace=False)
            indices_groups.append(group_indices.tolist())

            # Get corresponding prompts
            group_prompts = [self.prompts[idx] for idx in group_indices]
            prompts_groups.append(group_prompts)

        return indices_groups, prompts_groups

    def compute_fid(self,
                   temp_features: Optional[List[torch.Tensor]] = None,
                   temp_indices: Optional[List[int]] = None) -> float:
        """
        Compute FID score between features and ground truth.
        If temp_features and temp_indices are provided, compute FID for the temporary features.
        Otherwise, compute FID for the current pool.

        Args:
            temp_features: Optional list of temporary features to replace some pool features
            temp_indices: Optional list of indices where to replace features

        Returns:
            FID score between features and ground truth
        """
        if self.gt_mu is None or self.gt_sigma is None:
            raise ValueError("Ground truth FID statistics not loaded. Call load_gt_fid_stats first.")

        # If temporary features provided, compute their FID
        if temp_features is not None and temp_indices is not None:
            if len(temp_features) != len(temp_indices):
                raise ValueError("Number of temporary features must match number of indices")

            # Create temporary feature list
            temp_pool_features = self.inception_features.copy()
            for feat, idx in zip(temp_features, temp_indices):
                if idx >= len(temp_pool_features):
                    raise ValueError(f"Index {idx} out of range")
                temp_pool_features[idx] = feat

            # Stack features and convert to numpy for statistics computation
            temp_pool_features = [feat.cpu() for feat in temp_pool_features]
            temp_features_tensor = torch.stack(temp_pool_features)
            temp_features_np = temp_features_tensor.numpy()

            # Compute statistics using numpy implementation
            temp_mu_np, temp_sigma_np = calculate_activation_statistics(temp_features_np)

            # Calculate FID between temporary pool and GT
            return self._compute_fid_score(temp_mu_np, temp_sigma_np, self.gt_mu, self.gt_sigma)

        # Otherwise compute FID for current pool
        pool_mu, pool_sigma = self.compute_fid_statistics()
        return self._compute_fid_score(pool_mu, pool_sigma, self.gt_mu, self.gt_sigma)

    def load_gt_fid_stats(self, stats_path: str) -> None:
        """
        Load ground truth FID statistics from a pickle file.

        Args:
            stats_path: Path to the pickle file containing GT FID statistics
        """
        try:
            if stats_path.endswith(".pkl"):
                with open(stats_path, 'rb') as f:
                    stats = pickle.load(f)
                stats = stats.get("fid", stats)
            elif stats_path.endswith(".npz"):
                stats = np.load(stats_path)
            else:
                raise ValueError(f"Unsupported file extension: {stats_path}")

            self.gt_mu = stats['mu']
            self.gt_sigma = stats['sigma']
            logger.info(f"Successfully loaded GT FID statistics from {stats_path}")
        except Exception as e:
            logger.error(f"Failed to load GT FID statistics: {e}")
            raise

    @classmethod
    def from_args(cls, args, accelerator, pool_output_dir=None):
        """
        Create an ImagePool instance from command line arguments.

        Args:
            args: Command line arguments
            accelerator: Accelerator instance
            pool_output_dir: Optional directory override

        Returns:
            ImagePool instance
        """
        output_dir = pool_output_dir if pool_output_dir is not None else os.path.join(args.output_dir, "image_pool")

        return cls(
            output_dir=output_dir,
            device=accelerator.device,
            is_distributed=accelerator.num_processes > 1,
            pool_size=args.num_fid_images,
            batch_size=args.g_batch_size,
            image_size=args.image_size,
            rank=accelerator.process_index,
            world_size=accelerator.num_processes,
            dist_method=args.dist_method,
        )
