import os
import pickle
import numpy as np
import mdtraj as md
import py3Dmol

from myopenfold.np import protein


def visualize_prot(prot = None, pdb = None, pdb_string = None, 
                   show_sidechains=True, 
                   b_factors=None, 
                   show_sphere=False,
                   slice = (None, None)):
    if prot is not None:
        pass
    elif pdb is not None:
        # if pdb[-3:] == 'ent':
        #     traj = md.load(pdb)
        prot = protein.from_pdb_string(open(pdb).read())
    elif pdb_string is not None:
        prot = protein.from_pdb_string(pdb_string)
    else:
        raise NotImplementedError
    
    n_res = prot.atom_positions.shape[0]
    res_mask = np.where(np.any(prot.atom_mask, axis = 1))[0]
    view = py3Dmol.view(width=800, height=600)
    view.addModelsAsFrames(protein.to_pdb(prot, slice = slice))
    if b_factors is not None:
        min_b = b_factors.min().item()
        max_b = b_factors.max().item()
        style = {'cartoon': {'colorscheme':{'prop':'b','gradient':'roygb','min': min_b,'max': max_b }}}
        if show_sidechains:
            style['stick'] = {'colorscheme':{'prop':'b','gradient':'roygb','min': min_b,'max': max_b}}
        if show_sphere:
            style['sphere'] = {'colorscheme':{'prop':'b','gradient':'roygb','min': min_b,'max': max_b}}
    else:
        min_res = res_mask.min().item()  # res_mask.min()
        max_res = res_mask.max().item()  # res_mask.max()
        style = {'cartoon': {'colorscheme':{'prop':'resi','gradient':'roygb','min': min_res, 'max': max_res}}}  # {'cartoon': {'color': 'spectrum'}}
        if show_sidechains:
            style['stick'] = {'colorscheme':{'prop':'resi','gradient':'roygb','min': min_res,'max': max_res}}
        if show_sphere:
            style['sphere'] = {'colorscheme':{'prop':'resi','gradient':'roygb','min': min_res,'max': max_res}}
    print(style)
    view.setStyle({'model': -1}, style) # style {"cartoon": {'color': 'spectrum'}}
    return view.zoomTo()