from collections import defaultdict

import numpy as np
import torch

from torch_scatter import scatter_mean, scatter_sum
from torch_geometric.utils import to_dense_batch

from flexdock.geometry.ops import rigid_transform_kabsch

try:
    from relaxflow.utils.forces import (
        get_torsion_force,
        get_bond_force,
        get_angle_force,
    )
except Exception:
    pass


def set_time(batch, t):
    batch.complex_t = torch.full(
        (batch.num_graphs,), t, device=batch["conf_idx"].device
    )
    batch["ligand"].node_t = torch.full(
        (batch["ligand"].num_nodes,), t, device=batch["ligand"].pos.device
    )
    batch["receptor"].node_t = torch.full(
        (batch["receptor"].num_nodes,), t, device=batch["receptor"].pos.device
    )
    batch["atom"].node_t = torch.full(
        (batch["atom"].num_nodes,), t, device=batch["atom"].pos.device
    )


def center_complex(data):
    atom_center = scatter_mean(
        data["atom"].pos[data["atom"].nearby_atom_mask],
        data["atom"].batch[data["atom"].nearby_atom_mask],
        dim=0,
    )
    # atom_center = scatter_mean(data['atom'].pos, data['atom'].batch, dim=0)
    data["ligand"].pos -= atom_center[data["ligand"].batch]
    data["atom"].pos -= atom_center[data["atom"].batch]
    data["receptor"].pos = data["atom"].pos[data["atom"].ca_mask]
    return data


def sigmoid(t):
    return 1 / (1 + np.e ** (-t))


def sigmoid_schedule(t, k):
    s = lambda t: sigmoid(k * t)
    return (s(t) - s(0)) / (s(1) - s(0))


def exponential_schedule(t, a):
    return (np.exp(a * t) - 1) / (np.exp(a) - 1)

def gaussian_ppf_torch(p, loc=0.0, scale=1.0, alpha_exp: float = 1.0):
    p = torch.as_tensor(p)
    loc = torch.as_tensor(loc, device=p.device, dtype=p.dtype)
    scale = torch.as_tensor(scale, device=p.device, dtype=p.dtype)

    p = 0.5 * (p ** alpha_exp)
    eps = torch.finfo(p.dtype).eps
    p = p.clamp(min=eps, max=1 - eps)

    z = torch.sqrt(torch.tensor(2.0, device=p.device, dtype=p.dtype)) * torch.special.erfinv(2 * p - 1)
    v = loc + scale * z
    v = torch.clamp(v, max=0.0)
    return v

@torch.no_grad()
def apply_posebusters_bias(
    batch,
    bond_buffer: float = 0.25 - 1e-3,
    angle_buffer: float = 0.25 - 1e-3,
    steric_buffer: float = 0.2 - 1e-3,
    overlap_buffer: float = 0.75 - 1e-3,
    step_scale: float = 0.05,
    max_step_total: float = 0.25,
    alpha_exp: float = 1.0,
    extra_buffer_ratio: float = 1.0,
    t=None,
):
    """Apply a lightweight geometry bias to ligand positions using PoseBusters-style constraints.

    This heuristically nudges ligand atoms to reduce violations while sampling,
    without requiring gradients. It considers bond/angle bounds, a steric lower bound,
    and ligand–receptor overlap (vdW) push-apart.

    Returns the updated batch and a metrics dict.
    """
    try:
        pos = batch["ligand"].pos
        edge_index = batch["ligand", "lig_edge", "ligand"].posebusters_edge_index
        lower_bounds = batch["ligand", "lig_edge", "ligand"].lower_bound
        upper_bounds = batch["ligand", "lig_edge", "ligand"].upper_bound
        bond_mask = batch["ligand", "lig_edge", "ligand"].posebusters_bond_mask
        angle_mask = batch["ligand", "lig_edge", "ligand"].posebusters_angle_mask
        non_ba_mask = ~(bond_mask | angle_mask)

        src = pos.index_select(0, edge_index[0])
        tgt = pos.index_select(0, edge_index[1])
        vec = tgt - src
        d = torch.linalg.norm(vec, dim=-1, keepdim=True).clamp_min(1e-8)
        u = vec / d


        if t is not None:
            t_val = torch.as_tensor(t, device=pos.device, dtype=pos.dtype)
            v_val = gaussian_ppf_torch(p=t_val, loc=0.0, scale=(1.0 - t_val), alpha_exp=alpha_exp)
            eff_edge = torch.clamp(1.0 - v_val, min=1.0, max=(1.0 + extra_buffer_ratio)).view(1, 1).expand_as(d)
        else:
            raise ValueError("t must be provided")


        w = torch.ones_like(d)


        eff_e = eff_edge.squeeze(-1)
        upper_bond  = upper_bounds * (1.0 + eff_e * bond_buffer)
        lower_bond  = lower_bounds * (1.0 - eff_e * bond_buffer)
        upper_angle = upper_bounds * (1.0 + eff_e * angle_buffer)
        lower_angle = lower_bounds * (1.0 - eff_e * angle_buffer)
        lower_steric = lower_bounds * (1.0 - eff_e * steric_buffer)

        d_s = d.squeeze(-1)

        # Violations
        bond_too_short = bond_mask & (d_s < lower_bond)
        bond_too_long = bond_mask & (d_s > upper_bond)
        angle_too_short = angle_mask & (d_s < lower_angle)
        angle_too_long = angle_mask & (d_s > upper_angle)
        steric_too_short = non_ba_mask & (d_s < lower_steric)

        # Initialize offsets
        offsets = torch.zeros_like(pos)

        def _accumulate(mask, desired_delta, pull: bool):
            if mask.any():
                m_idx = torch.nonzero(mask, as_tuple=False).squeeze(-1)
                u_m = u.index_select(0, m_idx)
                src_idx = edge_index[0].index_select(0, m_idx)
                tgt_idx = edge_index[1].index_select(0, m_idx)
                step_w = w.index_select(0, m_idx)
                disp = (step_scale * desired_delta.index_select(0, m_idx)).unsqueeze(-1)
                # disp = torch.clamp(disp, max=max_step)
                disp = disp * u_m * step_w
                sgn = 1.0 if pull else -1.0
                offsets.index_add_(0, src_idx, sgn * 0.5 * disp.squeeze(-1))
                offsets.index_add_(0, tgt_idx, -sgn * 0.5 * disp.squeeze(-1))

        # Bond
        if bond_too_short.any():
            delta = (lower_bond - d_s).clamp_min(0.0)
            _accumulate(bond_too_short, delta, pull=False)
        if bond_too_long.any():
            delta = (d_s - upper_bond).clamp_min(0.0)
            _accumulate(bond_too_long, delta, pull=True)

        # Angle
        if angle_too_short.any():
            delta = (lower_angle - d_s).clamp_min(0.0)
            _accumulate(angle_too_short, delta, pull=False)
        if angle_too_long.any():
            delta = (d_s - upper_angle).clamp_min(0.0)
            _accumulate(angle_too_long, delta, pull=True)

        # Steric (only push apart)
        if steric_too_short.any():
            delta = (lower_steric - d_s).clamp_min(0.0)
            _accumulate(steric_too_short, delta, pull=False)

        # Keep bond/angle/steric offsets to combine later
        geom_offsets = offsets

        # Overlap bias (ligand vs receptor atoms; push ligand away if vdW overlap exceeds buffer)

        lig_dense_pos, lig_dense_mask = to_dense_batch(batch["ligand"].pos, batch["ligand"].batch)
        atom_dense_pos, atom_dense_mask = to_dense_batch(batch["atom"].pos, batch["atom"].batch)
        lig_dense_radii, _ = to_dense_batch(batch["ligand"].vdw_radii, batch["ligand"].batch)
        atom_dense_radii, _ = to_dense_batch(batch["atom"].vdw_radii, batch["atom"].batch)

        # Pairwise vectors and distances ligand->atom
        l2a_vec = atom_dense_pos.unsqueeze(1) - lig_dense_pos.unsqueeze(2)  # (B, N_lig, N_atom, 3)
        d_pair = torch.linalg.norm(l2a_vec, dim=-1).clamp_min(1e-8)  # (B, N_lig, N_atom)
        pair_mask = lig_dense_mask.unsqueeze(2) & atom_dense_mask.unsqueeze(1)


        if t is not None:

            eff_scalar = torch.clamp(1.0 - v_val, min=1.0, max=(1.0 + extra_buffer_ratio))
            overlap_buffer_b = overlap_buffer * eff_scalar
        else:
            raise ValueError("t must be provided")

        vdw_overlap = lig_dense_radii.unsqueeze(2) + atom_dense_radii.unsqueeze(1) - d_pair
        overlap_excess = torch.clip(vdw_overlap - overlap_buffer_b, min=0.0)  # (B, N_lig, N_atom)
        overlap_excess = overlap_excess * pair_mask.float()

        # Displacement: move ligand away from receptor atoms
        u_la = l2a_vec / d_pair.unsqueeze(-1)
        disp_pairs = -step_scale * overlap_excess.unsqueeze(-1) * u_la  # (B, N_lig, N_atom, 3)
        overlap_offsets_dense = disp_pairs.sum(dim=2)  # (B, N_lig, 3)

        # Convert dense offsets back to ragged ordering
        num_graphs = int(batch["ligand"].batch.max().item()) + 1
        overlap_offsets = torch.zeros_like(pos)
        for g in range(num_graphs):
            idx = torch.nonzero(batch["ligand"].batch == g, as_tuple=False).squeeze(-1)
            if idx.numel() == 0:
                continue
            valid_len = int(lig_dense_mask[g].sum().item())
            if valid_len > 0:
                overlap_offsets.index_add_(0, idx, overlap_offsets_dense[g, :valid_len])



        # Combine offsets and apply once
        combined_offsets = geom_offsets + overlap_offsets
        # Cap per-node combined displacement
        comb_norm = torch.linalg.norm(combined_offsets, dim=-1, keepdim=True).clamp_min(1e-8)
        comb_scale = torch.clamp(max_step_total / comb_norm, max=1.0)
        combined_offsets = combined_offsets * comb_scale
        batch["ligand"].pos = pos + combined_offsets

        # Metrics (pre- and post-update, similar to training loss terms)
        metrics = {}
        with torch.no_grad():
            # Pre-update metrics using d_s from original positions
            lig_dists_pre = d_s
            bond_loss_pre = (
                torch.clip(lower_bond - lig_dists_pre, min=0)[bond_mask]
                + torch.clip(lig_dists_pre - upper_bond, min=0)[bond_mask]
            )
            angle_loss_pre = (
                torch.clip(lower_angle - lig_dists_pre, min=0)[angle_mask]
                + torch.clip(lig_dists_pre - upper_angle, min=0)[angle_mask]
            )
            steric_loss_pre = torch.clip(lower_steric - lig_dists_pre, min=0)[non_ba_mask]
            graph_idx = batch["ligand"].batch.index_select(0, edge_index[0])
            num_graphs = int(batch["ligand"].batch.max().item()) + 1
            metrics["pre_batch_bond_loss"] = scatter_sum(
                bond_loss_pre, graph_idx[bond_mask], dim=0, dim_size=num_graphs
            )
            metrics["pre_batch_angle_loss"] = scatter_sum(
                angle_loss_pre, graph_idx[angle_mask], dim=0, dim_size=num_graphs
            )
            metrics["pre_batch_steric_loss"] = scatter_sum(
                steric_loss_pre, graph_idx[non_ba_mask], dim=0, dim_size=num_graphs
            )
            # Normalized (per-edge) means for logging comparability
            pre_bond_cnt = scatter_sum(
                torch.ones_like(bond_loss_pre), graph_idx[bond_mask], dim=0, dim_size=num_graphs
            )
            pre_angle_cnt = scatter_sum(
                torch.ones_like(angle_loss_pre), graph_idx[angle_mask], dim=0, dim_size=num_graphs
            )
            pre_steric_cnt = scatter_sum(
                torch.ones_like(steric_loss_pre), graph_idx[non_ba_mask], dim=0, dim_size=num_graphs
            )
            metrics["pre_batch_bond_loss_mean"] = metrics["pre_batch_bond_loss"] / pre_bond_cnt.clamp_min(1)
            metrics["pre_batch_angle_loss_mean"] = metrics["pre_batch_angle_loss"] / pre_angle_cnt.clamp_min(1)
            metrics["pre_batch_steric_loss_mean"] = metrics["pre_batch_steric_loss"] / pre_steric_cnt.clamp_min(1)

            # Overlap (ligand vs receptor) pre-update metrics
            try:
                lig_dense_pos_m, lig_dense_mask_m = to_dense_batch(pos, batch["ligand"].batch)
                atom_dense_pos_m, atom_dense_mask_m = to_dense_batch(batch["atom"].pos, batch["atom"].batch)
                lig_dense_radii_m, _ = to_dense_batch(batch["ligand"].vdw_radii, batch["ligand"].batch)
                atom_dense_radii_m, _ = to_dense_batch(batch["atom"].vdw_radii, batch["atom"].batch)
                d_pre = torch.linalg.norm(
                    atom_dense_pos_m.unsqueeze(1) - lig_dense_pos_m.unsqueeze(2), dim=-1
                )
                vdw_overlaps_pre = torch.clip(
                    lig_dense_radii_m.unsqueeze(2) + atom_dense_radii_m.unsqueeze(1) - d_pre - overlap_buffer,
                    min=0,
                )
                pair_mask_m = lig_dense_mask_m.unsqueeze(2) & atom_dense_mask_m.unsqueeze(1)
                vdw_overlaps_pre[~pair_mask_m] = torch.nan
                metrics["pre_batch_overlap_loss"] = torch.nansum(vdw_overlaps_pre, dim=(-1, -2))
                # Normalized (per-pair) mean for logging comparability
                pre_pairs_cnt = pair_mask_m.sum(dim=(-1, -2)).to(vdw_overlaps_pre.dtype)
                metrics["pre_batch_overlap_loss_mean"] = metrics["pre_batch_overlap_loss"] / pre_pairs_cnt.clamp_min(1)
            except Exception:
                pass

            # Post-update metrics: recompute distances on updated positions
            pos_post = batch["ligand"].pos
            src_post = pos_post.index_select(0, edge_index[0])
            tgt_post = pos_post.index_select(0, edge_index[1])
            d_post = torch.linalg.norm(tgt_post - src_post, dim=-1)
            lig_dists_post = d_post
            bond_loss_post = (
                torch.clip(lower_bond - lig_dists_post, min=0)[bond_mask]
                + torch.clip(lig_dists_post - upper_bond, min=0)[bond_mask]
            )
            angle_loss_post = (
                torch.clip(lower_angle - lig_dists_post, min=0)[angle_mask]
                + torch.clip(lig_dists_post - upper_angle, min=0)[angle_mask]
            )
            steric_loss_post = torch.clip(lower_steric - lig_dists_post, min=0)[non_ba_mask]
            metrics["post_batch_bond_loss"] = scatter_sum(
                bond_loss_post, graph_idx[bond_mask], dim=0, dim_size=num_graphs
            )
            metrics["post_batch_angle_loss"] = scatter_sum(
                angle_loss_post, graph_idx[angle_mask], dim=0, dim_size=num_graphs
            )
            metrics["post_batch_steric_loss"] = scatter_sum(
                steric_loss_post, graph_idx[non_ba_mask], dim=0, dim_size=num_graphs
            )
            # Normalized (per-edge) means for logging comparability
            post_bond_cnt = scatter_sum(
                torch.ones_like(bond_loss_post), graph_idx[bond_mask], dim=0, dim_size=num_graphs
            )
            post_angle_cnt = scatter_sum(
                torch.ones_like(angle_loss_post), graph_idx[angle_mask], dim=0, dim_size=num_graphs
            )
            post_steric_cnt = scatter_sum(
                torch.ones_like(steric_loss_post), graph_idx[non_ba_mask], dim=0, dim_size=num_graphs
            )
            metrics["post_batch_bond_loss_mean"] = metrics["post_batch_bond_loss"] / post_bond_cnt.clamp_min(1)
            metrics["post_batch_angle_loss_mean"] = metrics["post_batch_angle_loss"] / post_angle_cnt.clamp_min(1)
            metrics["post_batch_steric_loss_mean"] = metrics["post_batch_steric_loss"] / post_steric_cnt.clamp_min(1)

            # Overlap (ligand vs receptor) post-update metrics
            try:
                lig_dense_pos_post, lig_dense_mask_post = to_dense_batch(pos_post, batch["ligand"].batch)
                atom_dense_pos_post, atom_dense_mask_post = to_dense_batch(batch["atom"].pos, batch["atom"].batch)
                lig_radii_post, _ = to_dense_batch(batch["ligand"].vdw_radii, batch["ligand"].batch)
                atom_radii_post, _ = to_dense_batch(batch["atom"].vdw_radii, batch["atom"].batch)
                d_post_pairs = torch.linalg.norm(
                    atom_dense_pos_post.unsqueeze(1) - lig_dense_pos_post.unsqueeze(2), dim=-1
                )
                overlaps_post = torch.clip(
                    lig_radii_post.unsqueeze(2) + atom_radii_post.unsqueeze(1) - d_post_pairs - overlap_buffer,
                    min=0,
                )
                pair_mask_post = lig_dense_mask_post.unsqueeze(2) & atom_dense_mask_post.unsqueeze(1)
                overlaps_post[~pair_mask_post] = torch.nan
                metrics["post_batch_overlap_loss"] = torch.nansum(overlaps_post, dim=(-1, -2))
                # Normalized (per-pair) mean for logging comparability
                post_pairs_cnt = pair_mask_post.sum(dim=(-1, -2)).to(overlaps_post.dtype)
                metrics["post_batch_overlap_loss_mean"] = metrics["post_batch_overlap_loss"] / post_pairs_cnt.clamp_min(1)
            except Exception:
                pass


        return batch, metrics
    except Exception:
        # Be conservative: if anything fails, just return original batch
        return batch, {}


def _dense_roundtrip(x, batch_vec):
    """Convert ragged tensor x to dense via to_dense_batch and back.

    Returns the reconstructed flat tensor, along with (dense, mask).
    """
    dense, mask = to_dense_batch(x, batch_vec)
    pieces = []
    B = mask.size(0)
    for i in range(B):
        valid_len = int(mask[i].sum().item())
        if valid_len > 0:
            pieces.append(dense[i, :valid_len])
    if len(pieces) > 0:
        x_rec = torch.cat(pieces, dim=0)
    else:
        # Preserve shape when there are no nodes
        x_rec = x.new_zeros((0,) + tuple(x.shape[1:]))
    return x_rec, dense, mask


@torch.no_grad()
def sampling_on_batch(
    model,
    batch,
    inference_steps,
    x_zero_pred=False,
    save_traj=True,
    schedule_type="uniform",
    schedule_param=1.0,
):
    batch = center_complex(batch)
    if save_traj:
        lig_traj = batch["ligand"].pos.clone().unsqueeze(1)
        atom_traj = batch["atom"].pos.clone().unsqueeze(1)

    t_schedule = np.linspace(0, 1, inference_steps + 1)
    if schedule_type == "sigmoid":
        t_schedule = sigmoid_schedule(t_schedule, schedule_param)
    if schedule_type == "exponential":
        t_schedule = exponential_schedule(t_schedule, schedule_param)

    for t_idx in range(t_schedule.shape[0] - 1):
        t = t_schedule[t_idx]
        dt = t_schedule[t_idx + 1] - t_schedule[t_idx]
        set_time(batch, t)

        lig_pred, atom_pred = model(batch)

        if x_zero_pred:
            lig_update = (lig_pred - batch["ligand"].pos) * dt / (1 - t)
            atom_update = (atom_pred - batch["atom"].pos) * dt / (1 - t)
        else:
            lig_update = lig_pred * dt
            atom_update = atom_pred * dt
        batch["ligand"].pos += lig_update
        batch["atom"].pos += atom_update
        batch["receptor"].pos = batch["atom"].pos[batch["atom"].ca_mask]
        batch = center_complex(batch)
        # print(f"t: {t}")
        for i in range(5):
            batch, pb_metrics = apply_posebusters_bias(batch, t=t, extra_buffer_ratio=5)
            batch = center_complex(batch)

        if save_traj:
            lig_traj = torch.cat(
                (lig_traj, batch["ligand"].pos.clone().unsqueeze(1)), dim=1
            )
            atom_traj = torch.cat(
                (atom_traj, batch["atom"].pos.clone().unsqueeze(1)), dim=1
            )
    batch_proj = batch.clone()
    for i in range(20):
        batch_proj, pb_metrics = apply_posebusters_bias(batch_proj, t=1, extra_buffer_ratio=5)
        batch_proj = center_complex(batch_proj)
    if save_traj:
        return lig_traj, atom_traj
    else:
        # return batch["ligand"].pos, batch["atom"].pos, aa["ligand"].pos, aa["atom"].pos
        return batch_proj["ligand"].pos, batch_proj["atom"].pos


def sampling_on_confs(
    model,
    conf_loader,
    inference_steps,
    x_zero_pred=False,
    save_traj=False,
    device=None,
    schedule_type="uniform",
    schedule_param=1.0,
):
    lig_pred, atom_pred = [], []
    for batch in conf_loader:
        lig_pred_batch, atom_pred_batch = sampling_on_batch(
            model,
            batch.to(device),
            inference_steps,
            x_zero_pred=x_zero_pred,
            save_traj=save_traj,
            schedule_type=schedule_type,
            schedule_param=schedule_param,
        )  # N_Batch_Atoms x 3
        lig_pred.append(to_dense_batch(lig_pred_batch, batch["ligand"].batch)[0])
        atom_pred.append(to_dense_batch(atom_pred_batch, batch["atom"].batch)[0])
    lig_pred = torch.cat(lig_pred, dim=0)
    atom_pred = torch.cat(atom_pred, dim=0)
    if save_traj:
        lig_pred = lig_pred.permute(2, 0, 1, 3)
        atom_pred = atom_pred.permute(2, 0, 1, 3)
    return lig_pred, atom_pred


def compute_min_rmsds(graph, lig_pred, atom_pred):
    # R, tr = rigid_transform_kabsch_pairs(
    #     atom_pred.index_select(-2, torch.argwhere(graph['atom'].ca_mask).squeeze()),
    #     graph['atom'].tgt_pos[graph['atom'].ca_mask].swapaxes(0,1)
    # )
    R, tr = rigid_transform_kabsch(atom_pred, graph["atom"].tgt_pos)
    aligned_lig_pred = lig_pred @ R.swapaxes(-1, -2) + tr.unsqueeze(-2)
    aligned_atom_pred = atom_pred @ R.swapaxes(-1, -2) + tr.unsqueeze(-2)
    lig_rmsds = torch.sqrt(
        torch.mean(
            torch.sum((aligned_lig_pred - graph["ligand"].tgt_pos) ** 2, axis=-1),
            axis=-1,
        )
    )
    atom_rmsds = torch.sqrt(
        torch.mean(
            torch.sum((aligned_atom_pred - graph["atom"].tgt_pos) ** 2, axis=-1),
            axis=-1,
        )
    )
    # closest_conf_idxs = torch.argmin(lig_rmsds + atom_rmsds, axis=-1)
    # lig_min_rmsds = torch.gather(lig_rmsds, dim=-1, index=closest_conf_idxs.unsqueeze(-1)).squeeze(-1)
    # atom_min_rmsds = torch.gather(atom_rmsds, dim=-1, index=closest_conf_idxs.unsqueeze(-1)).squeeze(-1)

    return {"lig_rmsds": lig_rmsds, "atom_rmsds": atom_rmsds}


def compute_energy_ratios(graph, lig_pred, atom_pred):
    pred_lig_bond_length, pred_lig_bond_energy, pred_lig_bond_force = get_bond_force(
        lig_pred,
        graph["ligand"].bond_index,
        graph["ligand"].bond_k,
        graph["ligand"].bond_r_0,
    )
    pred_atom_bond_length, pred_atom_bond_energy, pred_atom_bond_force = get_bond_force(
        atom_pred,
        graph["atom"].bond_index,
        graph["atom"].bond_k,
        graph["atom"].bond_r_0,
    )
    pred_lig_angle, pred_lig_angle_energy, pred_lig_angle_force = get_angle_force(
        lig_pred,
        graph["ligand"].angle_index,
        graph["ligand"].angle_k,
        graph["ligand"].angle_theta_0,
    )
    pred_atom_angle, pred_atom_angle_energy, pred_atom_angle_force = get_angle_force(
        atom_pred,
        graph["atom"].angle_index,
        graph["atom"].angle_k,
        graph["atom"].angle_theta_0,
    )
    (
        pred_lig_torsion,
        pred_lig_torsion_energy,
        pred_lig_torsion_force,
    ) = get_torsion_force(
        lig_pred,
        graph["ligand"].torsion_index,
        graph["ligand"].torsion_k,
        graph["ligand"].torsion_n,
        graph["ligand"].torsion_phi_0,
    )
    (
        pred_atom_torsion,
        pred_atom_torsion_energy,
        pred_atom_torsion_force,
    ) = get_torsion_force(
        atom_pred,
        graph["atom"].torsion_index,
        graph["atom"].torsion_k,
        graph["atom"].torsion_n,
        graph["atom"].torsion_phi_0,
    )
    energy_ratio_dict = {}
    for key in ["src", "tgt"]:
        if key == "src":
            ref_lig_pos = graph["ligand"].src_pos.swapaxes(0, 1)
            ref_atom_pos = graph["atom"].src_pos.swapaxes(0, 1)
        else:
            ref_lig_pos = graph["ligand"].tgt_pos
            ref_atom_pos = graph["atom"].tgt_pos
        ref_lig_bond_length, ref_lig_bond_energy, ref_lig_bond_force = get_bond_force(
            ref_lig_pos,
            graph["ligand"].bond_index,
            graph["ligand"].bond_k,
            graph["ligand"].bond_r_0,
        )
        (
            ref_atom_bond_length,
            ref_atom_bond_energy,
            ref_atom_bond_force,
        ) = get_bond_force(
            ref_atom_pos,
            graph["atom"].bond_index,
            graph["atom"].bond_k,
            graph["atom"].bond_r_0,
        )
        ref_lig_angle, ref_lig_angle_energy, ref_lig_angle_force = get_angle_force(
            ref_lig_pos,
            graph["ligand"].angle_index,
            graph["ligand"].angle_k,
            graph["ligand"].angle_theta_0,
        )
        ref_atom_angle, ref_atom_angle_energy, ref_atom_angle_force = get_angle_force(
            ref_atom_pos,
            graph["atom"].angle_index,
            graph["atom"].angle_k,
            graph["atom"].angle_theta_0,
        )
        (
            ref_lig_torsion,
            ref_lig_torsion_energy,
            ref_lig_torsion_force,
        ) = get_torsion_force(
            ref_lig_pos,
            graph["ligand"].torsion_index,
            graph["ligand"].torsion_k,
            graph["ligand"].torsion_n,
            graph["ligand"].torsion_phi_0,
        )
        (
            ref_atom_torsion,
            ref_atom_torsion_energy,
            ref_atom_torsion_force,
        ) = get_torsion_force(
            ref_atom_pos,
            graph["atom"].torsion_index,
            graph["atom"].torsion_k,
            graph["atom"].torsion_n,
            graph["atom"].torsion_phi_0,
        )
        lig_bond_energy_ratio = pred_lig_bond_energy / ref_lig_bond_energy
        atom_bond_energy_ratio = pred_atom_bond_energy / ref_atom_bond_energy
        lig_angle_energy_ratio = pred_lig_angle_energy / ref_lig_angle_energy
        atom_angle_energy_ratio = pred_atom_angle_energy / ref_atom_angle_energy
        lig_torsion_energy_ratio = pred_lig_torsion_energy / ref_lig_torsion_energy
        atom_torsion_energy_ratio = pred_atom_torsion_energy / ref_atom_torsion_energy
        energy_ratio_dict.update(
            {
                f"{key}_lig_bond_energy_ratio": lig_bond_energy_ratio,
                f"{key}_atom_bond_energy_ratio": atom_bond_energy_ratio,
                f"{key}_lig_angle_energy_ratio": lig_angle_energy_ratio,
                f"{key}_atom_angle_energy_ratio": atom_angle_energy_ratio,
                f"{key}_lig_torsion_energy_ratio": lig_torsion_energy_ratio,
                f"{key}_atom_torsion_energy_ratio": atom_torsion_energy_ratio,
            }
        )
    return energy_ratio_dict


def aggregate_metric_dicts(metric_dicts):
    aggregated_metrict_dict = defaultdict(list)
    for metric_dict in metric_dicts:
        for key, value in metric_dict.items():
            aggregated_metrict_dict[key].append(value)
    aggregated_metrict_dict = {
        "avg_" + key: sum(values) / len(values)
        for key, values in aggregated_metrict_dict.items()
    }
    return aggregated_metrict_dict
