import random
import mdgen.residue_constants as rc
import os


# Configurations
input_file = "splits/mdCATH.txt"  # Replace with the name of your input file
train_file = "splits/mdCATH_train.csv"
val_file = "splits/mdCATH_val.csv"
test_file = "splits/mdCATH_test.csv"
error_domains_file = "./splits/erroneous_domains.txt"
pdb_dir = "./topology"

# Proportions for splitting the dataset
train_ratio = 0.8
val_ratio = 0.1
test_ratio = 0.1

# Set random seed for reproducibility
random_seed = 42
random.seed(random_seed)

def three_to_one(res_name: str) -> str:
    """
    Converts a 3-letter residue name to a 1-letter code.
    Returns 'X' if the residue name is not in the known dictionary.
    """
    return rc.restype_3to1.get(res_name.upper(), 'X')

def parse_atom_sequence(pdb_path: str) -> str:
    """
    Parse ATOM/HETATM lines from the PDB file. Collect residues
    in ascending order of residue number (columns 22–26).
    """
    if not os.path.isfile(pdb_path):
        return ""

    # Dictionary: residue_number -> residue_name
    # If a residue appears multiple times (e.g., multiple ATOM lines for the same residue),
    # we just store the last seen residue name (usually they match).
    residues = {}

    with open(pdb_path, "r") as f:
        for line in f:
            if line.startswith("ATOM") or line.startswith("HETATM"):
                # Residue name: columns 17-20
                # Residue number: columns 22-26
                res_name = line[17:20].strip()
                res_num_str = line[22:26].strip()
                if res_name in ["HSD", "HSE", "HSP"]:
                    res_name = 'HIS'

                try:
                    res_num = int(res_num_str)
                except ValueError:
                    # If we can't parse residue number as int, skip
                    continue

                residues[res_num] = res_name

    # Sort by residue number and convert each to one-letter code
    sorted_res_nums = sorted(residues.keys())
    seq = "".join(three_to_one(residues[n]) for n in sorted_res_nums)
    return seq

def main():
    # Validate split ratios
    print(rc.restype_3to1)
    assert abs(train_ratio + val_ratio + test_ratio - 1.0) < 1e-8, "Ratios must sum to 1.0"

    # Read error domains (to exclude)
    with open(error_domains_file, "r") as f:
        error_domains = {line.strip() for line in f if line.strip()}

    # Read input file, exclude error domains
    with open(input_file, "r") as f:
        all_domains = [
            line.strip()
            for line in f
            if line.strip() and (line.strip() not in error_domains)
        ]

    # Shuffle for reproducibility
    random.seed(random_seed)
    random.shuffle(all_domains)

    # Split
    num_domains = len(all_domains)
    train_end = int(train_ratio * num_domains)
    val_end   = train_end + int(val_ratio * num_domains)

    train_domains = all_domains[:train_end]
    val_domains   = all_domains[train_end:val_end]
    test_domains  = all_domains[val_end:]

    def write_csv(domains, output_csv):
        with open(output_csv, "w") as out:
            out.write("name,seqres\n")
            for domain in domains:
                pdb_path = os.path.join(pdb_dir, f"{domain}.pdb")
                seq = parse_atom_sequence(pdb_path)
                out.write(f"{domain},{seq}\n")

    # Write CSVs
    write_csv(train_domains, train_file)
    write_csv(val_domains, val_file)
    write_csv(test_domains, test_file)

    print(
        f"Dataset split into:\n"
        f"- {train_file}: {len(train_domains)} entries\n"
        f"- {val_file}: {len(val_domains)} entries\n"
        f"- {test_file}: {len(test_domains)} entries"
    )

if __name__ == "__main__":
    main()
