


import os.path

import numpy as np
from skimage.measure import block_reduce

from tqdm import tqdm

import aspire
from aspire.volume import Volume
from aspire.utils.rotation import Rotation

from logging import getLogger, basicConfig, INFO
basicConfig(level=INFO)
logger = getLogger(__name__)

DATA_FOLDER = os.path.normpath(os.path.join(os.path.dirname(__file__), "../../../../data/rotmol3d/"))


def mask_outside_sphere(data_array: np.ndarray, radius: float, fill_value: float = 0) -> np.ndarray:
    """
    Masks and crops elements of a NumPy array that fall outside a central sphere of given radius.
    The original array is not modified.

    Args:
        data_array (np.ndarray): The input NumPy array to be masked.
        radius (float): The radius of the central sphere. Points on the boundary
                        (distance == radius) are included (<=).
        fill_value (scalar): The value to assign to elements outside the sphere.
                             Defaults to 0.

    Returns:
        np.ndarray: A new NumPy array of the same shape as `data_array`,
                    with elements outside the sphere set to `fill_value`.
                    The returned array is cropped to the bounding box of the sphere.

    Raises:
        ValueError: If `radius` is negative.
        TypeError: If `data_array` is not a NumPy array.
        ValueError: If `data_array.shape` contains non-positive integers.
    """
    # --- Input Validation ---
    if not isinstance(data_array, np.ndarray):
        raise TypeError("data_array must be a NumPy array.")
    if not all(isinstance(dim, int) and dim > 0 for dim in data_array.shape):
         raise ValueError("All dimensions in data_array.shape must be positive integers.")
    if not isinstance(radius, (int, float)) or radius < 0:
        raise ValueError("radius must be a non-negative number.")

    array_shape = data_array.shape
    ndim = data_array.ndim
    center = [(dim - 1) / 2.0 for dim in array_shape]

    indices = np.indices(array_shape, sparse=False)
    distance_sq = np.sum([(indices[i] - center[i])**2 for i in range(ndim)], axis=0)
    mask_inside = distance_sq <= radius**2
    masked_array = np.copy(data_array)
    masked_array[~mask_inside] = fill_value
    
    # Find the bounding box of the sphere
    non_zero_indices = np.where(mask_inside)
    min_indices = [np.min(idx) for idx in non_zero_indices]
    max_indices = [np.max(idx) for idx in non_zero_indices]  
    crop_slices = tuple(slice(min_idx, max_idx + 1) for min_idx, max_idx in zip(min_indices, max_indices))
    
    return masked_array[crop_slices]


def gen_rotated_volumes(vol: Volume, mask_radius: int, downscale_factor: int, angles_rad: np.ndarray) -> list[np.ndarray]:
    logger.info(f'gen_rotated_volumes(vol, {mask_radius=}, {downscale_factor=}, {angles_rad=})')
    
    vol = vol.astype(np.float64)  # NUFFT gets mad if you try to rotate float32

    output_vols = []
    for angle in tqdm(angles_rad):
        logger.info(f'Rotating by angle {float(angle)} [rad] == {float(angle*180/np.pi)} [deg]...')
        rotation = Rotation.about_axis('z', angle)
        rotvol = vol.rotate(rotation)

        rotvol_numpy = rotvol.asnumpy()[0]

        logger.info(f'Masking outside sphere of radius {mask_radius}...')
        masked_rotvol_numpy = mask_outside_sphere(rotvol_numpy, mask_radius)

        logger.info('Downsampling...')
        ds_masked_rotvol_numpy = block_reduce(masked_rotvol_numpy, block_size=downscale_factor, func=np.mean)

        logger.info('Setting negative values to 0...')
        ds_masked_rotvol_numpy[ds_masked_rotvol_numpy<0] = 0

        output_vols.append(ds_masked_rotvol_numpy)

    return output_vols


def prepare_rotated_molecule_dataset(emdb_id: str, mask_radius: int, downscale_factor: int, angles_rad: np.ndarray):

    downloader_func = getattr(aspire.downloader, f'emdb_{emdb_id}')

    logger.info(f"Downloading molecular volume EMDB{emdb_id} (or cache) using aspire.downloader...")
    vol = downloader_func()

    rotated_vols = gen_rotated_volumes(vol, mask_radius, downscale_factor, angles_rad)

    output_filename = f"rotated_EMDB{emdb_id}_{mask_radius=}_{downscale_factor=}_n_angles={len(angles_rad)}.npy"
    output_path = os.path.join(DATA_FOLDER, output_filename)
    logger.info(f"Saving rotated volumes to {output_path}...")
    np.save(output_path, rotated_vols)

    angles_path = output_path.replace('.npy', '_angles.npy')
    logger.info(f"Saving angles to {angles_path}...")
    np.save(angles_path, angles_rad)


def prepare_all_datasets():
    ANGLES_DEG = np.linspace(0, 360, 18, endpoint=False)
    ANGLES_RAD = ANGLES_DEG * 2*np.pi/360
    logger.info(f"All angles [degrees]: {ANGLES_DEG}")
    logger.info(f"All angles [radians]: {ANGLES_RAD}")

    for downscale_factor in [32, 16, 8]:
        logger.info(f"Preparing all datasets with downscale factor {downscale_factor}...")

        # https://www.ebi.ac.uk/emdb/EMD-2660      
        logger.info("==== EMD-2660: Cryo-EM structure of the Plasmodium falciparum 80S ribosome bound to the anti-protozoan drug emetine")
        prepare_rotated_molecule_dataset('2660', 128, downscale_factor, ANGLES_RAD)

        # https://www.ebi.ac.uk/emdb/EMD-8012
        logger.info("==== EMD-8012: Cryo-EM structure of the yeast U4/U6.U5 tri-snRNP at 3.7 angstrom resolution.")
        prepare_rotated_molecule_dataset('8012', 128, downscale_factor, ANGLES_RAD)

        # https://www.ebi.ac.uk/emdb/EMD-14621
        logger.info("==== EMD-14621: Human coronavirus SARS-CoV-2 spike protein")
        prepare_rotated_molecule_dataset('14621', 128, downscale_factor, ANGLES_RAD)

        # https://www.ebi.ac.uk/emdb/EMD-2484
        logger.info("==== EMD-2484: Pre-fusion structure of trimeric HIV-1 envelope glycoprotein determined by cryo-electron microscopy")
        prepare_rotated_molecule_dataset('2484', 128, downscale_factor, ANGLES_RAD)

if __name__ == "__main__":
    prepare_all_datasets()
