import random
import json
import jsonlines
import ray
import mmap
import time
import sys
from ray.util.queue import Queue
sys.path.append(".")
from utils.misc import execute
from utils.ray_tools import ProgressBar
from tqdm import tqdm
import pathlib
import glob
import subprocess
import os
import traceback
def split_list(_list, n):
    chunk_size = (len(_list) - 1) // n + 1
    chunks = [_list[i * chunk_size : (i + 1) * chunk_size] for i in range(n)]
    return chunks

N_CPU_PER_THREAD = 1
n_thread=200
PDBBind_dir='/path/to/dir'
MSA_dir="/path/to/dir"
AF2DB_dir="/path/to/dir"


def remove_gap_of_primary_sequence(primary_sequence, candidate_sequence):
    assert len(primary_sequence) == len(candidate_sequence)
    primary_sequence_without_gap = ""
    candidate_sequence_without_gap = ""
    for i in range(len(primary_sequence)):
        if primary_sequence[i] != "-":
            primary_sequence_without_gap += primary_sequence[i]
            candidate_sequence_without_gap += candidate_sequence[i]
    return primary_sequence_without_gap, candidate_sequence_without_gap


@ray.remote(num_cpus=N_CPU_PER_THREAD)
def process_jobs(id,jobs_queue,actor):
    print("start process",id)
    while not jobs_queue.empty():
        job = jobs_queue.get()
        try:
            execute_one_job(job)
            
        except:
            print(f"failed: {job}")
            traceback.print_exception(*sys.exc_info())
        try:
            actor.update.remote(1)
        except:
            pass
    return 1


def execute_one_job(PDBBind_instance_dir):
    print("#######################")
    pdb_id=PDBBind_instance_dir.split("/")[-2]
    # get the sequence from pdb
    fasta_dir=glob.glob(PDBBind_instance_dir + '/*.fasta')
    if len(fasta_dir)==0:
        return 1
    fasta_dir=fasta_dir[0]
    with open(fasta_dir) as f:
        fasta=f.readlines()
    sequence_from_pdb=fasta[1].strip()
    chain_id=fasta_dir.split("/")[-1].split(".")[0][-1]
    
    # read the pocket position
    pocket_position_file=PDBBind_instance_dir + pdb_id +chain_id+ '_pocket_position.txt'
    if not os.path.exists(pocket_position_file):
        print("position_file not exist")
        return 1
    with open(pocket_position_file) as f:
        pocket_position=f.readline().strip()

    # get the sequence from TMalign
    chain_pdb_file=PDBBind_instance_dir + pdb_id + '_pocket_chain.pdb'
    if not os.path.exists(chain_pdb_file):
        print("chain_pdb_file not exist")
        return 1
    MSA_file=MSA_dir+f"/{pdb_id}"+f"{chain_id}"+".fasta"
    if not os.path.exists(MSA_file):
        print("MSA_file not exist")
        return 1
    MSA_ids=set()
    with open(MSA_file) as f:
        lines=f.readlines()
        for idx in range(0,len(lines),2):
            MSA_ids.add(lines[idx].strip().split(" ")[-1])
            if len(MSA_ids)>200:
                break
    # print(MSA_ids)

    # create rotation matrix dir
    rotation_matrix_dir=PDBBind_instance_dir + 'rotation_matrix/'
    if not os.path.exists(rotation_matrix_dir):
        os.makedirs(rotation_matrix_dir)

    # TMalign
    for MSA_id in list(MSA_ids):
        MSA_pdb_file=AF2DB_dir+f"/{MSA_id}.pdb"
        if not os.path.exists(MSA_pdb_file):
            continue
        rotation_matrix_file=rotation_matrix_dir+f"{MSA_id}.txt"
        out_bytes = subprocess.check_output(['TMalign',MSA_pdb_file,chain_pdb_file,"-m",rotation_matrix_file])
        out_text = out_bytes.decode('utf-8').strip().split("\n")
        TMscore1=float(out_text[12].split(" ")[1])
        TMscore2=float(out_text[13].split(" ")[1])
        sequence_from_TMalign,MSA_aligned_sequence=remove_gap_of_primary_sequence(out_text[19],out_text[17])  
        TMalign_file=rotation_matrix_dir+f"{MSA_id}_TMscore.txt"
        with open(TMalign_file,"w") as f:
            f.write("TMscore normalized to chain_pdb:"+str(TMscore2)+"\n")
            f.write("TMscore normalized to MSA_pdb:"+str(TMscore1)+"\n")
            f.write("Aligned sequence : \n")
            f.write(sequence_from_TMalign+"\n")
            f.write(MSA_aligned_sequence+"\n")


    print("finish pdb_id: ",pdb_id)
    return 1






PDBBind_instance_dirs = glob.glob(PDBBind_dir + '*/')
print('Number of PDBBind instances: {}'.format(len(PDBBind_instance_dirs)))
uncompleted_jobs=[]

tmp=set()
cnt=0
# remove jobs which do not have pocket position/chain_pdb_file/MSA_file
for PDBBind_instance_dir in PDBBind_instance_dirs:
    pdb_id=PDBBind_instance_dir.split("/")[-2]
    fasta_dir=glob.glob(PDBBind_instance_dir + '/*.fasta')
    if len(fasta_dir)==0:
        continue
    chain_id=fasta_dir[0].split("/")[-1].split(".")[0][-1]

    MSA_file=MSA_dir+f"/{pdb_id}"+f"{chain_id}"+".fasta"
    if not os.path.exists(MSA_file):
        continue
    uncompleted_jobs.append(PDBBind_instance_dir)
print("uncompleted jobs:",len(uncompleted_jobs))

uncompleted_jobs=uncompleted_jobs



print("uncompleted jobs:",len(uncompleted_jobs))
job_queue = Queue()
for job in tqdm(uncompleted_jobs):
    job_queue.put(job)
print("job queue size:",job_queue.qsize())
pb = ProgressBar(len(uncompleted_jobs)) 
actor=pb.actor

jop_id_list=[]
for i in range(n_thread):
    jop_id_list.append(process_jobs.remote(i,job_queue,actor))
pb.print_until_done()
result=ray.get(jop_id_list)
print("Done!")