from collections import defaultdict

import numpy as np
from Bio.PDB import DSSP, PDBParser, Polypeptide, Selection, Superimposer
from Bio.PDB.Polypeptide import is_aa
from Bio.SeqUtils.ProtParam import ProteinAnalysis
from scipy.spatial import cKDTree


def validate_design_sequence(sequence, num_clashes, advanced_settings):
    note_array = []

    if num_clashes > 0:
        note_array.append("Relaxed structure contains clashes.")

    restricted_AAs = advanced_settings["omit_AAs"].split(",")
    for restricted_AA in restricted_AAs:
        if restricted_AA in sequence:
            note_array.append("Contains: " + restricted_AA + "!")

    analysis = ProteinAnalysis(sequence)

    extinction_coefficient_reduced = analysis.molar_extinction_coefficient()[0]
    molecular_weight = round(analysis.molecular_weight() / 1000, 2)
    extinction_coefficient_reduced_1 = round(
        extinction_coefficient_reduced / molecular_weight * 0.01, 2
    )

    if extinction_coefficient_reduced_1 <= 2:
        note_array.append(
            f"Absorption value is {extinction_coefficient_reduced_1}, consider adding tryptophane to design."
        )

    notes = " ".join(note_array)

    return notes


def target_pdb_rmsd(trajectory_pdb, starting_pdb, chain_ids_string):

    parser = PDBParser(QUIET=True)
    structure_trajectory = parser.get_structure("trajectory", trajectory_pdb)
    structure_starting = parser.get_structure("starting", starting_pdb)

    chain_trajectory = structure_trajectory[0]["A"]

    chain_ids = chain_ids_string.split(",")
    residues_starting = []
    for chain_id in chain_ids:
        chain_id = chain_id.strip()
        chain = structure_starting[0][chain_id]
        for residue in chain:
            if is_aa(residue, standard=True):
                residues_starting.append(residue)

    residues_trajectory = [
        residue for residue in chain_trajectory if is_aa(residue, standard=True)
    ]

    min_length = min(len(residues_starting), len(residues_trajectory))
    residues_starting = residues_starting[:min_length]
    residues_trajectory = residues_trajectory[:min_length]

    atoms_starting = [residue["CA"] for residue in residues_starting if "CA" in residue]
    atoms_trajectory = [
        residue["CA"] for residue in residues_trajectory if "CA" in residue
    ]

    sup = Superimposer()
    sup.set_atoms(atoms_starting, atoms_trajectory)
    rmsd = sup.rms

    return round(rmsd, 2)


def calculate_clash_score(pdb_file, threshold=2.4, only_ca=False):
    parser = PDBParser(QUIET=True)
    structure = parser.get_structure("protein", pdb_file)

    atoms = []
    atom_info = []

    for model in structure:
        for chain in model:
            for residue in chain:
                for atom in residue:
                    if atom.element == "H":
                        continue
                    if only_ca and atom.get_name() != "CA":
                        continue
                    atoms.append(atom.coord)
                    atom_info.append(
                        (chain.id, residue.id[1], atom.get_name(), atom.coord)
                    )

    tree = cKDTree(atoms)
    pairs = tree.query_pairs(threshold)

    valid_pairs = set()
    for i, j in pairs:
        chain_i, res_i, name_i, coord_i = atom_info[i]
        chain_j, res_j, name_j, coord_j = atom_info[j]

        if chain_i == chain_j and res_i == res_j:
            continue

        if chain_i == chain_j and abs(res_i - res_j) == 1:
            continue

        if not only_ca and chain_i == chain_j:
            continue

        valid_pairs.add((i, j))

    return len(valid_pairs)


def hotspot_residues(trajectory_pdb, binder_chain="B", atom_distance_cutoff=4.0):

    parser = PDBParser(QUIET=True)
    structure = parser.get_structure("complex", trajectory_pdb)

    binder_atoms = Selection.unfold_entities(structure[0][binder_chain], "A")
    binder_coords = np.array([atom.coord for atom in binder_atoms])

    target_atoms = Selection.unfold_entities(structure[0]["A"], "A")
    target_coords = np.array([atom.coord for atom in target_atoms])

    binder_tree = cKDTree(binder_coords)
    target_tree = cKDTree(target_coords)

    interacting_residues = {}

    pairs = binder_tree.query_ball_tree(target_tree, atom_distance_cutoff)

    for binder_idx, close_indices in enumerate(pairs):
        binder_residue = binder_atoms[binder_idx].get_parent()
        binder_resname = binder_residue.get_resname()
        if binder_resname in Polypeptide.standard_aa_names:
            aa_single_letter = Polypeptide.three_to_one(binder_resname)
            for close_idx in close_indices:
                target_atoms[close_idx].get_parent()
                interacting_residues[binder_residue.id[1]] = aa_single_letter

    return interacting_residues


def calc_ss_percentage(
    pdb_file, advanced_settings, chain_id="B", atom_distance_cutoff=4.0
):

    parser = PDBParser(QUIET=True)
    structure = parser.get_structure("protein", pdb_file)
    model = structure[0]

    dssp = DSSP(model, pdb_file, dssp=advanced_settings["dssp_path"])

    ss_counts = defaultdict(int)
    ss_interface_counts = defaultdict(int)
    plddts_interface = []
    plddts_ss = []

    chain = model[chain_id]
    interacting_residues = set(
        hotspot_residues(pdb_file, chain_id, atom_distance_cutoff).keys()
    )

    for residue in chain:
        residue_id = residue.id[1]
        if (chain_id, residue_id) in dssp:
            ss = dssp[(chain_id, residue_id)][2]
            ss_type = "loop"
            if ss in ["H", "G", "I"]:
                ss_type = "helix"
            elif ss == "E":
                ss_type = "sheet"

            ss_counts[ss_type] += 1

            if ss_type != "loop":

                avg_plddt_ss = sum(atom.bfactor for atom in residue) / len(residue)
                plddts_ss.append(avg_plddt_ss)

            if residue_id in interacting_residues:
                ss_interface_counts[ss_type] += 1

                avg_plddt_residue = sum(atom.bfactor for atom in residue) / len(residue)
                plddts_interface.append(avg_plddt_residue)

    total_residues = sum(ss_counts.values())
    total_interface_residues = sum(ss_interface_counts.values())

    percentages = calculate_percentages(
        total_residues, ss_counts["helix"], ss_counts["sheet"]
    )
    interface_percentages = calculate_percentages(
        total_interface_residues,
        ss_interface_counts["helix"],
        ss_interface_counts["sheet"],
    )

    i_plddt = (
        round(sum(plddts_interface) / len(plddts_interface) / 100, 2)
        if plddts_interface
        else 0
    )
    ss_plddt = round(sum(plddts_ss) / len(plddts_ss) / 100, 2) if plddts_ss else 0

    return (*percentages, *interface_percentages, i_plddt, ss_plddt)


def calculate_percentages(total, helix, sheet):
    helix_percentage = round((helix / total) * 100, 2) if total > 0 else 0
    sheet_percentage = round((sheet / total) * 100, 2) if total > 0 else 0
    loop_percentage = (
        round(((total - helix - sheet) / total) * 100, 2) if total > 0 else 0
    )

    return helix_percentage, sheet_percentage, loop_percentage
