import base64
import json
import os
import re
import time
import requests
import torch
import yaml
import subprocess

class SingleQuoted(str):
    pass

def single_quoted_representer(dumper, data):
    return dumper.represent_scalar('tag:yaml.org,2002:str', data, style="'")

yaml.add_representer(SingleQuoted, single_quoted_representer)
    
def calculate_boltz(protein_name, ligand):
    if protein_name == "c-met":
        protein_sequence = "HIDLSALNPELVQAVQHVVIGPSSLIVHFNEVIGRGHFGCVYHGTLLDNDGKKIHCAVKSLNRITDIGEVSQFLTEGIIMKDFSXPNVLSLLGICLRSEGSPLVVLPYMKHGDLRNFIRNETHNPTVKDLIGFGLQVAKGMKYLASKKFVXRDLAARNCMLDEKFTVKVAXFGLARDMYDKEYYSVXNKTGAKLPVKWMALESLQTQKFTTKSDVWSFGVLLWELMTRGAPPYPDVNTFDITVYLLQGRRLLQPEYCPDPLYEVMLKCWXPKAEMRPSFSELVSRISAIFSTFIG"
    elif protein_name == "brd4":
        protein_sequence = "SHMEQLKCCSGILKEMFAKKHAAYAWPFYKPVDVEALGLHDYCDIIKHPMDMSTIKSKLEAREYRDAQEFGADVRLMFSNCYKYNPPDHEVVAMARKLQDVFEMRFAKM"
    else:
        print("Uknown protein!")
        return
    
    msa_exists = os.path.isfile(f"./boltz_cache/msa/{protein_name}.csv")
    if msa_exists:
        data = {
            "version": 1,
            "sequences": [
                {
                    "protein": {
                        "id": "A",
                        "sequence": protein_sequence,
                        "msa": f"./boltz_cache/msa/{protein_name}.csv"
                    }
                },
                {
                    "ligand": {
                        "id": "B",
                        "smiles": SingleQuoted(ligand)
                    }
                }
            ],
            "properties": [
                {
                    "affinity": {
                        "binder": "B"
                    }
                }
            ]
        }
    else:
        data = {
            "version": 1,
            "sequences": [
                {
                    "protein": {
                        "id": "A",
                        "sequence": protein_sequence,
                    }
                },
                {
                    "ligand": {
                        "id": "B",
                        "smiles": SingleQuoted(ligand)
                    }
                }
            ],
            "properties": [
                {
                    "affinity": {
                        "binder": "B"
                    }
                }
            ]
        }
        
    try:
        print(protein_name)
        ligand = re.sub(r'[\\/:\*\?"<>\|]', '_', ligand)
        name = f"{protein_name}_{ligand}"
        output_file = f"./boltz_cache/inputs/{name}.yaml" 
        if os.path.isfile(output_file): os.remove(output_file)
        with open(output_file, "w") as outfile:
            yaml.dump(data, outfile, sort_keys=False)
        
        all_gpu_env = os.environ.copy()
        all_gpu_env['CUDA_VISIBLE_DEVICES'] = "0,1,2,3,4,5,6,7"
        gpu = subprocess.run("python3 ./main/utils/find_gpu.py".split(), env=all_gpu_env, capture_output=True, text=True).stdout
        gpu = gpu.replace('\n', '')
        print("\nBoltz running on GPU " + gpu, flush=True)
        new_env = os.environ.copy()
        new_env['CUDA_VISIBLE_DEVICES'] = gpu
        
        boltz_command = f"boltz predict {output_file} --use_msa_server --output_format pdb --out_dir ./boltz_cache/results"
        subprocess.run(boltz_command.split(), env=new_env)
        
        if not msa_exists:
            move_command = f"cp ./boltz_cache/results/boltz_results_{name}/msa/{name}_0.csv ./boltz_cache/msa/{protein_name}.csv"
            subprocess.run(move_command.split())
        
        if not os.path.isfile(f"./boltz_cache/results/boltz_results_{name}/predictions/{name}/affinity_{name}.json"):
            affinity = 0
        else:
            with open(f"./boltz_cache/results/boltz_results_{name}/predictions/{name}/affinity_{name}.json", "r") as file:
                result = json.load(file)
                affinity = result["affinity_pred_value"]
                affinity = (6 - float(affinity)) * -1.364
                affinity = round(affinity, 2)
        return affinity
    except Exception as e:
        print(e)
        return 0