import pandas as pd
import numpy as np
from Bio.PDB.PDBParser import PDBParser
from collections import defaultdict
parser = PDBParser()

sigma_sqd = 1

dic_VdW_radius = {"C":1.70, 'N':1.55, "S":1.80, "O" : 1.52, "P":1.80}
dic = {"C":0, 'N':1, "S":2, "O" : 3, "P":4}


bounding_boxes = [np.array([64,64,64]),
       np.array([96,96,96]),
       np.array([128,128,128]),
       np.array([160,160,160])]


def shift_axes(PDB_DataFrame, resolution = 1):
    'Help function to shift x-, y- and z-axis such that minimum of all three axis is 0 and round values to next whole number'
    
    x_values = np.array(PDB_DataFrame["x"]) / resolution
    y_values = np.array(PDB_DataFrame["y"]) / resolution
    z_values = np.array(PDB_DataFrame["z"]) / resolution

    #Calculate minimum
    x_min, y_min, z_min   = np.min(x_values), np.min(y_values), np.min(z_values)
    x_max, y_max, z_max   = np.max(x_values), np.max(y_values), np.max(z_values)
    x_range, y_range, z_range = x_max-x_min, y_max-y_min, z_max-z_min

    #make x axis the longest axis
    if x_range >= y_range and x_range >= z_range:
        pass
    elif y_range >= x_range and y_range >= z_range:
        x_values, y_values = y_values, x_values
        x_min, y_min = y_min, x_min
        x_range, y_range = y_range, x_range
    else:
        x_values, z_values = z_values, x_values
        x_min, z_min = z_min, x_min
        x_range, z_range = z_range, x_range

    #make sure y axis is the second longest axis
    if y_range < z_range:
        y_values, z_values = z_values, y_values
        y_min, z_min = z_min, y_min
        y_range, z_range = z_range, y_range

    x_values = x_values - x_min + 4
    y_values = y_values - y_min + 4
    z_values = z_values - z_min + 4

    PDB_DataFrame["x"]= x_values
    PDB_DataFrame["y"]= y_values
    PDB_DataFrame["z"]= z_values
    return PDB_DataFrame

def correct_errors(ls):
    'Helper function to correct errors in PDB file'
    if len(ls[2]) > 4: #residue name is not separated from atom name
            new_ls = ls[0:2]
            new_ls.append(ls[2][0:3])
            new_ls.append(ls[2][3:])
            for k in range(3,len(ls)):
                new_ls.append(ls[k])
            ls = new_ls

    if len(ls[4]) > 1:
        new_ls = ls[0:4]
        new_ls.append(ls[4][0])
        new_ls.append(ls[4][1:])
        for k in range(5,len(ls)):
            new_ls.append(ls[k])
        ls = new_ls
        
        
    if len(ls[6]) > 8: #x and y value are connected
        new_ls = ls[0:6]
        pos= ls[6].find(".")
        new_ls.append(ls[6][:(pos+4)])
        new_ls.append(ls[6][(pos+4):])
        for k in range(7,len(ls)):
            new_ls.append(ls[k])
        ls = new_ls
        
    if len(ls[7]) > 8: #z and y value are connected
        new_ls = ls[0:7]
        pos= ls[7].find(".")
        new_ls.append(ls[7][:(pos+4)])
        new_ls.append(ls[7][(pos+4):])
        for k in range(8,len(ls)):
            new_ls.append(ls[k])
        ls = new_ls
    

    if len(ls[9]) > 4: #occupancy is connected with temperature factor
        new_ls = ls[0:9]
        new_ls.append(ls[9][0:4])
        new_ls.append(ls[9][4:])
        new_ls.append(ls[10])
        ls = new_ls
    if len(ls) != 12:
        print("list still has not length 12")
    return(ls)

def point_dict_to_npy(point_dict):
    point_list = []
    for keys, values in point_dict.items():
        point_list.append([keys[0], keys[1], keys[2], keys[3], np.round(values,5)])
    return(np.array(point_list))


def pdb_to_point_dict(filename, max_input_size = 160 , resolution = 1):
    ATOMS = []
    for line in open(filename):
        ls = line.split()
        if ls[0] == 'ATOM':
            if len(ls) != 12:
                ls = correct_errors(ls)
            ATOMS.append(ls)

    pdb = pd.DataFrame(ATOMS, columns =["record type","atom ID","atom name","residue name","chain ID","residue ID",
                             "x","y","z","occupancy","temperature factor","atom"])
    structure = parser.get_structure(filename.split("/")[-1], filename)

    pdb["x"]= pd.to_numeric(pdb["x"])
    pdb["y"]= pd.to_numeric(pdb["y"])
    pdb["z"]= pd.to_numeric(pdb["z"])
    #shift axis and round values:
    pdb= shift_axes(PDB_DataFrame =pdb,  resolution = resolution)
    #create numpy array:
    x_values = np.array(pdb["x"])
    y_values = np.array(pdb["y"])
    z_values = np.array(pdb["z"])
    atom_types = list(pdb["atom"])

    protein_size = np.max([np.max(x_values), np.max(y_values), np.max(z_values)])
    if protein_size > max_input_size -3:
        #remove all atoms that are outside of the input size
        pdb = pdb[(pdb["x"] < max_input_size-3) & (pdb["y"] < max_input_size-3) & (pdb["z"] < max_input_size-3)]
        x_values = np.array(pdb["x"])
        y_values = np.array(pdb["y"])
        z_values = np.array(pdb["z"])
        atom_types = list(pdb["atom"])

    

    point_dict = defaultdict(float)

    for i in range(len(atom_types)):
        atom = atom_types[i]
        if atom not in dic:
            continue
        coord = np.array([x_values[i], y_values[i], z_values[i]])
        atom_key = dic[atom]
        
        # Define the ranges for c1, c2, c3
        c1_min = max(0, int(coord[0] - 3))
        c1_max = min(max_input_size, int(coord[0] + 4))
        c2_min = max(0, int(coord[1] - 3))
        c2_max = min(max_input_size, int(coord[1] + 4))
        c3_min = max(0, int(coord[2] - 3))
        c3_max = min(max_input_size, int(coord[2] + 4))
        
        c1_range = np.arange(c1_min, c1_max)
        c2_range = np.arange(c2_min, c2_max)
        c3_range = np.arange(c3_min, c3_max)
        
        # Create a grid of c1, c2, c3 values
        c1_grid, c2_grid, c3_grid = np.meshgrid(c1_range, c2_range, c3_range, indexing='ij')
        
        # Compute coord2 and distances
        coord2_x = c1_grid + 0.5
        coord2_y = c2_grid + 0.5
        coord2_z = c3_grid + 0.5
        dx = coord2_x - coord[0]
        dy = coord2_y - coord[1]
        dz = coord2_z - coord[2]
        dst = np.sqrt(dx**2 + dy**2 + dz**2)
        
        # Compute the exponential values
        values = np.exp(-dst / sigma_sqd)
        
        # Flatten the arrays
        c1_flat = c1_grid.flatten().astype(int)
        c2_flat = c2_grid.flatten().astype(int)
        c3_flat = c3_grid.flatten().astype(int)
        values_flat = values.flatten()
        
        # Create keys and update the dictionary
        keys = zip(c1_flat, c2_flat, c3_flat)
        for key, value in zip(keys, values_flat):
            point_dict[(key[0], key[1], key[2], atom_key)] += value
        
    return(point_dict)

def get_ca_coordinate_dict(filename: str, max_input_size: int = 160, resolution: float = 1.0):
    """Return a mapping from residue numbers to the (x, y, z) grid indices of their Cα atom.

    The coordinate transformation (axis re-ordering, shift and padding) mirrors the one
    used inside :func:`pdb_to_point_dict`, guaranteeing that the returned indices refer
    to the *same* voxel grid as the generated 3D point list.

    Parameters
    ----------
    filename : str
        Absolute path to the PDB file.
    max_input_size : int, optional
        Size of the 3D grid that will later hold the point list. Coordinates that would
        fall outside this cube are discarded, identical to the behaviour in
        ``pdb_to_point_dict``. Defaults to ``160``.
    resolution : float, optional
        Resolution that was used when creating the point list. Must match the value that
        was passed to :func:`pdb_to_point_dict`. Defaults to ``1.0``.

    Returns
    -------
    dict[int, tuple[int, int, int]]
        Dictionary mapping *residue number* (column "residue ID" in the PDB file) to the
        integer grid coordinates ``(x, y, z)`` of the corresponding C-alpha atom.
        If several chains share the same residue number, the entry of the last chain in
        the file will be kept.
    """
    # Parse ATOM lines from the PDB file – identical to the approach in
    # ``pdb_to_point_dict`` so that we end up with the very same DataFrame.
    ATOMS = []
    for line in open(filename):
        ls = line.split()
        if not ls or ls[0] != "ATOM":
            continue
        if len(ls) != 12:
            ls = correct_errors(ls)
        ATOMS.append(ls)

    if not ATOMS:
        return {}

    pdb = pd.DataFrame(
        ATOMS,
        columns=[
            "record type",
            "atom ID",
            "atom name",
            "residue name",
            "chain ID",
            "residue ID",
            "x",
            "y",
            "z",
            "occupancy",
            "temperature factor",
            "atom",
        ],
    )

    # Cast coordinate columns to float and apply the same axis-shift logic.
    pdb["x"] = pd.to_numeric(pdb["x"])
    pdb["y"] = pd.to_numeric(pdb["y"])
    pdb["z"] = pd.to_numeric(pdb["z"])

    pdb = shift_axes(PDB_DataFrame=pdb, resolution=resolution)

    # Truncate coordinates that would exceed the grid, mirroring ``pdb_to_point_dict``.
    mask_inside = (
        (pdb["x"] < max_input_size - 3)
        & (pdb["y"] < max_input_size - 3)
        & (pdb["z"] < max_input_size - 3)
    )
    pdb = pdb[mask_inside]

    # Select C-alpha atoms (atom name == "CA") and build the dictionary.
    ca_df = pdb[pdb["atom name"] == "CA"]

    ca_coord_dict = {}
    for _, row in ca_df.iterrows():
        # Residue number as integer (column "residue ID").
        try:
            resid = int(row["residue ID"])
        except ValueError:
            # Fallback for non-numeric residue IDs (e.g. insertion codes). Keep the
            # string to avoid accidental key collisions.
            resid = row["residue ID"]

        # Convert to integer voxel indices: floor is equivalent to ``int`` for positive
        # values because coordinates are strictly non-negative after axis shifting.
        x_idx = int(row["x"])
        y_idx = int(row["y"])
        z_idx = int(row["z"])

        ca_coord_dict[resid] = (x_idx, y_idx, z_idx)

    return ca_coord_dict