import biotite.structure as struc
import biotite.structure.io as strucio
import torch

from proteinfoundation.utils.pdb_utils import from_pdb_file


def load_alpha_carbon_coordinates(pdb_file):
    prot = from_pdb_file(pdb_file)
    mask = torch.Tensor(prot.atom_mask).long().bool()
    coors_atom37 = torch.Tensor(prot.atom_positions)
    mask_ca = mask[:, 1]
    return coors_atom37[mask_ca, 1, :]


def compute_ca_metrics(pdb_path):
    try:
        coors = load_alpha_carbon_coordinates(pdb_path)
        consecutive_ca_ca_distances = torch.norm(coors[1:, :] - coors[:-1, :], dim=-1)
        pairwise_ca_ca_distances = torch.norm(
            coors[None, :, :] - coors[:, None, :], dim=-1
        )
        num_collisions = (
            torch.sum(
                (pairwise_ca_ca_distances > 0.01) & (pairwise_ca_ca_distances < 2.0)
            )
            / 2.0
        )

        return {
            "ca_ca_dist_avg": torch.mean(consecutive_ca_ca_distances),
            "ca_ca_dist_median": torch.median(consecutive_ca_ca_distances),
            "ca_ca_dist_std": torch.std(consecutive_ca_ca_distances),
            "ca_ca_dist_min": torch.min(consecutive_ca_ca_distances),
            "ca_ca_dist_max": torch.max(consecutive_ca_ca_distances),
            "ca_ca_collisions(2A)": num_collisions,
        }
    except Exception as e:
        print(f"Error in ca-ca metrics {e}")
        return {
            "ca_ca_dist_avg": 0.0,
            "ca_ca_dist_median": 0.0,
            "ca_ca_dist_std": 0.0,
            "ca_ca_dist_min": 0.0,
            "ca_ca_dist_max": 0.0,
            "ca_ca_collisions(2A)": 0.0,
        }


def compute_ss_metrics(pdb_path):
    try:
        stack = strucio.load_structure(pdb_path)
        sse = struc.annotate_sse(stack)
        a = (sse == "a").sum()
        b = (sse == "b").sum()
        c = (sse == "c").sum()
        tot = a + b + c
    except Exception:
        a = 0.0
        b = 0.0
        c = 0.0
        tot = 1.0

    return {
        "biot_alpha": a / tot,
        "biot_beta": b / tot,
        "biot_coil": c / tot,
    }


def compute_structural_metrics(pdb_path):

    metrics_ss = compute_ss_metrics(pdb_path)
    metrics_ca_ca = compute_ca_metrics(pdb_path)
    return {**metrics_ss, **metrics_ca_ca}


def compute_structural_metrics(pdb_path):

    metrics_ss = compute_ss_metrics(pdb_path)
    metrics_ca_ca = compute_ca_metrics(pdb_path)
    return {**metrics_ss, **metrics_ca_ca}
