import io
import time
from typing import Collection, Optional, Sequence

from absl import logging
from openfold.np import (
    protein,
    residue_constants,
)
import openfold.utils.loss as loss
from openfold.np.relax import cleanup, utils
import ml_collections
import numpy as np
import openmm
from openmm import unit
from openmm import app as openmm_app
from openmm.app.internal.pdbstructure import PdbStructure

ENERGY = unit.kilocalories_per_mole
LENGTH = unit.angstroms


def will_restrain(atom: openmm_app.Atom, rset: str) -> bool:

    if rset == "non_hydrogen":
        return atom.element.name != "hydrogen"
    elif rset == "c_alpha":
        return atom.name == "CA"


def _add_restraints(
    system: openmm.System,
    reference_pdb: openmm_app.PDBFile,
    stiffness: unit.Unit,
    rset: str,
    exclude_residues: Sequence[int],
):

    assert rset in ["non_hydrogen", "c_alpha"]

    force = openmm.CustomExternalForce("0.5 * k * ((x-x0)^2 + (y-y0)^2 + (z-z0)^2)")
    force.addGlobalParameter("k", stiffness)
    for p in ["x0", "y0", "z0"]:
        force.addPerParticleParameter(p)

    for i, atom in enumerate(reference_pdb.topology.atoms()):
        if atom.residue.index in exclude_residues:
            continue
        if will_restrain(atom, rset):
            force.addParticle(i, reference_pdb.positions[i])
    logging.info(
        "Restraining %d / %d particles.",
        force.getNumParticles(),
        system.getNumParticles(),
    )
    system.addForce(force)


def _openmm_minimize(
    pdb_str: str,
    max_iterations: int,
    tolerance: unit.Unit,
    stiffness: unit.Unit,
    restraint_set: str,
    exclude_residues: Sequence[int],
    use_gpu: bool,
):

    pdb_file = io.StringIO(pdb_str)
    pdb = openmm_app.PDBFile(pdb_file)

    force_field = openmm_app.ForceField("amber99sb.xml")
    constraints = openmm_app.HBonds
    system = force_field.createSystem(pdb.topology, constraints=constraints)
    if stiffness > 0 * ENERGY / (LENGTH**2):
        _add_restraints(system, pdb, stiffness, restraint_set, exclude_residues)

    integrator = openmm.LangevinIntegrator(0, 0.01, 0.0)
    platform = openmm.Platform.getPlatformByName("CUDA" if use_gpu else "CPU")
    simulation = openmm_app.Simulation(pdb.topology, system, integrator, platform)
    simulation.context.setPositions(pdb.positions)

    ret = {}
    state = simulation.context.getState(getEnergy=True, getPositions=True)
    ret["einit"] = state.getPotentialEnergy().value_in_unit(ENERGY)
    ret["posinit"] = state.getPositions(asNumpy=True).value_in_unit(LENGTH)
    simulation.minimizeEnergy(maxIterations=max_iterations, tolerance=tolerance)
    state = simulation.context.getState(getEnergy=True, getPositions=True)
    ret["efinal"] = state.getPotentialEnergy().value_in_unit(ENERGY)
    ret["pos"] = state.getPositions(asNumpy=True).value_in_unit(LENGTH)
    ret["min_pdb"] = _get_pdb_string(simulation.topology, state.getPositions())
    return ret


def _get_pdb_string(topology: openmm_app.Topology, positions: unit.Quantity):

    with io.StringIO() as f:
        openmm_app.PDBFile.writeFile(topology, positions, f)
        return f.getvalue()


def _check_cleaned_atoms(pdb_cleaned_string: str, pdb_ref_string: str):

    cleaned = openmm_app.PDBFile(io.StringIO(pdb_cleaned_string))
    reference = openmm_app.PDBFile(io.StringIO(pdb_ref_string))

    cl_xyz = np.array(cleaned.getPositions().value_in_unit(LENGTH))
    ref_xyz = np.array(reference.getPositions().value_in_unit(LENGTH))

    for ref_res, cl_res in zip(
        reference.topology.residues(), cleaned.topology.residues()
    ):
        assert ref_res.name == cl_res.name
        for rat in ref_res.atoms():
            for cat in cl_res.atoms():
                if cat.name == rat.name:
                    if not np.array_equal(cl_xyz[cat.index], ref_xyz[rat.index]):
                        raise ValueError(
                            f"Coordinates of cleaned atom {cat} do not match "
                            f"coordinates of reference atom {rat}."
                        )


def _check_residues_are_well_defined(prot: protein.Protein):

    if (prot.atom_mask.sum(axis=-1) == 0).any():
        raise ValueError(
            "Amber minimization can only be performed on proteins with"
            " well-defined residues. This protein contains at least"
            " one residue with no atoms."
        )


def _check_atom_mask_is_ideal(prot):

    atom_mask = prot.atom_mask
    ideal_atom_mask = protein.ideal_atom_mask(prot)
    utils.assert_equal_nonterminal_atom_types(atom_mask, ideal_atom_mask)


def clean_protein(prot: protein.Protein, checks: bool = True):

    _check_atom_mask_is_ideal(prot)

    prot_pdb_string = protein.to_pdb(prot)
    pdb_file = io.StringIO(prot_pdb_string)
    alterations_info = {}
    fixed_pdb = cleanup.fix_pdb(pdb_file, alterations_info)
    fixed_pdb_file = io.StringIO(fixed_pdb)
    pdb_structure = PdbStructure(fixed_pdb_file)
    cleanup.clean_structure(pdb_structure, alterations_info)

    logging.info("alterations info: %s", alterations_info)

    as_file = openmm_app.PDBFile(pdb_structure)
    pdb_string = _get_pdb_string(as_file.getTopology(), as_file.getPositions())
    if checks:
        _check_cleaned_atoms(pdb_string, prot_pdb_string)
    return pdb_string


def make_atom14_positions(prot):

    restype_atom14_to_atom37 = []
    restype_atom37_to_atom14 = []
    restype_atom14_mask = []

    for rt in residue_constants.restypes:
        atom_names = residue_constants.restype_name_to_atom14_names[
            residue_constants.restype_1to3[rt]
        ]

        restype_atom14_to_atom37.append(
            [(residue_constants.atom_order[name] if name else 0) for name in atom_names]
        )

        atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
        restype_atom37_to_atom14.append(
            [
                (atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0)
                for name in residue_constants.atom_types
            ]
        )

        restype_atom14_mask.append([(1.0 if name else 0.0) for name in atom_names])

    restype_atom14_to_atom37.append([0] * 14)
    restype_atom37_to_atom14.append([0] * 37)
    restype_atom14_mask.append([0.0] * 14)

    restype_atom14_to_atom37 = np.array(restype_atom14_to_atom37, dtype=int)
    restype_atom37_to_atom14 = np.array(restype_atom37_to_atom14, dtype=int)
    restype_atom14_mask = np.array(restype_atom14_mask, dtype=np.float32)

    residx_atom14_to_atom37 = restype_atom14_to_atom37[prot["aatype"]]
    residx_atom14_mask = restype_atom14_mask[prot["aatype"]]

    residx_atom14_gt_mask = residx_atom14_mask * np.take_along_axis(
        prot["all_atom_mask"], residx_atom14_to_atom37, axis=1
    ).astype(np.float32)

    residx_atom14_gt_positions = residx_atom14_gt_mask[:, :, None] * (
        np.take_along_axis(
            prot["all_atom_positions"],
            residx_atom14_to_atom37[..., None],
            axis=1,
        )
    )

    prot["atom14_atom_exists"] = residx_atom14_mask
    prot["atom14_gt_exists"] = residx_atom14_gt_mask
    prot["atom14_gt_positions"] = residx_atom14_gt_positions

    prot["residx_atom14_to_atom37"] = residx_atom14_to_atom37.astype(np.int64)

    residx_atom37_to_atom14 = restype_atom37_to_atom14[prot["aatype"]]
    prot["residx_atom37_to_atom14"] = residx_atom37_to_atom14.astype(np.int64)

    restype_atom37_mask = np.zeros([21, 37], dtype=np.float32)
    for restype, restype_letter in enumerate(residue_constants.restypes):
        restype_name = residue_constants.restype_1to3[restype_letter]
        atom_names = residue_constants.residue_atoms[restype_name]
        for atom_name in atom_names:
            atom_type = residue_constants.atom_order[atom_name]
            restype_atom37_mask[restype, atom_type] = 1

    residx_atom37_mask = restype_atom37_mask[prot["aatype"]]
    prot["atom37_atom_exists"] = residx_atom37_mask

    restype_3 = [
        residue_constants.restype_1to3[res] for res in residue_constants.restypes
    ]
    restype_3 += ["UNK"]

    all_matrices = {res: np.eye(14, dtype=np.float32) for res in restype_3}
    for resname, swap in residue_constants.residue_atom_renaming_swaps.items():
        correspondences = np.arange(14)
        for source_atom_swap, target_atom_swap in swap.items():
            source_index = residue_constants.restype_name_to_atom14_names[
                resname
            ].index(source_atom_swap)
            target_index = residue_constants.restype_name_to_atom14_names[
                resname
            ].index(target_atom_swap)
            correspondences[source_index] = target_index
            correspondences[target_index] = source_index
            renaming_matrix = np.zeros((14, 14), dtype=np.float32)
            for index, correspondence in enumerate(correspondences):
                renaming_matrix[index, correspondence] = 1.0
        all_matrices[resname] = renaming_matrix.astype(np.float32)
    renaming_matrices = np.stack([all_matrices[restype] for restype in restype_3])

    renaming_transform = renaming_matrices[prot["aatype"]]

    alternative_gt_positions = np.einsum(
        "rac,rab->rbc", residx_atom14_gt_positions, renaming_transform
    )
    prot["atom14_alt_gt_positions"] = alternative_gt_positions

    alternative_gt_mask = np.einsum(
        "ra,rab->rb", residx_atom14_gt_mask, renaming_transform
    )

    prot["atom14_alt_gt_exists"] = alternative_gt_mask

    restype_atom14_is_ambiguous = np.zeros((21, 14), dtype=np.float32)
    for resname, swap in residue_constants.residue_atom_renaming_swaps.items():
        for atom_name1, atom_name2 in swap.items():
            restype = residue_constants.restype_order[
                residue_constants.restype_3to1[resname]
            ]
            atom_idx1 = residue_constants.restype_name_to_atom14_names[resname].index(
                atom_name1
            )
            atom_idx2 = residue_constants.restype_name_to_atom14_names[resname].index(
                atom_name2
            )
            restype_atom14_is_ambiguous[restype, atom_idx1] = 1
            restype_atom14_is_ambiguous[restype, atom_idx2] = 1

    prot["atom14_atom_is_ambiguous"] = restype_atom14_is_ambiguous[prot["aatype"]]

    return prot


def find_violations(prot_np: protein.Protein):

    batch = {
        "aatype": prot_np.aatype,
        "all_atom_positions": prot_np.atom_positions.astype(np.float32),
        "all_atom_mask": prot_np.atom_mask.astype(np.float32),
        "residue_index": prot_np.residue_index,
    }

    batch["seq_mask"] = np.ones_like(batch["aatype"], np.float32)
    batch = make_atom14_positions(batch)

    violations = loss.find_structural_violations_np(
        batch=batch,
        atom14_pred_positions=batch["atom14_gt_positions"],
        config=ml_collections.ConfigDict(
            {
                "violation_tolerance_factor": 12,
                "clash_overlap_tolerance": 1.5,
            }
        ),
    )
    violation_metrics = loss.compute_violation_metrics_np(
        batch=batch,
        atom14_pred_positions=batch["atom14_gt_positions"],
        violations=violations,
    )

    return violations, violation_metrics


def get_violation_metrics(prot: protein.Protein):

    structural_violations, struct_metrics = find_violations(prot)
    violation_idx = np.flatnonzero(
        structural_violations["total_per_residue_violations_mask"]
    )

    struct_metrics["residue_violations"] = violation_idx
    struct_metrics["num_residue_violations"] = len(violation_idx)
    struct_metrics["structural_violations"] = structural_violations
    return struct_metrics


def _run_one_iteration(
    *,
    pdb_string: str,
    max_iterations: int,
    tolerance: float,
    stiffness: float,
    restraint_set: str,
    max_attempts: int,
    exclude_residues: Optional[Collection[int]] = None,
    use_gpu: bool,
):

    exclude_residues = exclude_residues or []

    tolerance = tolerance * ENERGY
    stiffness = stiffness * ENERGY / (LENGTH**2)

    start = time.perf_counter()
    minimized = False
    attempts = 0
    while not minimized and attempts < max_attempts:
        attempts += 1
        try:
            logging.info(
                "Minimizing protein, attempt %d of %d.", attempts, max_attempts
            )
            ret = _openmm_minimize(
                pdb_string,
                max_iterations=max_iterations,
                tolerance=tolerance,
                stiffness=stiffness,
                restraint_set=restraint_set,
                exclude_residues=exclude_residues,
                use_gpu=use_gpu,
            )
            minimized = True
        except Exception as e:
            print(e)
            logging.info(e)
    if not minimized:
        raise ValueError(f"Minimization failed after {max_attempts} attempts.")
    ret["opt_time"] = time.perf_counter() - start
    ret["min_attempts"] = attempts
    return ret


def run_pipeline(
    prot: protein.Protein,
    stiffness: float,
    use_gpu: bool,
    max_outer_iterations: int = 1,
    place_hydrogens_every_iteration: bool = True,
    max_iterations: int = 0,
    tolerance: float = 2.39,
    restraint_set: str = "non_hydrogen",
    max_attempts: int = 100,
    checks: bool = True,
    exclude_residues: Optional[Sequence[int]] = None,
):

    _check_residues_are_well_defined(prot)
    pdb_string = clean_protein(prot, checks=checks)

    exclude_residues = exclude_residues or []
    exclude_residues = set(exclude_residues)
    violations = np.inf
    iteration = 0

    while violations > 0 and iteration < max_outer_iterations:
        ret = _run_one_iteration(
            pdb_string=pdb_string,
            exclude_residues=exclude_residues,
            max_iterations=max_iterations,
            tolerance=tolerance,
            stiffness=stiffness,
            restraint_set=restraint_set,
            max_attempts=max_attempts,
            use_gpu=use_gpu,
        )
        prot = protein.from_pdb_string(ret["min_pdb"])
        if place_hydrogens_every_iteration:
            pdb_string = clean_protein(prot, checks=True)
        else:
            pdb_string = ret["min_pdb"]
        ret.update(get_violation_metrics(prot))
        ret.update(
            {
                "num_exclusions": len(exclude_residues),
                "iteration": iteration,
            }
        )
        violations = ret["violations_per_residue"]
        exclude_residues = exclude_residues.union(ret["residue_violations"])

        logging.info(
            "Iteration completed: Einit %.2f Efinal %.2f Time %.2f s "
            "num residue violations %d num residue exclusions %d ",
            ret["einit"],
            ret["efinal"],
            ret["opt_time"],
            ret["num_residue_violations"],
            ret["num_exclusions"],
        )
        iteration += 1
    return ret


def get_initial_energies(
    pdb_strs: Sequence[str],
    stiffness: float = 0.0,
    restraint_set: str = "non_hydrogen",
    exclude_residues: Optional[Sequence[int]] = None,
):

    exclude_residues = exclude_residues or []

    openmm_pdbs = [openmm_app.PDBFile(PdbStructure(io.StringIO(p))) for p in pdb_strs]
    force_field = openmm_app.ForceField("amber99sb.xml")
    system = force_field.createSystem(
        openmm_pdbs[0].topology, constraints=openmm_app.HBonds
    )
    stiffness = stiffness * ENERGY / (LENGTH**2)
    if stiffness > 0 * ENERGY / (LENGTH**2):
        _add_restraints(
            system, openmm_pdbs[0], stiffness, restraint_set, exclude_residues
        )
    simulation = openmm_app.Simulation(
        openmm_pdbs[0].topology,
        system,
        openmm.LangevinIntegrator(0, 0.01, 0.0),
        openmm.Platform.getPlatformByName("CPU"),
    )
    energies = []
    for pdb in openmm_pdbs:
        try:
            simulation.context.setPositions(pdb.positions)
            state = simulation.context.getState(getEnergy=True)
            energies.append(state.getPotentialEnergy().value_in_unit(ENERGY))
        except Exception as e:
            logging.error("Error getting initial energy, returning large value %s", e)
            energies.append(unit.Quantity(1e20, ENERGY))
    return energies
