import argparse
import os

import h5py
import mrcfile
import numpy as np
from chimerax.core.commands import run
from chimerax.core.session import Session

if "session" not in locals():
    session = Session()

parser = argparse.ArgumentParser(description="Convert PDB to density map.")
parser.add_argument(
    "--source_dir",
    type=str,
    help="Source directory containing PDB files.",
)
parser.add_argument(
    "--output_dir",
    type=str,
    help="Output directory for density maps.",
)
parser.add_argument(
    "--grid_spacing", type=float, default=1.0, help="Grid spacing for density map."
)
parser.add_argument("--grid_size", type=int, default=16, help="Grid size for density map.")
parser.add_argument("--output_mrc", action="store_true", help="Output MRC file.")
args = parser.parse_args()

source_dir = args.source_dir
output_dir = args.output_dir
grid_spacing = args.grid_spacing
grid_size = args.grid_size
output_mrc = args.output_mrc


def clean():
    run(session, "delete ligand")
    run(session, "delete ions")
    run(session, "delete solvent")


def min_max_coord(model):
    coords = np.array([a.coord for a in model.atoms])
    min_coords = coords.min(axis=0)
    max_coords = coords.max(axis=0)
    return min_coords, max_coords


def create_initial_grid(grid_size: int, grid_spacing=0.5, origin=(0, 0, 0), model_id=10):
    """Create an initial grid that covers the bounding box."""
    # Use ChimeraX commands to create a grid
    run(
        session,
        f"vop new size {grid_size} gridSpacing {grid_spacing} origin {origin[0]},{origin[1]},{origin[2]} modelid {model_id}",
    )
    return origin


def reduce_grid(voxel_array):
    # Find the indices where values are non-zero
    non_zero_indices = np.nonzero(voxel_array)

    x_indices, y_indices, z_indices = non_zero_indices

    # Find the bounds for non-zero values in x, y, z
    x_min, x_max = x_indices.min(), x_indices.max()
    y_min, y_max = y_indices.min(), y_indices.max()
    z_min, z_max = z_indices.min(), z_indices.max()

    offset = np.array([x_min, y_min, z_min])

    # Extract the reduced grid
    reduced_grid = voxel_array[x_min : x_max + 1, y_min : y_max + 1, z_min : z_max + 1]

    return reduced_grid, offset


def process_pdb(filepath, grid_size=100, grid_spacing=0.5, resolution=1):
    run(session, f"open {filepath}")
    clean()

    model = session.models[0]
    min_coords, max_coords = min_max_coord(model=model)
    model_center_coords = (min_coords + max_coords) / 2
    origin = model_center_coords - grid_size * grid_spacing / 2

    gird_id = 10
    create_initial_grid(
        grid_size=grid_size, grid_spacing=grid_spacing, origin=origin, model_id=gird_id
    )
    run(session, f"molmap #{model.id_string} {resolution} onGrid #{gird_id}")

    volumes = [v for v in session.models.list() if "Volume" in v.__class__.__name__]
    volume = volumes[2]
    grid_data = volume.data.full_matrix()
    return grid_data, origin


def save_structures(structures, filename):
    with h5py.File(filename, "w") as f:
        for key, value in structures.items():
            group = f.create_group(key)
            group.create_dataset("positions", data=value["positions"])
            group.create_dataset("densities", data=value["densities"])


def save_volumes_mrc(
    volume,
    filename,
    origin=(0.0, 0.0, 0.0),
    grid_spacing=0.5,
):
    """Save each volume in the density_map as a MRC file."""

    with mrcfile.new(filename, overwrite=True) as mrc:
        mrc.set_data(volume)
        mrc.voxel_size = grid_spacing
        mrc.header.origin = (origin[0], origin[1], origin[2])
        print(f"Saved {filename}")


structures = {}

for pdb_file in os.listdir(source_dir):
    if pdb_file.endswith(".pdb"):
        pdb_file_path = os.path.join(source_dir, pdb_file)
        grid_data, origin = process_pdb(
            pdb_file_path, grid_size=grid_size, grid_spacing=grid_spacing
        )
        reduced_grid, offset_indices = reduce_grid(grid_data)
        positions = np.indices(reduced_grid.shape).reshape(3, -1).T

        offset = offset_indices / grid_size
        positions = positions / grid_size + offset
        densities = reduced_grid.flatten()

        positions = positions[densities > 0]
        densities = densities[densities > 0]

        pdb_name = pdb_file.split(".")[0]
        structures[pdb_name] = {"positions": positions, "densities": densities, "origin": origin}

        if output_mrc:
            structures[pdb_name]["grid"] = grid_data

        run(session, "close all")

save_structures(structures, f"{output_dir}/structures.h5")

if output_mrc:
    for name, structure in structures.items():
        save_volumes_mrc(
            structure["grid"],
            f"{output_dir}/{name}.mrc",
            grid_spacing=grid_spacing,
            origin=structure["origin"],
        )
