import os
import time
import subprocess
import numpy as np
import pandas as pd
from typing import Any, List, Literal, Optional
from Bio import SeqIO

from src.constants import DATA_PATH


def cluster_sequence_identity(
    input_sequences,
    identity_threshold=0.8,
    word_size=4,
    input_file="input",
    output_file="output",
):
    """
    Cluster sequences based on sequence similarity using CD-HIT.

    Notes:
    - https://manpages.ubuntu.com/manpages/impish/man1/cd-hit-est.1.html
    - How to chose word size? https://github.com/weizhongli/cdhit/wiki/3.-User's-Guide#user-content-CDHITEST
       -n 10, 11 for thresholds 0.95 ~ 1.0
       -n 8,9    for thresholds 0.90 ~ 0.95
       -n 7      for thresholds 0.88 ~ 0.9
       -n 6      for thresholds 0.85 ~ 0.88
       -n 5      for thresholds 0.80 ~ 0.85
       -n 4      for thresholds 0.75 ~ 0.8
    """
    t0 = time.time()

    # Write input sequences to the temporary input file
    SeqIO.write(input_sequences, input_file, "fasta")

    # Run CD-HIT-EST
    cmd = [
        "cd-hit-est",
        "-i",
        input_file,
        "-o",
        output_file,
        "-c",
        str(identity_threshold),  # Sequence identity threshold (e.g., 90%)
        "-n",
        str(
            word_size
        ),  # Word size for sequence comparisson, larger is better (default: 2)
        "-M",
        str(0),  # Memory limit in MB (0 for unlimited)
        "-T",
        str(0),  # Number of threads (0 for using all CPUs)
    ]
    subprocess.run(cmd, check=True)

    # Read clustered sequences from the temporary output file
    # clustered_sequences = list(SeqIO.parse(output_file, "fasta"))

    # Process the clustering output
    seq_id_to_cluster = {}
    with open(output_file + ".clstr", "r") as f:
        current_cluster = None
        for line in f:
            if line.startswith(">"):
                current_cluster = int(line.strip().split(" ")[1])
            else:
                seq_id = line.split(">")[1].split("...")[0]
                seq_id_to_cluster[seq_id] = current_cluster

    # Delete temporary files
    os.remove(input_file)
    os.remove(output_file)
    os.remove(output_file + ".clstr")

    print(f"Total CPU time {(time.time() - t0)/60:.2f} m")

    return seq_id_to_cluster


def parse_qtmclust_cluster_file(file_path: str) -> List[List[Any]]:
    # Return a list of lists, where each inner list is a cluster of structures
    clusters = {}
    with open(file_path) as file:
        for line in file:
            columns = line.strip().split("\t")
            valid_columns = [
                col for col in columns if col
            ]  # filter out any empty columns
            # NOTE: the representative structure is the first (col=0) structure for a given cluster
            cluster_repr = valid_columns[0]
            clusters[cluster_repr] = valid_columns
    return list(clusters.values())


def run_qtmclust(
    chain_dir: str,
    chain_list_filepath: str,
    qtmclust_exec_path: str,
    output_cluster_filepath: Optional[str] = None,
    tm_cluster_threshold: float = 0.45,
    chain_ter_mode: Literal[0, 1, 2, 3] = 3,
    chain_split_mode: Literal[0, 1, 2] = 0,
) -> Optional[pd.DataFrame]:
    # Run qTMclust structural similarity clustering
    # For more information on `chain_ter_mode` and `chain_split_mode`, please see:
    # https://github.com/pylelab/USalign/blob/58b42af9d58436279c21b4f4074db87f072fcc21/qTMclust.cpp#L72
    # and
    # https://github.com/pylelab/USalign/blob/58b42af9d58436279c21b4f4074db87f072fcc21/qTMclust.cpp#L78
    cmd = [
        qtmclust_exec_path,
        "-dir",
        (chain_dir if chain_dir.endswith("/") else chain_dir + "/"),
        chain_list_filepath,
        "-TMcut",
        str(tm_cluster_threshold),
        "-ter",
        str(chain_ter_mode),
        "-split",
        str(chain_split_mode),
    ]
    if output_cluster_filepath is not None:
        cmd += ["-o", output_cluster_filepath]
    subprocess.run(" ".join(cmd), capture_output=True, shell=True)  # nosec
    if output_cluster_filepath is not None:
        output_clusters = parse_qtmclust_cluster_file(output_cluster_filepath)
        return output_clusters


def cluster_structure_similarity(
    input_pdb_files,
    similarity_threshold: float = 0.45,
    chain_list_filepath: str = "chain_list",
    output_cluster_filepath: str = "cluster.txt",
    chain_dir: str = os.path.join(DATA_PATH, "raw"),
    qtmclust_exec_path: str = "~/USalign/qTMclust",
):
    """
    Cluster structures based on their structural similarity using qTMclust.

    Credit: Alex Morehead

    Notes:
    - https://zhanggroup.org/US-align/
    - TM-score has values in (0,1] with 1 indicating an identical structure match,
      where a TM-score ≥0.5 (or 0.45) means the structures share the same global
      topology for proteins (or RNAs).
    """
    t0 = time.time()

    with open(chain_list_filepath, "w") as f:
        for pdb_file_index, pdb_file in enumerate(input_pdb_files):
            # record the name of each PDB file in a temporary text file input
            sample_name_without_extension = os.path.basename(
                os.path.splitext(pdb_file)[0]
            )
            sample_name_postfix = (
                "" if pdb_file_index == (len(input_pdb_files) - 1) else "\n"
            )
            f.write(f"{sample_name_without_extension}{sample_name_postfix}")

    clustered_structures = run_qtmclust(
        chain_dir=chain_dir,
        chain_list_filepath=chain_list_filepath,
        qtmclust_exec_path=qtmclust_exec_path,
        output_cluster_filepath=output_cluster_filepath,
        tm_cluster_threshold=similarity_threshold,  # note: clusters two chains if their TM-score is `similarity_threshold` or greater
        chain_ter_mode=0,  # note: reads all chains
        chain_split_mode=0,  # note: parses all chains in a complex as a single chain
    )

    # Delete temporary files
    os.remove(chain_list_filepath)
    os.remove(output_cluster_filepath)

    print(f"Total CPU time {(time.time() - t0)/60:.2f} m")

    return clustered_structures
