import argparse
import os

import h5py
import mrcfile
import numpy as np
from chimerax.core.commands import run
from chimerax.core.session import Session

if "session" not in locals():
    session = Session()

parser = argparse.ArgumentParser(description="Convert PDB to density map.")
parser.add_argument(
    "--source_dir",
    type=str,
    help="Source directory containing PDB files.",
)
parser.add_argument(
    "--output_dir",
    type=str,
    help="Output directory for density maps.",
)
parser.add_argument(
    "--grid_spacing", type=float, default=1.0, help="Grid spacing for density map."
)
parser.add_argument("--grid_size", type=int, default=32, help="Grid size for density map.")
parser.add_argument("--output_mrc", action="store_true", help="Output MRC file.")
args = parser.parse_args()

source_dir = args.source_dir
output_dir = args.output_dir
grid_spacing = args.grid_spacing
grid_size = args.grid_size
output_mrc = args.output_mrc


def save_volumes_as_files(
    density_maps, origin=(0.0, 0.0, 0.0), grid_spacing=0.5, output_dir="volumes"
):
    """Save each volume in the density_maps as a separate MRC file."""
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    for i, volume in enumerate(density_maps):
        filename = os.path.join(output_dir, f"volume_{i+1:03d}.mrc")
        # ChimeraX uses a different coordinate system than MRC files
        # volume = np.flip(volume, axis=0)
        with mrcfile.new(filename, overwrite=True) as mrc:
            mrc.set_data(volume)
            mrc.voxel_size = grid_spacing
            mrc.header.origin = (origin[0], origin[1], origin[2])
        print(f"Saved {filename}")


def save_h5py(file_name, key, positions, densities, grid_size, grid_spacing):
    with h5py.File(file_name, "w") as f:
        # Create a group for the key
        group = f.create_group(key)

        # Save objects under the group
        group.create_dataset("positions", data=positions)
        group.create_dataset("densities", data=densities)
        group.attrs["grid_size"] = grid_size
        group.attrs["grid_spacing"] = grid_spacing


def reduce_grid(voxel_array):
    # Find the indices where values are non-zero
    non_zero_indices = np.nonzero(voxel_array)

    _, x_indices, y_indices, z_indices = non_zero_indices

    # Find the bounds for non-zero values in x, y, z
    x_min, x_max = x_indices.min(), x_indices.max()
    y_min, y_max = y_indices.min(), y_indices.max()
    z_min, z_max = z_indices.min(), z_indices.max()

    offset = np.array([x_min, y_min, z_min])

    # Extract the reduced grid
    reduced_grid = voxel_array[:, x_min : x_max + 1, y_min : y_max + 1, z_min : z_max + 1]

    return reduced_grid, offset


def calculate_bounding_box(model):
    """Calculate the bounding box that covers the entire molecule over all frames."""
    min_coords = np.array([float("inf")] * 3)
    max_coords = np.array([-float("inf")] * 3)

    num_frames = model.num_coordsets
    for frame in range(1, num_frames + 1):
        run(session, f"coordset #{model.id_string} {frame}")  # Set the frame
        coords = np.array([a.coord for a in model.atoms])  # Get atom coordinates
        min_coords = np.minimum(min_coords, coords.min(axis=0))
        max_coords = np.maximum(max_coords, coords.max(axis=0))

    return min_coords, max_coords


def create_initial_grid(grid_size: int, grid_spacing=0.5, origin=(0, 0, 0), model_id=10):
    """Create an initial grid that covers the bounding box."""
    # Use ChimeraX commands to create a grid
    run(
        session,
        f"vop new size {grid_size} gridSpacing {grid_spacing} origin {origin[0]},{origin[1]},{origin[2]} modelid {model_id}",
    )
    return origin


def generate_density_maps(model, resolution=1, model_id=10):
    """Generate density maps for each frame using the initial grid parameters."""
    density_maps = []

    num_frames = model.num_coordsets
    for frame in range(1, num_frames + 1):
        run(session, f"coordset #{model.id_string} {frame}")  # Set the frame
        run(
            session, f"molmap #{model.id_string} {resolution} onGrid #{model_id}"
        )  # Generate density map

        # Get the volume model created by molmap
        volumes = [v for v in session.models.list() if "Volume" in v.__class__.__name__]
        if not volumes:
            raise RuntimeError("No volume model found")

        volume = volumes[2]
        grid_data = volume.data.full_matrix()
        density_maps.append(grid_data)

        run(session, f"close #{volume.id_string}")

    return density_maps


def save_density_maps(density_maps, file_path="density_maps.npy"):
    """Save the density maps as a numpy file."""
    np.save(file_path, density_maps)


model_id = 10
out_put_dir = "scripts/ChimeraX/volumes_new"

# Open the trajectory and topology files
run(session, f"open {source_dir}/alanine-dipeptide.pdb")
run(session, f"open {source_dir}/traj3_cent.dcd")

# Get the model containing the atomic structure
model = session.models.list()[0]

# Calculate the bounding box
min_coords, max_coords = calculate_bounding_box(model)
model_center_coords = (min_coords + max_coords) / 2

origin = model_center_coords - grid_size * grid_spacing / 2

create_initial_grid(
    grid_size=grid_size, grid_spacing=grid_spacing, origin=origin, model_id=model_id
)

# Generate density maps
density_maps = generate_density_maps(model, model_id=model_id)
density_maps = np.stack(density_maps, axis=0)

if output_mrc:
    save_volumes_as_files(
        density_maps, origin=origin, grid_spacing=grid_spacing, output_dir=output_dir
    )

# Reduce the grid size
density_maps, offset_indices = reduce_grid(density_maps)
positions = np.indices(density_maps.shape).reshape(4, -1).T
offset = offset_indices / grid_size
positions = positions[:, 1:] / grid_size + offset
densities = density_maps.flatten()

all_pos = []
all_den = []

for pos, dens in zip(positions, densities):
    all_pos.append(pos[dens > 0])
    all_den.append(dens[dens > 0])

positions = positions.reshape(density_maps.shape[0], -1, 3)
densities = densities.reshape(density_maps.shape[0], -1)

save_h5py(
    file_name=f"{output_dir}/data.h5",
    key="0001",
    positions=positions,
    densities=densities,
    grid_size=grid_size,
    grid_spacing=grid_spacing,
)

run(session, "exit")
