"""
Visualize pharmacophores with py3dmol.
"""
import numpy as np
from matplotlib.colors import to_hex

from rdkit import Chem

# drawing
from rdkit.Chem.Draw import IPythonConsole
from .pharmacophore import feature_colors, get_pharmacophores_dict
import py3Dmol

def __draw_arrow(view, color, anchor_pos, rel_unit_vec, flip: bool = False):
    """
    Add arrow
    """
    keys = ['x', 'y', 'z']
    if flip:
        flip = -1.
    else:
        flip = 1.
        
    view.addArrow({
        'start' : {k: anchor_pos[i] for i, k in enumerate(keys)},
        'end' : {k: flip*2*rel_unit_vec[i] + anchor_pos[i] for i, k in enumerate(keys)},
        'radius': .1,
        'radiusRatio':2.5,
        'mid':0.6,
        'color':to_hex(color)
    })


def draw_general(mol: Chem.rdchem.Mol,
                 feats: dict = {},
                 point_cloud = None,
                 esp = None,
                 add_SAS = False,
                 view = None,
                 confId = -1,
                 removeHs = False,
                 width = 800,
                 height = 400):
    """
    Draw molecule with pharmacophore features and point cloud on surface accessible surface and electrostatics.
    """
    if esp is not None:
        esp_colors = np.zeros((len(esp), 3))
        esp_colors[:,2] = np.where(esp < 0, 0, esp/np.max((np.max(-esp), np.max(esp)))).squeeze()
        esp_colors[:,0] = np.where(esp >= 0, 0, -esp/np.max((np.max(-esp), np.max(esp)))).squeeze()

    if view is None:
        view = py3Dmol.view(width=width, height=height)
    view.removeAllModels()
    if removeHs:
        mol = Chem.RemoveHs(mol)
    IPythonConsole.addMolToView(mol, view ,confId=confId)
    keys = ['x', 'y', 'z']

    if feats:
        for fam in feats: # cycle through pharmacophores
            clr = feature_colors.get(fam, (.5,.5,.5))

            num_points = len(feats[fam]['P'])
            for i in range(num_points):
                pos = feats[fam]['P'][i]
                view.addSphere({'center':{keys[k]: pos[k] for k in range(3)},'radius':.5,'color':to_hex(clr)})

                if fam not in ('Aromatic', 'Donor', 'Acceptor'):
                    continue

                vec = feats[fam]['V'][i]
                __draw_arrow(view, clr, pos, vec, flip=False)

                if fam == 'Aromatic':
                    __draw_arrow(view, clr, pos, vec, flip=True)

    if point_cloud is not None:
        clr = np.zeros(3)
        for i, pc in enumerate(point_cloud):
            if esp is not None:
                if np.sqrt(np.sum(np.square(esp_colors[i]))) < 0.3:
                    clr = np.ones(3)
                else:
                    clr = esp_colors[i]
            else:
                esp_colors = np.ones((len(point_cloud), 3))
            view.addSphere({'center':{'x':pc[0], 'y':pc[1], 'z':pc[2]}, 'radius':.1,'color':to_hex(clr), 'opacity':0.5})
    if add_SAS:
        view.addSurface(py3Dmol.SAS, {'opacity':0.5})
    view.zoomTo()
    return view.show() # view.show() to save memory


def draw_pharmacophores(mol, view=None, width=800, height=400):
    """
    Generate the pharmacophores and visualize them.
    """
    draw_general(mol,
                 feats = get_pharmacophores_dict(mol),
                 view = view,
                 width = width,
                 height = height)