import os
import igl
import scipy
import numpy as np
from tqdm import tqdm
import multiprocessing
from config import get_config
from pyquaternion import Quaternion
from functools import partialmethod
from sklearn.neighbors import BallTree
import geomlib.signatures as signatures
from core.chemistry import res_type_dict, hydrophob_dict
from geomlib.gradient_operator import compute_gradient_operator

class DataProcessor():
    def __init__(self, config):
        # dataset
        self.db5_data_dir = config.db5_data_dir
        self.rcsb_data_dir = config.rcsb_data_dir
        self.num_preprocess_workers = config.num_preprocess_workers
        
        # features
        self.apply_filter = config.apply_filter
        self.min_iface_size = config.min_iface_size
        self.max_iface_size = config.max_iface_size
        self.max_iface_ratio = config.max_iface_ratio
        self.iface_cutoff = config.iface_cutoff
        self.rotation = config.rotation
        self.num_signatures = config.num_signatures
        self.num_lb_basis = config.num_lb_basis
        self.smoothing = config.smoothing
        
    def preprocess(self):
        self.processed_dir = './processed_data/'
        if not os.path.isdir(self.processed_dir):
            os.makedirs(self.processed_dir, exist_ok=False)
            self._preprocess_db5(parallel=True)
            self._preprocess_rcsb(parallel=True)
            print('preprocessing done!')
    
    def _preprocess_db5(self, parallel=True):
        dst_dir = os.path.join(self.processed_dir, 'DB5')
        os.makedirs(dst_dir, exist_ok=False)
        db5_fpaths = [os.path.join(self.db5_data_dir, fname) for fname in os.listdir(self.db5_data_dir)]
        if parallel:
            print(f'parallel preprocessing DB5 with {self.num_preprocess_workers} workers..')
            pool = multiprocessing.Pool(processes=self.num_preprocess_workers)
            pool_args = [(fpath, dst_dir, False) for fpath in db5_fpaths]
            _ = pool.starmap(self._process_serial, tqdm(pool_args), chunksize=1)
            pool.close()
            pool.join()
        else:
            for fpath in tqdm(db5_fpaths):
                self._process_serial(fpath, dst_dir, False)
    
    def _preprocess_rcsb(self, parallel=True):
        dst_dir = os.path.join(self.processed_dir, 'RCSB')
        os.makedirs(dst_dir, exist_ok=False)
        rcsb_fpaths = [os.path.join(self.rcsb_data_dir, fname) for fname in os.listdir(self.rcsb_data_dir)]
        if parallel:
            print(f'parallel preprocessing RCSB with {self.num_preprocess_workers} workers..')
            pool = multiprocessing.Pool(processes=self.num_preprocess_workers)
            pool_args = [(fpath, dst_dir, self.apply_filter) for fpath in rcsb_fpaths]
            _ = pool.starmap(self._process_serial, tqdm(pool_args), chunksize=10)
            pool.close()
            pool.join()
        else:
            for fpath in tqdm(rcsb_fpaths):
                self._process_serial(fpath, dst_dir, self.apply_filter)

    def _process_serial(self, src, dst, apply_filter=False):
        # load raw features
        data = np.load(src, allow_pickle=True)
        
        # ligand
        lig_atom_info = data['lig_atom_info']
        lig_atom_coords = lig_atom_info[:, :3]
        lig_verts = data['lig_verts']
        lig_faces = data['lig_faces'].astype(int)
        lig_num_lb_basis = int(self.num_lb_basis) if self.num_lb_basis > 1 \
                                                  else int(self.num_lb_basis*len(lig_verts))
        lig_eigen_vals = data['lig_eigen_vals'][:lig_num_lb_basis]
        lig_eigen_vecs = data['lig_eigen_vecs'][:, :lig_num_lb_basis]
        lig_mass = data['lig_mass'].item()
        lig_eigen_vecs_inv = lig_eigen_vecs.T @ lig_mass

        # receptor
        rec_atom_info = data['rec_atom_info']
        rec_atom_coords = rec_atom_info[:, :3]
        rec_verts = data['rec_verts']
        rec_faces = data['rec_faces'].astype(int)
        rec_num_lb_basis = int(self.num_lb_basis) if self.num_lb_basis > 1 \
                                                  else int(self.num_lb_basis*len(rec_verts))
        rec_eigen_vals = data['rec_eigen_vals'][:rec_num_lb_basis]
        rec_eigen_vecs = data['rec_eigen_vecs'][:, :rec_num_lb_basis]
        rec_mass = data['rec_mass'].item()
        rec_eigen_vecs_inv = rec_eigen_vecs.T @ rec_mass
        
        # ground truth interface labels and point-to-point correspondence
        # ligand to receptor
        bt_rec = BallTree(rec_verts)
        dist_lig2rec, ind_lig2rec = bt_rec.query(lig_verts, k=1)
        lig_iface = np.where(dist_lig2rec < self.iface_cutoff)[0]
        map_lig2rec = np.concatenate((lig_iface.reshape(-1, 1),
                                      ind_lig2rec[lig_iface]), axis=-1)

        # receptor to ligand
        bt_lig = BallTree(lig_verts)
        dist_rec2lig, ind_rec2lig = bt_lig.query(rec_verts, k=1)
        rec_iface = np.where(dist_rec2lig < self.iface_cutoff)[0]
        map_rec2lig = np.concatenate((rec_iface.reshape(-1, 1),
                                      ind_rec2lig[rec_iface]), axis=-1)

        # apply additional filters
        if apply_filter and (
            min(len(lig_iface), len(rec_iface)) < self.min_iface_size or \
            max(len(lig_iface), len(rec_iface)) > self.max_iface_size or \
            max(len(lig_iface)/len(lig_verts), len(rec_iface)/len(rec_verts)) > self.max_iface_ratio
        ):
            return

        if self.smoothing:
            lig_verts = lig_eigen_vecs @ (lig_eigen_vecs_inv @ lig_verts)
            rec_verts = rec_eigen_vecs @ (rec_eigen_vecs_inv @ rec_verts)

        # co-rotate the system
        co_rotation = Quaternion.random().rotation_matrix
        lig_atom_coords = lig_atom_coords @ co_rotation
        lig_verts = lig_verts @ co_rotation
        rec_atom_coords = rec_atom_coords @ co_rotation
        rec_verts = rec_verts @ co_rotation

        # random rotation + translation, which has no impact on the features
        if self.rotation:
            lig_rot = Quaternion.random().rotation_matrix
            lig_trans = np.random.uniform(-50, 50, (3,))
            lig_atom_coords = lig_atom_coords @ lig_rot + lig_trans
            lig_verts = lig_verts @ lig_rot + lig_trans

        # chemical features 
        lig_atom_feats = []
        for atom_info in lig_atom_info:
            _, _, _, res_name, atom_type, charge, radius, is_alpha_carbon = atom_info
            # obtain residue type name from its dictionary value, kinda silly..
            residue = list(res_type_dict.keys())[list(res_type_dict.values()).index(res_name)]
            hphob = hydrophob_dict[residue]
            lig_atom_feats.append([res_name, atom_type, hphob, charge, radius, is_alpha_carbon])
        rec_atom_feats = []
        for atom_info in rec_atom_info:
            _, _, _, res_name, atom_type, charge, radius, is_alpha_carbon = atom_info
            residue = list(res_type_dict.keys())[list(res_type_dict.values()).index(res_name)]
            hphob = hydrophob_dict[residue]
            rec_atom_feats.append([res_name, atom_type, hphob, charge, radius, is_alpha_carbon])
        lig_atom_feats = np.array(lig_atom_feats, dtype=float)
        rec_atom_feats = np.array(rec_atom_feats, dtype=float)

        # curvatures
        _, _, lig_k1, lig_k2 = igl.principal_curvature(lig_verts, lig_faces)
        lig_gauss_curvs = lig_k1 * lig_k2
        lig_mean_curvs = 0.5 * (lig_k1 + lig_k2)
        _, _, rec_k1, rec_k2 = igl.principal_curvature(rec_verts, rec_faces)
        rec_gauss_curvs = rec_k1 * rec_k2
        rec_mean_curvs = 0.5 * (rec_k1 + rec_k2)

        bt_lig = BallTree(lig_verts)
        lig_edge = []
        for idx, nbrs in enumerate(bt_lig.query_radius(lig_verts, r=2)):
            for nbr in nbrs:
                lig_edge.append([idx,nbr])

        bt_rec = BallTree(rec_verts)
        rec_edge = []
        for idx, nbrs in enumerate(bt_rec.query_radius(rec_verts, r=2)):
            for nbr in nbrs:
                rec_edge.append([idx,nbr])
        # geometric signatures
        lig_hks = signatures.compute_HKS(lig_eigen_vecs, lig_eigen_vals, self.num_signatures)
        rec_hks = signatures.compute_HKS(rec_eigen_vecs, rec_eigen_vals, self.num_signatures)

        # assemble coordinate-free geometric features
        lig_geom_feats = np.concatenate((lig_gauss_curvs.reshape(-1, 1),
                                         lig_mean_curvs.reshape(-1, 1),
                                         lig_hks), axis=-1)
        rec_geom_feats = np.concatenate((rec_gauss_curvs.reshape(-1, 1),
                                         rec_mean_curvs.reshape(-1, 1),
                                         rec_hks), axis=-1)
        
        # Laplace-Beltrami basis        
        lig_eigs = np.concatenate((lig_eigen_vals.reshape(1, -1),
                                   lig_eigen_vecs,
                                   lig_eigen_vecs_inv.T), axis=0)
        rec_eigs = np.concatenate((rec_eigen_vals.reshape(1, -1),
                                   rec_eigen_vecs,
                                   rec_eigen_vecs_inv.T), axis=0)
        
        # vertex normals
        lig_vnormals = igl.per_vertex_normals(lig_verts, lig_faces)
        rec_vnormals = igl.per_vertex_normals(rec_verts, rec_faces)

        # compute gradient operator
        lig_grad_op, lig_grad_basis = compute_gradient_operator(lig_verts, lig_faces, lig_vnormals)
        rec_grad_op, rec_grad_basis = compute_gradient_operator(rec_verts, rec_faces, rec_vnormals)

        # features for model
        out_fpath = os.path.join(dst, src[src.rfind('/')+1:])
        np.savez(out_fpath, # ligand
                            map_lig2rec=map_lig2rec.astype(np.float32),
                            lig_atom_coords=lig_atom_coords.astype(np.float32),
                            lig_atom_feats=lig_atom_feats.astype(np.float32),
                            lig_verts=lig_verts.astype(np.float32),
                            lig_edge = lig_edge,
                            lig_vnormals=lig_vnormals.astype(np.float32),
                            lig_geom_feats=lig_geom_feats.astype(np.float32),
                            lig_eigs=lig_eigs.astype(np.float32),
                            lig_grad_op=lig_grad_op.astype(np.csingle),
                            lig_grad_basis=lig_grad_basis.astype(np.float32),
                            # receptor
                            map_rec2lig=map_rec2lig.astype(np.float32),
                            rec_atom_coords=rec_atom_coords.astype(np.float32),
                            rec_atom_feats=rec_atom_feats.astype(np.float32),
                            rec_verts=rec_verts.astype(np.float32),
                            rec_edge = rec_edge,
                            rec_vnormals=rec_vnormals.astype(np.float32),
                            rec_geom_feats=rec_geom_feats.astype(np.float32),
                            rec_eigs=rec_eigs.astype(np.float32),
                            rec_grad_op=rec_grad_op.astype(np.csingle),
                            rec_grad_basis=rec_grad_basis.astype(np.float32))  
           

if __name__ == "__main__":
    config = get_config()

    # mute tqdm for production runs
    if not config.unmute_tqdm:
        tqdm.__init__ = partialmethod(tqdm.__init__, disable=True)

    processor = DataProcessor(config)

    processor.preprocess()


