from typing import Any, Dict, Sequence, Tuple
from openfold.np import protein
from openfold.np.relax import amber_minimize, utils
import numpy as np


class AmberRelaxation(object):

    def __init__(
        self,
        *,
        max_iterations: int,
        tolerance: float,
        stiffness: float,
        exclude_residues: Sequence[int],
        max_outer_iterations: int,
        use_gpu: bool,
    ):

        self._max_iterations = max_iterations
        self._tolerance = tolerance
        self._stiffness = stiffness
        self._exclude_residues = exclude_residues
        self._max_outer_iterations = max_outer_iterations
        self._use_gpu = use_gpu

    def process(
        self, *, prot: protein.Protein
    ) -> Tuple[str, Dict[str, Any], np.ndarray]:

        out = amber_minimize.run_pipeline(
            prot=prot,
            max_iterations=self._max_iterations,
            tolerance=self._tolerance,
            stiffness=self._stiffness,
            exclude_residues=self._exclude_residues,
            max_outer_iterations=self._max_outer_iterations,
            use_gpu=self._use_gpu,
        )
        min_pos = out["pos"]
        start_pos = out["posinit"]
        rmsd = np.sqrt(np.sum((start_pos - min_pos) ** 2) / start_pos.shape[0])
        debug_data = {
            "initial_energy": out["einit"],
            "final_energy": out["efinal"],
            "attempts": out["min_attempts"],
            "rmsd": rmsd,
        }
        pdb_str = amber_minimize.clean_protein(prot)
        min_pdb = utils.overwrite_pdb_coordinates(pdb_str, min_pos)
        min_pdb = utils.overwrite_b_factors(min_pdb, prot.b_factors)
        utils.assert_equal_nonterminal_atom_types(
            protein.from_pdb_string(min_pdb).atom_mask, prot.atom_mask
        )
        violations = out["structural_violations"]["total_per_residue_violations_mask"]
        return min_pdb, debug_data, violations
