import io
from collections import defaultdict
from functools import lru_cache
from dataclasses import replace
from typing import Tuple

from smart_open import open
import torch
import numpy as np
from rdkit.Chem import GetPeriodicTable

from coarsebind_public.coarsebind.data.const import tokens, token_ids
from coarsebind_public.coarsebind.io_schema import IOSchemaCoarseBind


@lru_cache(maxsize=None)
def from_proc_npz(file_path: str):
    """
    Load a PDB file from the internal CIF format.
    """

    with open(file_path, "rb") as f:
        pdb_data = {key: f_obj for key, f_obj in np.load(f, allow_pickle=True).items()}

    return pdb_data


def assign_pdb_template(io_schema: IOSchemaCoarseBind):
    raise NotImplementedError("broken atm")

    data = from_proc_npz(io_schema.template_path)
    chain_mask = data["chains"]["name"] == io_schema.chain_id
    chain = data["chains"][chain_mask]

    if len(chain) == 0:
        raise ValueError(
            f"Chain {io_schema.chain_id} not found in PDB file {io_schema.template_path}"
        )
    if len(chain) > 1:
        raise ValueError(
            f"Multiple chains found for {io_schema.chain_id} in PDB file {io_schema.template_path}"
        )

    chain = chain[0]

    start_idx = chain["res_idx"]
    end_idx = start_idx + chain["res_num"]
    chain_residues = data["residues"][start_idx:end_idx]

    res_coords = data["atoms"][chain_residues["atom_disto"]]["coords"]
    xtal_is_present_mask = chain_residues["is_present"]

    sequence = chain["sequence"]

    if len(sequence) != len(io_schema.sequence):
        raise ValueError(
            f"Sequence mismatch: expected {io_schema.sequence}, got {sequence} in pdb entry {io_schema.rcsb_id}"
        )

    template_coords = np.zeros((len(io_schema.res_type), 3), dtype=np.float32)
    template_coords[~io_schema.potency_ligand_mask] = res_coords

    template_mask = np.ones((len(io_schema.res_type),), dtype=bool)
    template_mask[~io_schema.potency_ligand_mask] = xtal_is_present_mask

    # TODO: actually need to offset this to compare directly w template
    template_res_idxs = data["residues"][chain["res_idx"] : chain["res_idx"] + chain["res_num"]][
        "res_idx"
    ].copy()

    pad_template_res_idxs = np.zeros(len(io_schema.res_type), dtype=np.int32)
    pad_template_res_idxs[~io_schema.potency_ligand_mask] = template_res_idxs

    io_schema = replace(
        io_schema,
        coarse_cofold_template_coords=template_coords,
        template_mask=template_mask,
        template_res_idxs=pad_template_res_idxs,
    )

    return io_schema


def get_binding_site_heatmap_pdb(
    io_schema: IOSchemaCoarseBind,
    coords: np.ndarray,
    binding_probs: np.ndarray,
    alt_loc="",
):

    residue_mask = ~io_schema.disto_output.potency_ligand_mask

    # NOTE: binding_probs is already masked
    coords = coords[residue_mask]

    # TODO..
    if io_schema.disto_output.template_res_idxs is not None:
        template_res_idxs = io_schema.disto_output.template_res_idxs[residue_mask]
    else:
        num_residues = (residue_mask).sum()
        template_res_idxs = np.arange(num_residues)

    protein_buf = io.StringIO()

    occupancy = 1.00

    # first process residues
    protein_chain = "A"

    global_idx = 1
    for i, _res_tok in enumerate(io_schema.disto_output.res_type[residue_mask]):

        res_name = tokens[_res_tok]  # e.g. "LYS"
        record = "ATOM  "  # Treat as ATOM for standard residues
        atom_symbol = "CA"  # Use alpha carbon as representation
        element = "C"  # Carbon for CA
        b_factor = binding_probs[i] * 100.0

        x, y, z = coords[i]

        # fixed width PDB ATOM line:
        protein_buf.write(
            f"{record}{global_idx:5d}"
            f" {atom_symbol:^4s}"
            f"{alt_loc:1s}"
            f"{res_name:>3s} "
            f"{protein_chain:1s}"
            f"{template_res_idxs[i]:4d}"
            f"    "
            f"{x:8.3f}{y:8.3f}{z:8.3f}"
            f"{occupancy:6.2f}{b_factor:6.2f}"
            f"          {element:>2s}\n"
        )

        global_idx += 1

    heatmap_pdb_str = protein_buf.getvalue()
    return heatmap_pdb_str


def generate_pdb(
    io_schema: IOSchemaCoarseBind,
    coords,
    model_number=1,
    alt_loc="",
):

    res_type = io_schema.disto_output.res_type
    norm_bin_entropy = io_schema.disto_output.norm_bin_entropy

    mean_norm_bin_entropy = norm_bin_entropy.mean(1)

    ligand_mask = io_schema.disto_output.potency_ligand_mask
    residue_mask = ~ligand_mask

    if io_schema.disto_output.coarse_cofold_template_coords is None:
        # get a canonical parity for residue coords
        principal_axes, eigenvalues, centroid = calculate_principal_axes(coords[residue_mask])
        parity = principal_axes[0][0]
        # use positive parity as canonical
        if parity < 0:
            coords = -1 * coords

    pt = GetPeriodicTable()
    protein_buf = io.StringIO()
    ligand_buf = io.StringIO()

    occupancy = 1.00

    # first process residues

    global_idx = 1
    for i, _res_tok in enumerate(res_type[residue_mask]):

        chain_id = str(io_schema.disto_output.asym_id[residue_mask][i])

        if io_schema.disto_output.res_name is not None:
            res_name = io_schema.disto_output.res_name[residue_mask][i]
        else:
            res_name = tokens[_res_tok]  # e.g. "LYS"

        record = "ATOM  "  # Treat as ATOM for standard residues
        atom_symbol = "CA"  # Use alpha carbon as representation
        element = "C"  # Carbon for CA
        b_factor = mean_norm_bin_entropy[residue_mask][i] * 100.0

        x, y, z = coords[residue_mask][i]

        # fixed width PDB ATOM line:
        protein_buf.write(
            f"{record}{global_idx:5d}"
            f" {atom_symbol:^4s}"
            f"{alt_loc:1s}"
            f"{res_name:>3s} "
            f"{chain_id:1s}"
            f"{io_schema.disto_output.res_num[residue_mask][i]:4d}"
            f"    "
            f"{x:8.3f}{y:8.3f}{z:8.3f}"
            f"{occupancy:6.2f}{b_factor:6.2f}"
            f"          {element:>2s}\n"
        )

        global_idx += 1

    global_idx = 0

    # then ligands

    # reindex bonds
    if io_schema.mol_enc_io is not None:
        bonds = io_schema.mol_enc_io.bonds
        atoms = io_schema.mol_enc_io.atoms
    else:
        # Fallback in case no MolEnc data is available
        bonds = np.array([])
        atoms = np.array([])

    bonds += global_idx

    for i, _res_tok in enumerate(res_type[ligand_mask]):

        chain_id = str(io_schema.disto_output.asym_id[ligand_mask][i])

        record = "HETATM"  # Non-residues are HETATM

        element = pt.GetElementSymbol(int(atoms[i])).upper()
        atom_symbol = element + str(i + 1)
        res_name = "LIG"  # Default for ligands
        b_factor = mean_norm_bin_entropy[ligand_mask][i] * 100.0

        x, y, z = coords[ligand_mask][i]

        ligand_buf.write(
            f"{record}{global_idx:5d}"
            f" {atom_symbol:^4s}"
            f"{alt_loc:1s}"
            f"{res_name:>3s} "
            f"{chain_id:1s}"
            f"{io_schema.disto_output.res_num[ligand_mask][i]:4d}"
            f"    "
            f"{x:8.3f}{y:8.3f}{z:8.3f}"
            f"{occupancy:6.2f}{b_factor:6.2f}"
            f"          {element:>2s}\n"
        )

        global_idx += 1

    # Add bonds
    connect_dict = defaultdict(list)
    for _bond in bonds:
        atom1_idx, atom2_idx = _bond
        connect_dict[atom1_idx].append(atom2_idx)
        connect_dict[atom2_idx].append(atom1_idx)

    # 2. Generate the CONECT records
    conect_records = []
    # Sort by atom index for a clean, ordered PDB file
    for atom_idx in sorted(connect_dict.keys()):
        # Get the list of atoms it's connected to
        partners = connect_dict[atom_idx]

        # The PDB format specifies a maximum of 4 connections per line.
        # We chunk the partners into groups of 4.
        for i in range(0, len(partners), 4):
            # Get a chunk of up to 4 partners
            chunk = partners[i : i + 4]

            # Format the atom indices into 5-character, right-aligned strings
            atom_str = f"{atom_idx:>5}"
            partners_str = "".join([f"{p:>5}" for p in chunk])

            # Combine into the final CONECT record and add to our list
            # The record name "CONECT" is left-aligned in a 6-character space
            conect_records.append(f"CONECT{atom_str}{partners_str}")

    # Write all CONECT records to the buffer
    for record in conect_records:
        ligand_buf.write(record + "\n")

    model_header = f"MODEL     {model_number}\n"
    model_footer = "ENDMDL\n"

    protein_pdb_str = model_header + protein_buf.getvalue() + model_footer
    ligand_pdb_str = model_header + ligand_buf.getvalue() + model_footer

    return protein_pdb_str, ligand_pdb_str


def calculate_principal_axes(coords):
    """
    Calculates the principal axes of rotation for a set of 3D coordinates.

    Args:
        coords (np.ndarray): An Nx3 NumPy array of 3D coordinates (x, y, z).

    Returns:
        tuple: A tuple containing:
            - principal_axes (np.ndarray): A 3x3 matrix where each column is a principal axis (eigenvector).
                                          Sorted from largest to smallest eigenvalue.
            - eigenvalues (np.ndarray): A 1D array of eigenvalues, sorted in descending order.
            - centroid (np.ndarray): The 1x3 centroid of the input coordinates.
    """
    if coords.shape[1] != 3:
        raise ValueError("Input coordinates must be an Nx3 array.")

    # 1. Calculate the centroid (center of mass/geometry)
    centroid = np.mean(coords, axis=0)

    # 2. Center the data
    centered_coords = coords - centroid

    # 3. Compute the covariance matrix
    # The 'rowvar=False' argument tells numpy that each column is a variable (x,y,z)
    # and each row is an observation (an atom).
    covariance_matrix = np.cov(centered_coords, rowvar=False)

    # 4. Perform eigenvalue decomposition
    # eig_vals will be eigenvalues, eig_vecs will be eigenvectors (columns are eigenvectors)
    eigenvalues, eigenvectors = np.linalg.eig(covariance_matrix)

    # 5. Sort eigenvalues and eigenvectors in descending order of eigenvalues
    # Get the indices that would sort eigenvalues in descending order
    sorted_indices = np.argsort(eigenvalues)[::-1]
    sorted_eigenvalues = eigenvalues[sorted_indices]
    sorted_eigenvectors = eigenvectors[:, sorted_indices]  # Sort columns (eigenvectors)

    return sorted_eigenvectors, sorted_eigenvalues, centroid


def frames_to_pdb(io_schema: IOSchemaCoarseBind, coords, output_filename=None) -> Tuple[str, str]:
    """Generates a PDB trajectory file and returns it as a string."""
    assert len(coords.shape) == 3, "Coordinates should be 3D (frames, atoms, xyz)"

    protein_pdb_content = io.StringIO()
    ligand_pdb_content = io.StringIO()

    # Process all frames
    for f in range(coords.shape[0]):
        protein_pdb_frame, ligand_pdb_frame = generate_pdb(
            io_schema,
            coords[f],
            model_number=f + 1,
        )
        protein_pdb_content.write(protein_pdb_frame)
        ligand_pdb_content.write(ligand_pdb_frame)

    return (
        protein_pdb_content.getvalue(),
        ligand_pdb_content.getvalue(),
    )
