import os
import torch
import torch.multiprocessing as mp
from transformers import EsmForProteinFolding

# --- CONFIGURATION ---
MODEL_PATH = "facebook/esmfold_v1"
PARENT_DIR = "./Subset"  # folder containing subfolders with .fasta files
STRUCTURES_DIR = "Structures"  # root folder for outputs
os.makedirs(STRUCTURES_DIR, exist_ok=True)


# --- FASTA READER ---
def read_fasta_sequences(fasta_path):
    """Reads all sequences from a FASTA file."""
    sequences, current_seq = [], []
    with open(fasta_path, "r") as f:
        for line in f:
            line = line.strip()
            if line.startswith(">"):
                if current_seq:
                    sequences.append("".join(current_seq))
                    current_seq = []
            else:
                current_seq.append(line)
        if current_seq:
            sequences.append("".join(current_seq))
    return sequences


# --- WORKER FUNCTION ---
def fold_worker(gpu_id, gpu_tasks):
    """Each worker folds a subset of sequences on a specific GPU."""
    device = torch.device(f"cuda:{gpu_id}")
    print(f"[GPU {gpu_id}] Loading model on {device}")

    model = EsmForProteinFolding.from_pretrained(MODEL_PATH)
    model = model.eval().to(device)

    tasks = gpu_tasks[gpu_id]

    for dataset_name, seq_id, seq in tasks:
        try:
            with torch.no_grad():
                pdb_str = model.infer_pdb(seq)

            dataset_outdir = os.path.join(STRUCTURES_DIR, dataset_name)
            os.makedirs(dataset_outdir, exist_ok=True)
            save_path = os.path.join(dataset_outdir, f"seq{seq_id}.pdb")
            with open(save_path, "w") as f:
                f.write(pdb_str)

            print(f"[GPU {gpu_id}] [✓] Saved: {save_path}")
        except Exception as e:
            print(f"[GPU {gpu_id}] [❌] Failed for {dataset_name}, seq{seq_id}: {e}")


# --- MAIN ---
if __name__ == "__main__":
    num_gpus = torch.cuda.device_count()
    if num_gpus == 0:
        raise RuntimeError("No CUDA devices found. This script requires at least 1 GPU.")

    print(f"[INFO] Found {num_gpus} GPUs")

    all_tasks = []
    for subdir, _, files in os.walk(PARENT_DIR):
        for fasta_file in files:
            if fasta_file.endswith(".fasta"):
                fasta_path = os.path.join(subdir, fasta_file)

                dataset_name = os.path.basename(subdir)
                sequences = read_fasta_sequences(fasta_path)
                print(f"[INFO] Found {len(sequences)} sequence(s) in dataset '{dataset_name}'")

                for i, seq in enumerate(sequences, start=1):
                    all_tasks.append((dataset_name, i, seq))

    gpu_tasks = [[] for _ in range(num_gpus)]
    for i, task in enumerate(all_tasks):
        gpu_tasks[i % num_gpus].append(task)

    mp.spawn(fold_worker, args=(gpu_tasks,), nprocs=num_gpus, join=True)
