## Directly adapted from: https://github.com/a-r-j/CPDB/blob/main/cpdb/__init__.py

import gzip
import os
import pathlib
import sys
from typing import Dict, Optional, Union
from urllib.error import HTTPError, URLError
from urllib.request import urlopen

import numpy as np
import pandas as pd

# import pyximport
# pyximport.install()
# from .parser import parse_pdb_file, parse_pdb_string

AF2_VERSION: int = 4


import numpy as np


def parse_pdb_to_dict(pdb_file):
    # Initialize lists to store the data for each field
    fields = {
        "record_name": [],
        "atom_number": [],
        "atom_name": [],
        "alt_loc": [],
        "residue_name": [],
        "chain_id": [],
        "residue_number": [],
        "insertion": [],
        "x_coord": [],
        "y_coord": [],
        "z_coord": [],
        "occupancy": [],
        "b_factor": [],
        "element_symbol": [],
        "charge": [],
        "model_idx": [],
    }
    current_model_idx = 1  # Model index starts at 1

    with open(pdb_file, "r") as file:
        for line in file:
            record = line[0:6].strip()

            if record == "MODEL":
                current_model_idx = int(line[10:14].strip())
            elif record in ("ATOM", "HETATM"):
                fields["record_name"].append(record)
                fields["atom_number"].append(int(line[6:11].strip()))
                fields["atom_name"].append(line[12:16].strip())
                fields["alt_loc"].append(line[16:17].strip())
                fields["residue_name"].append(line[17:20].strip())
                fields["chain_id"].append(line[21:22].strip())
                fields["residue_number"].append(int(line[22:26].strip()))
                fields["insertion"].append(line[26:27].strip())
                fields["x_coord"].append(float(line[30:38].strip()))
                fields["y_coord"].append(float(line[38:46].strip()))
                fields["z_coord"].append(float(line[46:54].strip()))
                fields["occupancy"].append(
                    float(line[54:60].strip()) if line[54:60].strip() else 1.0
                )
                fields["b_factor"].append(
                    float(line[60:66].strip()) if line[60:66].strip() else 0.0
                )
                fields["element_symbol"].append(line[76:78].strip())
                fields["charge"].append(line[78:80].strip())
                fields["model_idx"].append(current_model_idx)

    # Convert lists to numpy arrays with appropriate dtypes
    return {
        "record_name": np.array(fields["record_name"], dtype=object),
        "atom_number": np.array(fields["atom_number"], dtype=np.int32),
        "atom_name": np.array(fields["atom_name"], dtype=object),
        "alt_loc": np.array(fields["alt_loc"], dtype=object),
        "residue_name": np.array(fields["residue_name"], dtype=object),
        "chain_id": np.array(fields["chain_id"], dtype=object),
        "residue_number": np.array(fields["residue_number"], dtype=np.int32),
        "insertion": np.array(fields["insertion"], dtype=object),
        "x_coord": np.array(fields["x_coord"], dtype=np.float32),
        "y_coord": np.array(fields["y_coord"], dtype=np.float32),
        "z_coord": np.array(fields["z_coord"], dtype=np.float32),
        "occupancy": np.array(fields["occupancy"], dtype=np.float32),
        "b_factor": np.array(fields["b_factor"], dtype=np.float32),
        "element_symbol": np.array(fields["element_symbol"], dtype=object),
        "charge": np.array(fields["charge"], dtype=object),
        "model_idx": np.array(fields["model_idx"], dtype=np.int32),
    }


# Example usage (adjust the file path as necessary)
# pdb_data = parse_pdb_to_dict("example.pdb")
# print(pdb_data)

import pandas as pd


def pdb_to_dataframe(pdb_file):
    """
    Convert the output of parse_pdb_to_dict to a Pandas DataFrame.

    Parameters:
    - pdb_dict (dict): Dictionary generated by the parse_pdb_to_dict function.

    Returns:
    - pd.DataFrame: A DataFrame containing the PDB data.
    """
    pdb_dict = parse_pdb_to_dict(pdb_file)
    # Convert the dictionary to a DataFrame
    return pd.DataFrame(pdb_dict)


# Example usage (adjust the file path as necessary):
# pdb_dict = parse_pdb_to_dict("example.pdb")
# pdb_df = pdb_to_dataframe(pdb_dict)
# print(pdb_df)


def parse(
    fname: Optional[os.PathLike] = None,
    pdb_str: Optional[str] = None,
    pdb_code: Optional[str] = None,
    uniprot_id: Optional[str] = None,
    df: bool = True,
) -> Union[Dict[str, np.ndarray], pd.DataFrame]:
    if fname is not None:
        if isinstance(fname, pathlib.Path):
            fname = str(fname)
        if fname.endswith(("pdb.gz", ".ent.gz")):
            with gzip.open(fname, "rb") as f:
                pdb_str = f.read()
            pdb_str = (
                pdb_str.decode("utf-8")
                if sys.version_info[0] >= 3
                else pdb_str.encode("ascii")
            )
        else:
            d = parse_pdb_file(fname)

    # if pdb_str is not None:
    #     if isinstance(pdb_str, list):
    #         pdb_str = "".join(pdb_str)
    #     d = parse_pdb_string(pdb_str)

    # if pdb_code is not None:
    #     pdb_str = _fetch_pdb(pdb_code)
    #     d = parse_pdb_string(pdb_str)

    # if uniprot_id is not None:
    #     pdb_str = _fetch_af2(uniprot_id)
    #     d = parse_pdb_string(pdb_str)

    return pd.DataFrame(d) if df else d


def _fetch_pdb(pdb_code: str) -> str:
    """Load PDB file from rcsb.org."""
    txt = None
    url = f"https://files.rcsb.org/download/{pdb_code.lower()}.pdb"
    try:
        response = urlopen(url)
        txt = response.read()
        txt = txt.decode("utf-8") if sys.version_info[0] >= 3 else txt.encode("ascii")
    except HTTPError as e:
        print(f"HTTP Error {e.code}")
    except URLError as e:
        print(f"URL Error {e.args}")
    return txt


def _fetch_af2(uniprot_id: str) -> str:
    """Load PDB file from https://alphafold.ebi.ac.uk/."""
    txt = None
    url = f"https://alphafold.ebi.ac.uk/files/AF-{uniprot_id.upper()}-F1-model_v{AF2_VERSION}.pdb"
    try:
        response = urlopen(url)
        txt = response.read()
        txt = txt.decode("utf-8") if sys.version_info[0] >= 3 else txt.encode("ascii")
    except HTTPError as e:
        print(f"HTTP Error {e.code}")
    except URLError as e:
        print(f"URL Error {e.args}")
    return txt
