import os
import subprocess
import time

import torch
from tqdm import tqdm
import concurrent.futures


def run(prot_a, prot_b):
    if os.path.isfile(prot_a) and os.path.isfile(prot_b):
        result = subprocess.run(
            ["./RAG/USalign", prot_a, prot_b],
            capture_output=True, text=True, check=False
        )
        if result.returncode != 0:

            return None

        data = result.stdout.splitlines()
        scores = []
        seq1 = data[-5]
        seq2 = data[-3]

        for d in data:
            if d.startswith("TM-score="):
                parts = d.split(" ")
                scores.append(float(parts[1]))
        if not scores:
            return None

        return seq1, seq2, max(float(scores[0]), float(scores[1]))
    else:
        # raise Exception("Check that %s and %s exists" % (prot_a, prot_b))
        return None


def process_file(rag_f, rag_dir, data_dir, database_dir):

    with open(os.path.join(rag_dir, rag_f)) as f:
        lines = f.readlines()

    data = []
    for line in lines:
        name1, name2, score = line.strip().split()


        score = float(score)

        # remove same structure to avoid data leakage
        if abs(score - 1.0) < 1e-4:
            continue

        if score > 0.5:
            gt_path = os.path.join(data_dir, name1 + ".pdb")
            rag_path = os.path.join(database_dir, name2 + ".pdb")

            run_result = run(gt_path, rag_path)
            if run_result is None:
                continue

            seq1, seq2, tm = run_result

            # remove subseq to avoid data leakage
            seq = seq1.strip().replace("-", "")
            rag_seq = seq2.strip().replace("-", "")

            if seq in rag_seq or rag_seq in seq:
                continue

            if tm > 0.5:
                data.append({
                    "tm_score": tm,
                    "aligned1": seq1,
                    "aligned2": seq2,
                    "seq": seq
                })

    if data:
        return rag_f, data
    return None


def run_cath():
    modes = ["train", "validation", "test"]

    MAX_WORKERS = 64

    for mode in modes:
        rag_dir = f"./cath42/{mode}/rag_seq/"
        data_dir = f"../data/cath_download_42/{mode}/"
        database_dir = f"../data/SwissProt/flatten/"

        data_dict = {}
        files_to_process = os.listdir(rag_dir)

        with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
            future_to_file = {
                executor.submit(process_file, rag_f, rag_dir, data_dir, database_dir): rag_f
                for rag_f in files_to_process
            }

            progress = tqdm(concurrent.futures.as_completed(future_to_file), total=len(files_to_process),
                            desc=f"Processing {mode}")

            for future in progress:
                result = future.result()
                if result:
                    rag_f, data = result
                    data_dict[rag_f] = data

        if files_to_process:
            num_hit = len(data_dict)
            total_nbh = sum(len(v) for v in data_dict.values())
            avg_nbh = total_nbh / len(files_to_process)
            print(f"Total hits: {num_hit}, Average neighbors per file: {avg_nbh:.4f}")
        else:
            print("No files to process.")

        torch.save(data_dict, f"./cath42/{mode}/data_dict_mul.pt")


t = time.time()
run_cath()
print(f"Time: {time.time() - t}")
