import pymol
from pymol import cmd
import pickle
import networkx as nx
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
import matplotlib.image as mpimg


"""
Nice camera views for proteins with corresponding pdb ids
"""
views = {}
views["1cc8_A"] = """\
     0.763499558,    0.309624940,   -0.566744685,\
    -0.299389571,   -0.607875705,   -0.735424161,\
    -0.572217107,    0.731172562,   -0.371415883,\
     0.000000004,    0.000004972,  -74.506790161,\
     0.361687273,    1.107958794,   -0.115405403,\
    51.589096069,   97.424530029,  -19.999998093 """
views["1id0_A"] = """\
    -0.254355341,    0.718892097,   -0.646913648,\
     0.966919482,    0.175733104,   -0.184888989,\
    -0.019231033,   -0.672540784,   -0.739809811,\
     0.000000000,    0.000000000, -124.400863647,\
     0.408094406,    0.228492737,    2.062408447,\
    88.815368652,  159.986389160,  -20.000000000 """
views["1l3p_A"] = """\
     0.402022630,    0.062729187,    0.913478553,\
     0.825484276,    0.406828821,   -0.391234815,\
    -0.396172196,    0.911347568,    0.111771576,\
     0.000000000,    0.000000000, -128.573135376,\
    -0.117021561,    0.512550354,    0.588439941,\
   103.617271423,  153.528976440,  -20.000000000 """
views["2eaq_A"] = """\
     0.478335142,    0.870104671,   -0.118802272,\
    -0.375687540,    0.325029463,    0.867880940,\
     0.793760002,   -0.370503217,    0.482359707,\
     0.000000551,    0.000002749,  -96.189186096,\
    -0.231350377,   -1.982213736,    1.682779193,\
    70.790626526,  121.587432861,  -20.000000000 """


def put_H0_graph_on_pdb(pdb_id,
           G,
           con,
           layer_id):
    
    """
    pdb_id: input PDB identifier
    G: list of networkx H0 graphs for each protein
    con: labels for residues
    layer_id: layer_ids
    """
    
    pdb_path = "./ConSuf10k_graphs/pdbs/"+pdb_id+".pdb"
    cmd.set_key('F1', cmd.zoom, ['all within 5.0 of (sele)'], {'animate': 1})
    cmd.set_key('F2', cmd.zoom, [], {'animate':1})

    ### delete everything and load protein
    cmd.do(f"delete *")
    cmd.do(f"load {pdb_path}, prot")
    cmd.do("remove heta")
    cmd.do("color gray, elem C")
    cmd.do("set ray_shadows,off")

    ### change B factor according to input labels
    for i in range(len(con)):
        cmd.do(f"alter resi {i+1}, b={con[i]}")

    
    ### plot graph in PyMOL    
    n = 0
    for i1,i2 in G.edges():
        n+=1
        r1 = i1+1
        r2 = i2+1
        cmd.do(f'distance resi {r1} & name CA, resi {r2} &name CA')
        cmd.do(f'show spheres, resi {r1} & name CA')
        cmd.do(f'show spheres, resi {r2} & name CA')
        nn = str(n)
        if len(nn) == 1:
            nn = "0"+nn

        ### set sphere scale based on logarithm scaled degree of the node
        scale1 = 0.5*np.log10(len(list(G.neighbors(i1))))
        scale2 = 0.5*np.log10(len(list(G.neighbors(i2))))

        ### show very small spheres for the nodes with low degree of connectivity so we can see them
        if scale1<0.1:
            scale1=0.1
        if scale2<0.1:
            scale2=0.1
            
        cmd.do(f'set sphere_scale, {scale1}, resi {r1} & name CA')
        cmd.do(f'set sphere_scale, {scale2}, resi {r2} & name CA')        
        cmd.do(f'color black, dist*')

    ### color protein based on conservation/labels
    cmd.do(f'spectrum b')

    cmd.do("bg_color white")
    cmd.do('hide labels')
    
    cmd.do("set cartoon_transparency, 0.5")

    ### change the camera view
    if pdb_id in views:
        cmd.set_view(views[pdb_id])

    ### save images to folder
    Path("./ConSuf10k_graphs/pngs").mkdir(exist_ok=True)
    cmd.do(f'png ./ConSuf10k_graphs/pngs/{pdb_id}_{layer_id}.png, dpi=300, width=1500')



def combine_figs(pdb_id,
                  cut = 0.05,
                  show=False):
    """
    Combine and savefigures generated by put_H0_graph_on_pdb function
    """
    image_paths = [f"./ConSuf10k_graphs/pngs/{pdb_id}_{layer_id}.png" for layer_id in [0,4,8,12,16,20,24,31]]
    fig, axes = plt.subplots(2,4, figsize=(8*2, 6))
    n,j=0,0
    for img_path in image_paths:
        if n==4:
            n=0
            j+=1
        ax = axes[j,n]#n//5]
        img = mpimg.imread(img_path)
        height, width, channels = img.shape
        crop_width = int(cut * width)
        img = img[:, crop_width:-crop_width]
        ax.imshow(img)
        ax.set_title("Layer "+str(int(img_path.split("_")[-1][:-4])+1))
        ax.axis('off')
        n+=1
    Path(f"./ConSuf10k_graphs/combined_png/").mkdir(exist_ok=True)
    plt.savefig(f"./ConSuf10k_graphs/combined_png/pdb_{pdb_id}.png", dpi=300)
    return



    
def main():
    for layer_id in [0,4,8,12,16,20,24,28,31]:
        for pdb_id in ["1cc8_A",
                       "2eaq_A",
                       "1id0_A",
                       "1l3p_A"]:
            precalculated_data = pickle.load(open("./ConSuf10k_graphs/graphs/"+pdb_id+".pkl",'rb'))

            ### precalculated H0 graph
            graphs = precalculated_data["graphs"]

            ### conservation labels for ConSuf10k dataset
            labels = precalculated_data["labels_int"]            
            put_H0_graph_on_pdb(pdb_id,
                                G=graphs[layer_id],
                                con=labels,
                                layer_id=layer_id)

    for pdb_id in ["1cc8_A",
              "2eaq_A",
              "1id0_A",
              "1l3p_A"]:
        combine_figs(pdb_id)

if __name__ == "__main__":
    main()
