"""
Implementation of projection methods for dimension reduction.

This file contains functions to construct all random projection methods for dimension reduction.
"""

from __future__ import annotations

import math
from abc import ABC, abstractmethod
from enum import Enum
from typing import TYPE_CHECKING, Tuple
import logging

if TYPE_CHECKING:
    from collections.abc import Callable
    from typing import Dict, List, Union, Optional

import numpy as np
import torch
from torch import Tensor

from ..utils.common import vectorize, get_parameter_chunk_sizes

# Configure logger
logger = logging.getLogger(__name__)

class ProjectionType(str, Enum):
    """Projection type used for projectors."""

    normal: str = "normal"
    rademacher: str = "rademacher"
    identity: str = "identity"


class AbstractProjector(ABC):
    """Base Class for projectors."""

    @abstractmethod
    def __init__(
        self,
        feature_dim: int,
        proj_dim: int,
        seed: int,
        proj_type: Union[str, ProjectionType],
        device: Union[str, torch.device],
    ) -> None:
        """Initializes hyperparameters for the projection.

        Args:
            feature_dim: Dimension of the features to be projected.
                Typically, this equals the number of parameters in the model
                (dimension of the gradient vectors).
            proj_dim: Dimension after the projection.
            seed: Random seed for the generation of the sketching
                (projection) matrix.
            proj_type: The random projection (JL
                transform) guarantees that distances will be approximately
                preserved for a variety of choices of the random matrix. Here,
                we provide an implementation for matrices with iid Gaussian
                entries and iid Rademacher entries.
            device: CUDA device to use.
        """
        self.feature_dim = feature_dim
        self.proj_dim = proj_dim
        self.seed = seed
        self.proj_type = proj_type
        self.device = device

    @abstractmethod
    def project(self, features: Union[dict, Tensor], ensemble_id: int) -> Tensor:
        """Performs the random projection on feature matrix.

        This function will take features and an ensemble_id, which allows us
        to generate different projection matrices, and output the projected
        matrix.

        Args:
            features: A batch of features or a dictionary
                of batch of features.
            ensemble_id: A unique ID for this ensemble.

        Returns:
            Tensor: The projected features.
        """

    @abstractmethod
    def free_memory(self) -> None:
        """Frees up memory used by the projector."""


class BasicProjector(AbstractProjector):
    """A simple block-wise implementation of the projection.

    The projection matrix is generated on-device in blocks.
    The accumulated result across blocks is returned.
    """

    def __init__(
        self,
        feature_dim: int,
        proj_dim: int,
        seed: int,
        proj_type: Union[str, ProjectionType],
        device: torch.device,
        block_size: int = 100,
        dtype: torch.dtype = torch.float32,
        ensemble_id: int = 0,
        method: str = "Gaussian",
        active_indices: Optional[Tensor] = None,
        pre_compute: bool = False,
    ) -> None:
        """Initializes hyperparameters for BasicProjector.

        Args:
            feature_dim: Dimension of the features to be projected.
                Typically, this equals the number of parameters in the model
                (dimension of the gradient vectors).
            proj_dim: Dimension after the projection.
            seed: Random seed for the generation of the sketching
                (projection) matrix.
            proj_type: The random projection (JL
                transform) guarantees that distances will be approximately
                preserved for a variety of choices of the random matrix. Here,
                we provide an implementation for matrices with iid Gaussian
                entries and iid Rademacher entries.
            device: CUDA device to use.
            block_size: Maximum number of projection dimension allowed.
                Thus, min(block_size, proj_dim) will be used as the actual
                projection dimension.
            dtype: The dtype of the projected matrix.
            ensemble_id: A unique ID for this ensemble.
        """
        super().__init__(feature_dim, proj_dim, seed, proj_type, device)

        self.block_size = min(self.proj_dim, block_size)
        self.num_blocks = math.ceil(self.proj_dim / self.block_size)
        self.dtype = dtype
        self.proj_type = proj_type
        self.ensemble_id = ensemble_id
        self.method = method #TODO: currently unused

        self.proj_matrix = torch.empty(
            self.feature_dim,
            self.block_size,
            dtype=self.dtype,
            device=self.device,
        )

        self.proj_matrix_available = True

        self.generator = torch.Generator(device=self.device)

        self.get_generator_states()
        self.generate_sketch_matrix(self.generator_states[0])

        logger.debug(f"Initialized BasicProjector with dimensions {feature_dim} -> {proj_dim}")

    def free_memory(self) -> None:
        """Delete the projection matrix."""
        del self.proj_matrix
        self.proj_matrix_available = False

    def get_generator_states(self) -> None:
        """Set generator seeds for each block."""
        self.generator_states = []
        self.seeds = []
        self.jl_size = self.feature_dim * self.block_size

        for i in range(self.num_blocks):
            s = self.seed + int(1e3) * i + int(1e5) * self.ensemble_id
            self.seeds.append(s)
            self.generator = self.generator.manual_seed(s)
            self.generator_states.append(self.generator.get_state())

    def generate_sketch_matrix(self, generator_state: List) -> None:
        """Set generator states and generate sketch matrices.

        Args:
            generator_state: A list of generator states. Usually each
                block will be given a unique generator states.

        Raises:
            KeyError: Projection type is not recognized.
        """
        if not self.proj_matrix_available:
            self.proj_matrix = torch.empty(
                self.feature_dim,
                self.block_size,
                dtype=self.dtype,
                device=self.device,
            )
            self.proj_matrix_available = True

        self.generator.set_state(generator_state)

        if self.proj_type in {ProjectionType.normal, "normal"}:
            self.proj_matrix.normal_(generator=self.generator)
        elif self.proj_type in {ProjectionType.rademacher, "rademacher"}:
            self.proj_matrix.bernoulli_(p=0.5, generator=self.generator)
            self.proj_matrix *= 2.0
            self.proj_matrix -= 1.0
        else:
            msg = f"Projection type {self.proj_type} not recognized."
            raise KeyError(msg)

    def project(self, features: Union[dict, Tensor], ensemble_id: int) -> Tensor:
        """Performs the random projection on the feature matrix.

        Args:
            features: A batch of features or a dictionary
                of batch of features.
            ensemble_id: A unique ID for this ensemble.

        Returns:
            Tensor: The projected features.
        """
        if isinstance(features, dict):
            features = vectorize(features, device=self.device)
        elif features.device.type != self.device:
            features = features.to(self.device)
        features = features.to(dtype=self.dtype)
        sketch = torch.zeros(
            size=(features.size(0), self.proj_dim),
            dtype=self.dtype,
            device=self.device,
        )

        if ensemble_id != self.ensemble_id:
            self.ensemble_id = ensemble_id
            self.get_generator_states()  # regenerate random seeds for new ensemble_id
            if self.num_blocks == 1:
                self.generate_sketch_matrix(self.generator_states[0])

        if self.num_blocks == 1:
            torch.matmul(features.data, self.proj_matrix, out=sketch)
        else:
            for ind in range(self.num_blocks):
                self.generate_sketch_matrix(self.generator_states[ind])

                st = ind * self.block_size
                ed = min((ind + 1) * self.block_size, self.proj_dim)
                sketch[:, st:ed] = (
                    features.type(self.dtype) @ self.proj_matrix[:, : (ed - st)]
                )
        return sketch.type(features.dtype)


class CudaProjector(AbstractProjector):
    """Projector implemented using CUDA.

    A performant implementation of the projection
    for CUDA with compute capability >= 7.0.
    """

    def __init__(
        self,
        feature_dim: int,
        proj_dim: int,
        seed: int,
        proj_type: ProjectionType,
        device: str,
        max_batch_size: int,
        method: str,
        active_indices: Optional[Tensor] = None,
        pre_compute: bool = False,
    ) -> None:
        """Initializes hyperparameters for CudaProjector.

        Args:
            feature_dim: Dimension of the features to be projected.
                Typically, this equals the number of parameters in the model
                (dimension of the gradient vectors).
            proj_dim: Dimension we project *to* during the projection step
            seed: Random seed.
            proj_type: Type of randomness to use for
                projection matrix (rademacher or normal).
            device: CUDA device to use.
            max_batch_size: Explicitly constrains the batch size of
                the CudaProjector is going to use for projection.
                Set this if you get a 'The batch size of the CudaProjector is
                too large for your GPU' error. Must be either 8, 16, or 32.
            method: The method used for the projection.
            active_indices: The indices of the features to be considered.
            pre_compute: If True, the projection construction will be pre-computed

        Raises:
            ValueError: When attempting to use this on a non-CUDA device.
            ModuleNotFoundError: When fast_jl is not installed.
        """
        super().__init__(feature_dim, proj_dim, seed, proj_type, device)
        self.max_batch_size = max_batch_size

        if active_indices is None:
            active_indices = torch.arange(feature_dim, device=device)

        self.active_indices = active_indices.to(device)

        # if active_indices is a single element, then it will be a 0-dim tensor
        if self.active_indices.dim() == 0:
            self.active_indices = self.active_indices.unsqueeze(0)

        self.pre_compute = pre_compute

        if isinstance(device, str):
            device = torch.device(device)

        if device.type != "cuda":
            err = "CudaProjector only works on a CUDA device; \
            Either switch to a CUDA device, or use the BasicProjector"
            raise ValueError(err)

        self.num_sms = torch.cuda.get_device_properties(
            device.index,
        ).multi_processor_count

        self.method = method
        if self.method == "FJLT":
            try:
                import fast_jl

                # test run to catch at init time if projection goes through
                fast_jl.project_rademacher_8(
                    torch.zeros(8, 1_000, device="cuda"),
                    512,
                    0,
                    self.num_sms,
                )
            except ImportError:
                msg = "You should make sure to install the CUDA projector \
                (the fast_jl library)."
                raise ModuleNotFoundError(msg) from None
        elif self.method == "SJLT":
            try:
                from .sjlt.sjlt_cuda import SJLTProjection
                self.c = 1
                if self.pre_compute:
                    torch.manual_seed(self.seed)
                    active_dim = self.active_indices.numel()
                    rand_indices = torch.randint(proj_dim, (active_dim, self.c), device=device)
                    rand_signs = torch.randint(0, 2, (active_dim, self.c), device=device) * 2 - 1
                    self.sjlt_cuda_module = SJLTProjection(active_dim, proj_dim, self.c, device=device)
                    self.sjlt_cuda_module.rand_indices.copy_(rand_indices)
                    self.sjlt_cuda_module.rand_signs.copy_(rand_signs.to(torch.int8))
            except ImportError:
                msg = "You should make sure that SJLT CUDA version can be installed correctly."
                raise ModuleNotFoundError(msg) from None
        elif self.method == "Rademacher":
            if self.pre_compute:
                active_dim = self.active_indices.numel()
                self.proj_matrix = torch.empty(
                    active_dim,
                    proj_dim,
                    device=device,
                )
                torch.manual_seed(self.seed)
                self.proj_matrix.bernoulli_(p=0.5)
                self.proj_matrix *= 2.0
                self.proj_matrix -= 1.0
        elif self.method == "Gaussian":
            if self.pre_compute:
                active_dim = self.active_indices.numel()
                torch.manual_seed(self.seed)
                self.proj_matrix = torch.randn(
                    active_dim,
                    proj_dim,
                    device=device,
                    dtype=torch.bfloat16 #Add
                )
        elif self.method == "Random":
            if self.active_indices.numel() > proj_dim:
                torch.manual_seed(self.seed)
                indices = torch.randperm(self.active_indices.numel())[:proj_dim]
                self.active_indices = self.active_indices[indices]

        logger.debug(f"Initialized CudaProjector with method {method} and dimensions {feature_dim} -> {proj_dim}")


    def project(
        self,
        features: Union[dict, Tensor],
        ensemble_id: int,
    ) -> Tensor:
        """Performs the random projection on the feature matrix.

        Args:
            features: A batch of features or a dictionary
                of batch of features.
            ensemble_id: A unique ID for this ensemble.

        Raises:
            RuntimeError: The batch size of the CudaProjector is too large for
                your GPU.
            RuntimeError: Too many resources requested for launch CUDA.

        Returns:
            Tensor: The projected features.
        """

        if isinstance(features, dict):
            features = vectorize(features, device=self.device)
        elif features.device.type != self.device:
            features = features.to(self.device)
        batch_size = features.shape[0]

        effective_batch_size = 32
        min_proj_batch_size = 8
        if batch_size <= min_proj_batch_size:
            effective_batch_size = min_proj_batch_size
        elif batch_size <= min_proj_batch_size * 2:
            effective_batch_size = min_proj_batch_size * 2

        effective_batch_size = min(self.max_batch_size, effective_batch_size)

        if self.method == "FJLT":
            #TODO: fix due to the update
            import fast_jl
            function_name = f"project_{self.proj_type}_{effective_batch_size}"

            fn = getattr(fast_jl, function_name)

            try:
                result = fn(
                    features,
                    self.proj_dim,
                    self.seed + int(1e4) * ensemble_id,
                    self.num_sms,
                )
            except RuntimeError as e:
                if "CUDA error: too many resources requested for launch" in str(e):
                    # provide a more helpful error message
                    msg = "The batch size of the CudaProjector is too large for your GPU. \
                        Reduce it by using the proj_max_batch_size argument.\
                        \nOriginal error:"
                    raise RuntimeError(msg) from e
                raise e from None
        elif self.method == "SJLT":
            features = features[:, self.active_indices]

            if self.pre_compute:
                with torch.no_grad():
                    result = self.sjlt_cuda_module(features)
            else:
                try:
                    from .sjlt.sjlt_cuda import SJLTProjection

                    torch.manual_seed(self.seed)
                    active_dim = self.active_indices.numel()
                    rand_indices = torch.randint(self.proj_dim, (active_dim, self.c), device=self.device)
                    rand_signs = torch.randint(0, 2, (active_dim, self.c), device=self.device) * 2 - 1

                    sjlt_cuda_module = SJLTProjection(active_dim, self.proj_dim, self.c, device=self.device)

                    sjlt_cuda_module.rand_indices.copy_(rand_indices)
                    sjlt_cuda_module.rand_signs.copy_(rand_signs.to(torch.int8))

                    with torch.no_grad():
                        result = sjlt_cuda_module(features)
                except ImportError:
                    msg = "You should make sure that SJLT CUDA version can be installed correctly."
                    raise ModuleNotFoundError(msg) from None
        elif self.method == "Rademacher":
            if self.pre_compute:
                proj_matrix = self.proj_matrix
            else:
                active_dim = self.active_indices.numel()
                proj_matrix = torch.empty(
                    active_dim,
                    self.proj_dim,
                    device=self.device,
                )
                torch.manual_seed(self.seed)
                proj_matrix.bernoulli_(p=0.5)
                proj_matrix *= 2.0
                proj_matrix -= 1.0

            features = features[:, self.active_indices]
            result = features @ proj_matrix / (self.proj_dim ** 0.5)
        elif self.method == "Gaussian":
            if self.pre_compute:
                proj_matrix = self.proj_matrix
            else:
                active_dim = self.active_indices.numel()
                torch.manual_seed(self.seed)
                proj_matrix = torch.randn(
                    active_dim,
                    self.proj_dim,
                    device=self.device,
                    dtype=torch.bfloat16 #Add
                )

            features = features[:, self.active_indices]
            result = features @ proj_matrix / (self.proj_dim ** 0.5)
        elif self.method == "Random":
            features = features[:, self.active_indices]
            result = features
        elif self.method == "Localize":
            features = features[:, self.active_indices]
            result = features

        return result

    def free_memory(self) -> None:
        """A no-op method."""


class ChunkedCudaProjector:
    """Chunked CudaProjector implemented using CUDA.

    This projector is used when (# dim of features)*(# batch size) is too large.
    If the features are gradients, then (# dim of features) equals to the number
    of parameters in the model.
    """

    def __init__(
        self,
        projector_per_chunk: list,
        max_chunk_size: int,
        dim_per_chunk: list,
        feature_batch_size: int,
        proj_max_batch_size: int,
        device: torch.device,
        dtype: torch.dtype,
    ) -> None:
        """Initializes hyperparameters for ChunkedCudaProjector.

        Args:
            projector_per_chunk: A list of projectors. Specifying
                the projector used by each chunk.
            max_chunk_size: The maximum size of each chunk.
            dim_per_chunk: The number of feature dimensions per chunk.
            feature_batch_size: The batch size of input feature.
            proj_max_batch_size: The maximum batch size for each projector.
            device: Device to use. Will be "cuda" or "cpu".
            dtype: The dtype of the projected matrix.
        """
        self.projector_per_chunk = projector_per_chunk
        self.proj_dim = self.projector_per_chunk[0].proj_dim
        self.proj_type = self.projector_per_chunk[0].proj_type
        self.dim_per_chunk = dim_per_chunk
        self.feature_batch_size = feature_batch_size
        self.max_chunk_size = max_chunk_size
        self.proj_max_batch_size = proj_max_batch_size
        self.device = device
        self.dtype = dtype
        self.input_allocated = False

        logger.debug(f"Initialized ChunkedCudaProjector with {len(projector_per_chunk)} chunks")

    def allocate_input(self) -> None:
        """Allocate zero tensor for input."""
        if self.input_allocated:
            return

        self.ch_input = torch.zeros(
            size=(self.feature_batch_size, self.max_chunk_size),
            device=self.device,
            dtype=self.dtype,
        )

        self.input_allocated = True

    def free_memory(self) -> None:
        """Frees up memory used by the projector."""
        if not self.input_allocated:
            return

        del self.ch_input
        self.input_allocated = False

    def dict_project(self, features: Union[dict, Tensor], ensemble_id: int) -> Tensor:
        """Performs the random projection on the feature matrix.

        Args:
            features: A batch of features or a dictionary
                of batch of features.
            ensemble_id: A unique ID for this ensemble.

        Raises:
            ValueError: The number of accumulated #feature dim does not match
                dim_per_chunk.

        Returns:
            Tensor: The projected features.
        """
        self.allocate_input()
        ch_output = torch.zeros(
            size=(self.feature_batch_size, self.proj_dim),
            device=self.device,
            dtype=self.dtype,
        )
        pointer = 0
        # iterate over feature dimenions, keep a counter of #dim so far, and when prev
        # chunk reaches max_chunk_size, project and accumulate.
        projector_index = 0
        vector_dim = 1
        for _, p in enumerate(features.values()):
            # check the shape of p, if vector then unsqueeze.
            if len(p.shape) <= vector_dim:
                p_flat = p.data.unsqueeze(-1)
            else:
                p_flat = p.data.flatten(start_dim=1)

            feature_dim_size = p_flat.size(1)
            # if current accumulated params exceed max_chunk_size,
            # then stop accumulation.
            if pointer + feature_dim_size > self.max_chunk_size:
                # fill remaining entries with 0
                if pointer != self.dim_per_chunk[projector_index]:
                    msg = "Current number of accumulated #dim does not match \
                    the #feature dim of current chunk."
                    raise ValueError(msg)
                # project and accumulate
                ch_output.add_(
                    self.projector_per_chunk[projector_index].project(
                        self.ch_input[:, :pointer].contiguous(),
                        ensemble_id=ensemble_id,
                    ),
                )
                # reset counter
                pointer = 0
                projector_index += 1

            # continue accumulation
            actual_bs = min(self.ch_input.size(0), p_flat.size(0))
            self.ch_input[:actual_bs, pointer : pointer + feature_dim_size].copy_(
                p_flat,
            )
            pointer += feature_dim_size

        # at the end, we need to project remaining items
        # fill remaining entries with 0
        if pointer != self.dim_per_chunk[projector_index]:
            msg = "Current number of accumulated #dim does not match \
                    the #feature dim of current chunk."
            raise ValueError(msg)

        # project and accumulate
        ch_output[:actual_bs].add_(
            self.projector_per_chunk[projector_index].project(
                self.ch_input[:actual_bs, :pointer].contiguous(),
                ensemble_id=ensemble_id,
            ),
        )

        return ch_output[:actual_bs]

    def project(self, features: Union[dict, Tensor], ensemble_id: int) -> Tensor:
        """Performs the random projection on the feature matrix.

        Args:
            features: A batch of features or a dictionary
                of batch of features.
            ensemble_id: A unique ID for this ensemble.

        Returns:
            Tensor: The projected features.
        """
        # allocate zero tensor for output
        ch_output = torch.zeros(
            size=(self.feature_batch_size, self.proj_dim),
            device=self.device,
            dtype=self.dtype,
        )
        # force the input to be Tensor for now
        # TODO: support dict input
        if isinstance(features, dict):
            features = vectorize(features, device=self.device)

        pointer = 0
        for chunk_idx, chunk_dim in enumerate(self.dim_per_chunk):
            ch_output.add_(
                self.projector_per_chunk[chunk_idx].project(
                    features[:, pointer : pointer + chunk_dim].contiguous(),
                    ensemble_id=ensemble_id,
                ),
            )

            pointer += chunk_dim

        return ch_output


def make_random_projector(
    param_shape_list: List,
    feature_batch_size: int,
    proj_dim: int,
    proj_max_batch_size: int,
    device: str,
    proj_seed: int = 0,
    method: str = "Gaussian",
    *,
    use_half_precision: bool = True,
    active_indices: Optional[Tensor] = None,
    pre_compute: bool = False,
) -> Tensor:
    """Initialize random projector by the info of feature about to be projected.

    Args:
        param_shape_list: A list of numbers indicating the total number of
            features to be projected. A typical example is a list of total parameter
            size of each module in a torch.nn.Module model. Total parameter size
            of each module equals to feature_batch_size * param_size of that module.
        feature_batch_size: The batch size of each tensor in the feature
            about to be projected. The typical type of feature are gradients of
            torch.nn.Module model but can be restricted to this.
        proj_dim: Dimension of the projected feature.
        proj_max_batch_size: The maximum batch size used by fast_jl if the
            CudaProjector is used. Must be a multiple of 8. The maximum
            batch size is 32 for A100 GPUs, 16 for V100 GPUs, 40 for H100 GPUs.
        device: "cuda" or "cpu".
        proj_seed: Random seed used by the projector. Defaults to 0.
        method: The method used for the projection.
        use_half_precision: If True, torch.float16 will be used for all
            computations and arrays will be stored in torch.float16.
        active_indices: The indices of the features to be considered.
        pre_compute: If True, the projection construction will be pre-computed

    Returns:
        A function that takes projects feature to a smaller dimension.

    Raises:
        AttributeError: possible attribute error when initializing CudaProjector.
        ImportError: fast_jl is not installed.
        RuntimeError: Too many resources requested for launch CUDA. Try reduce
            proj_max_batch_size.
    """
    using_cuda_projector = False
    dtype = torch.float16 if use_half_precision else torch.float32
    # the total feature dim
    feature_dim = sum(param_shape_list)
    if device == "cpu":
        projector = BasicProjector
        # Sampling from bernoulli distribution is not supported for
        # dtype float16 on CPU; playing it safe here by defaulting to
        # normal projection, rather than rademacher.
        proj_type = ProjectionType.normal
    else:
        if method == "FJLT":
            try:
                import fast_jl

                test_feature = torch.ones(1, feature_dim).cuda()
                num_sms = torch.cuda.get_device_properties(
                    "cuda",
                ).multi_processor_count
                fast_jl.project_rademacher_8(
                    test_feature,
                    proj_dim,
                    0,
                    num_sms,
                )

            except (ImportError, RuntimeError, AttributeError):
                projector = BasicProjector
                raise
            proj_type = ProjectionType.rademacher
        elif method == "SJLT":
            proj_type = ProjectionType.rademacher
        elif method == "Rademacher":
            proj_type = ProjectionType.rademacher
        elif method == "Gaussian":
            proj_type = ProjectionType.normal
        elif method == "Random" or method == "Localize":
            proj_type = ProjectionType.identity

        projector = CudaProjector
        using_cuda_projector = True

    if using_cuda_projector:
        # TODO: make this support dict input
        # currently, only tensor input will be considered
        max_chunk_size, param_chunk_sizes = get_parameter_chunk_sizes(
            param_shape_list,
            proj_max_batch_size,
        )
        if len(param_chunk_sizes) > 1:  # we have to use the ChunkedCudaProjector
            rng = np.random.default_rng(proj_seed)
            # different seeds for each chunk
            seeds = rng.integers(
                low=0,
                high=500,
                size=len(param_chunk_sizes),
            )
            projector_per_chunk = [
                projector(
                    feature_dim=chunk_size,
                    proj_dim=proj_dim,
                    seed=seeds[i],
                    proj_type=proj_type,
                    max_batch_size=proj_max_batch_size,
                    device=device,
                    method=method,
                    active_indices=active_indices,
                    pre_compute=pre_compute,
                )
                for i, chunk_size in enumerate(param_chunk_sizes)
            ]
            return ChunkedCudaProjector(
                projector_per_chunk,
                max_chunk_size,
                param_chunk_sizes,
                feature_batch_size,
                proj_max_batch_size,
                device,
                dtype,
            )

    if projector == CudaProjector:
        assigned_projector = projector(
            feature_dim=feature_dim,
            proj_dim=proj_dim,
            seed=proj_seed,
            proj_type=proj_type,
            max_batch_size=proj_max_batch_size,
            device=device,
            method=method,
            active_indices=active_indices,
            pre_compute=pre_compute,
        )
    elif projector == BasicProjector:
        assigned_projector = projector(
            feature_dim=feature_dim,
            proj_dim=proj_dim,
            seed=proj_seed,
            proj_type=proj_type,
            dtype=dtype,
            device=device,
            method=method,
            active_indices=active_indices,
            pre_compute=pre_compute,
        )

    return assigned_projector


def random_project(
    feature: Union[Dict[str, Tensor], Tensor],
    feature_batch_size: int,
    proj_dim: int,
    proj_max_batch_size: int = 32,
    device: str = "cuda",
    proj_seed: int = 0,
    method: str = "Gaussian",
    *,
    use_half_precision: bool = True,
    active_indices: Optional[Tensor] = None,
    pre_compute: bool = False,
) -> Callable:
    """Randomly projects the features to a smaller dimension.

    Args:
        feature: The feature needs to be projected. This can simple be a tensor with size
            [feature_batch_size, feature_dim]. Or typically, if this is gradient of some
            torch.nn.Module models, it will have the structure similar to the result of
            model.named_parameters().
        feature_batch_size: The batch size of each tensor in the feature
            about to be projected. The typical type of feature are gradients of
            torch.nn.Module model but can restricted to this.
        proj_dim: Dimension of the projected feature.
        proj_max_batch_size: The maximum batch size used by fast_jl if the
            CudaProjector is used. Must be a multiple of 8. The maximum
            batch size is 32 for A100 GPUs, 16 for V100 GPUs, 40 for H100 GPUs.
        device: "cuda" or "cpu".
        proj_seed: Random seed used by the projector. Defaults to 0.
        method: The method used for the projection.
        use_half_precision: If True, torch.float16 will be used for all
            computations and arrays will be stored in torch.float16.
        active_indices: The indices of the features to be considered.
        pre_compute: If True, the projection construction will be pre-computed

    Returns:
        A function that takes projects feature to a smaller dimension.
    """
    # check the type of feature
    if isinstance(feature, dict):
        param_shape_list = [
            feature[param_name].numel() // feature_batch_size for param_name in feature
        ]
    else:
        param_shape_list = [feature.numel() // feature_batch_size]

    projector = make_random_projector(
        param_shape_list=param_shape_list,
        feature_batch_size=feature_batch_size,
        proj_dim=proj_dim,
        proj_max_batch_size=proj_max_batch_size,
        device=device,
        proj_seed=proj_seed,
        method=method,
        use_half_precision=use_half_precision,
        active_indices=active_indices,
        pre_compute=pre_compute,
    )

    def _random_project_func(
        feature: Union[Dict[str, Tensor], Tensor],
        ensemble_id: int = 0,
    ) -> Tensor:
        """The projection function using constructed projector.

        Args:
            feature: The feature needs to be
                projected. This can simple be a tensor with size [feature_batch_size,
                feature_dim]. Or typically, if the this is gradient of some
                torch.nn.Module models, it will have the structure similar to the
                result of model.named_parameters().
            ensemble_id: A unique ID for this ensemble. Defaults to 0.

        Returns:
            The projected result of feature, which is a tensor with size
                [feature_batch_size, proj_dim].
        """
        return projector.project(feature, ensemble_id)

    return _random_project_func