import os
import torch
import numpy as np
import horovod.torch as hvd
from sklearn.neighbors import BallTree
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import RandomSampler
from torch.utils.data.distributed import DistributedSampler

def collate_fnc(unbatched_list):
    # ligand
    b_lig_iface_p2p = []
    b_lig_vnormals = []
    b_lig_chem_feats = []
    b_lig_geom_feats = []
    b_lig_eigs = []
    b_lig_grad_op = []
    b_lig_grad_basis = []
    b_lig_num_verts = []
    b_lig_edge = []
    # receptor
    b_rec_iface_p2p = []
    b_rec_vnormals = []
    b_rec_chem_feats = []
    b_rec_geom_feats = []
    b_rec_eigs = []
    b_rec_grad_op = []
    b_rec_grad_basis = []
    b_rec_num_verts = []
    b_rec_edge = []

    total_lig_num = 0
    total_rec_num = 0

    # file path
    b_fpath = []
    for (lig_dict, rec_dict, fpath) in unbatched_list:
        # ligand
        b_lig_iface_p2p.append(lig_dict['iface_p2p'])
        b_lig_vnormals.append(lig_dict['vnormals'])
        b_lig_chem_feats.append(lig_dict['chem_feats'])
        b_lig_geom_feats.append(lig_dict['geom_feats'])
        b_lig_eigs.append(lig_dict['eigs'])
        b_lig_grad_op.append(lig_dict['grad_op'])
        b_lig_grad_basis.append(lig_dict['grad_basis'])
        lig_num_verts = lig_dict['geom_feats'].size(0)
        b_lig_num_verts.append(lig_num_verts)
        b_lig_edge.append(lig_dict['edge'] + total_lig_num)
        total_lig_num +=lig_num_verts

        # receptor
        b_rec_iface_p2p.append(rec_dict['iface_p2p'])
        b_rec_vnormals.append(rec_dict['vnormals'])
        b_rec_chem_feats.append(rec_dict['chem_feats'])
        b_rec_geom_feats.append(rec_dict['geom_feats'])
        b_rec_eigs.append(rec_dict['eigs'])
        b_rec_grad_op.append(rec_dict['grad_op'])
        b_rec_grad_basis.append(rec_dict['grad_basis'])
        rec_num_verts = rec_dict['geom_feats'].size(0)
        b_rec_num_verts.append(rec_num_verts)
        b_rec_edge.append(rec_dict['edge'] + total_rec_num)
        total_rec_num +=rec_num_verts
        # file path
        b_fpath.append(fpath)
    
    # batched ligand features
    b_lig_edge = torch.cat(b_lig_edge,dim = 0).transpose(0,1)
    b_rec_edge = torch.cat(b_rec_edge,dim = 0).transpose(0,1)

    batched_lig_dict = {
        'iface_p2p': b_lig_iface_p2p,
        'vnormals': torch.cat(b_lig_vnormals, dim=0),
        'chem_feats': torch.cat(b_lig_chem_feats, dim=0),
        'geom_feats': torch.cat(b_lig_geom_feats, dim=0),
        'edge':b_lig_edge,
        'eigs': b_lig_eigs,
        'grad_op': b_lig_grad_op,
        'grad_basis': torch.cat(b_lig_grad_basis, dim=0),
        'num_verts': b_lig_num_verts,
    }

    # batched receptor features
    batched_rec_dict = {
        'iface_p2p': b_rec_iface_p2p,
        'vnormals': torch.cat(b_rec_vnormals, dim=0),
        'chem_feats': torch.cat(b_rec_chem_feats, dim=0),
        'geom_feats': torch.cat(b_rec_geom_feats, dim=0),
        'edge':b_rec_edge,
        'eigs': b_rec_eigs,
        'grad_op': b_rec_grad_op,
        'grad_basis': torch.cat(b_rec_grad_basis, dim=0),
        'num_verts':b_rec_num_verts,
    }

    return {
        'lig_dict': batched_lig_dict,
        'rec_dict': batched_rec_dict,
        'batch_fpath': b_fpath,
    }

def batch_to_device(batch):
    # loop over batch dict
    for key, val in batch.items():
        if key in ('batch_fpath'):
            continue
        else:
            assert isinstance(val, dict)
            for ikey, ival in val.items():
                if ikey in ('num_verts'):
                    continue
                if isinstance(ival, list):
                    ival = [v.cuda() for v in ival]
                else:
                    ival = ival.cuda()
                val[ikey] = ival
            batch[key] = val
    return batch


def get_dataloaders(config, dataset_name):
    # load dataset
    train_dataset = PPDockDataset(config, split='train', dataset_name=dataset_name)
    valid_dataset = PPDockDataset(config, split='valid', dataset_name=dataset_name)
    test_dataset = PPDockDataset(config, split='test', dataset_name=dataset_name)

    if config.serial:
        train_sampler = RandomSampler(train_dataset.fpaths)
        valid_sampler = RandomSampler(valid_dataset.fpaths)
        test_sampler = RandomSampler(test_dataset.fpaths)
    else:
        train_sampler = DistributedSampler(train_dataset, num_replicas=hvd.size(), rank=hvd.rank())
        valid_sampler = DistributedSampler(valid_dataset, num_replicas=hvd.size(), rank=hvd.rank())
        test_sampler = DistributedSampler(test_dataset, num_replicas=hvd.size(), rank=hvd.rank())
        
    train_loader = DataLoader(train_dataset,
                              batch_size=config.batch_size,
                              sampler=train_sampler,
                              collate_fn=collate_fnc,
                              num_workers=config.num_data_workers,
                              pin_memory=True)
    valid_loader = DataLoader(valid_dataset,
                              batch_size=config.batch_size,
                              sampler=valid_sampler,
                              collate_fn=collate_fnc,
                              num_workers=config.num_data_workers,
                              pin_memory=True)
    test_loader = DataLoader(test_dataset,
                             batch_size=config.batch_size,
                             sampler=test_sampler,
                             collate_fn=collate_fnc,
                             num_workers=config.num_data_workers,
                             pin_memory=True)

    return train_loader, valid_loader, test_loader, train_sampler


class PPDockDataset(Dataset):
    def __init__(self, config, split, dataset_name):
        # dataset
        self.cv = config.cv
        self.swap_pairs = config.swap_pairs and split != 'test'
        # features
        self.vert_nbr_atoms = config.vert_nbr_atoms
        self.gauss_curv_gdf = GaussianDistance(start=-0.1, stop=0.1, num_centers=config.num_gdf)
        self.mean_curv_gdf = GaussianDistance(start=-0.5, stop=0.5, num_centers=config.num_gdf)
        self.dist_gdf = GaussianDistance(start=0., stop=8., num_centers=config.num_gdf)
        self.angular_gdf = GaussianDistance(start=-1., stop=1., num_centers=config.num_gdf)
        
        # prefilter and preprocess
        self.processed_dir = './processed_data/'
        if not os.path.isdir(self.processed_dir):
            raise Exception('Cannot find processed data, run data processor first')
            
        self.fpaths = [] 
        # train-valid-test split
        processed_dir = os.path.join(self.processed_dir, dataset_name.upper())
        fpath_list = [os.path.join(processed_dir, fname) for fname in os.listdir(processed_dir)]
        if dataset_name == 'rcsb':
            if split in ['train', 'valid']:
                split_file = f'./misc/rcsb_{split}_30_cv{self.cv}.txt'
            else:
                processed_dir = os.path.join(self.processed_dir, 'DB5')
                fpath_list = [os.path.join(processed_dir, fname) for fname in os.listdir(processed_dir)]
                split_file = './misc/db5_all.txt'
        else:
            split_file = f'./misc/db5_{split}_cv{self.cv}.txt'

        with open(split_file, 'r') as f:
            pdb_ids = [pid.strip() for pid in f.readlines()]
        for fpath in fpath_list:
            pdb_id = fpath[fpath.rfind('/')+1:fpath.rfind('/')+5]
            if pdb_id in pdb_ids:
                self.fpaths.append(fpath)
        print(f'size of {dataset_name.upper()} {split} set: {len(self.fpaths)}')                

    def __getitem__(self, idx):
        # load data
        fpath = self.fpaths[idx]
        data = np.load(fpath, allow_pickle=True)

        ''' 
        Neighboring atomic environment information is too large to store on HDD,
        we do it on-the-fly for the moment
        '''

        lig, rec = 'lig', 'rec'
        if self.swap_pairs and np.random.binomial(n=1, p=0.5):
            lig, rec = 'rec', 'lig'

        # interface and point-to-point correspondence
        map_lig2rec = data[f'map_{lig}2{rec}'].astype(int)
        map_rec2lig = data[f'map_{rec}2{lig}'].astype(int)

        # ligand chemical features
        lig_verts = data[f'{lig}_verts']
        lig_vnormals = data[f'{lig}_vnormals']
        lig_atom_coords = data[f'{lig}_atom_coords']
        lig_atom_feats = data[f'{lig}_atom_feats']
        lig_bt = BallTree(lig_atom_coords)
        lig_dist, lig_ind = lig_bt.query(lig_verts, k=self.vert_nbr_atoms)
        lig_dist = np.vstack(lig_dist)
        lig_ind = np.vstack(lig_ind)
        lig_nbr_dist_gdf = self.dist_gdf.expand(lig_dist)
        lig_nbr_vec = lig_atom_coords[lig_ind] - lig_verts.reshape(-1, 1, 3)
        lig_nbr_angular = np.einsum('vkj,vj->vk', 
                                    lig_nbr_vec / np.linalg.norm(lig_nbr_vec, axis=-1, keepdims=True), 
                                    lig_vnormals)
        lig_nbr_angular_gdf = self.angular_gdf.expand(lig_nbr_angular)
        # (num_verts, vert_nbr_atoms, 6 + 2*num_gdf)
        lig_chem_feats = np.concatenate((lig_atom_feats[lig_ind],
                                         lig_nbr_dist_gdf,
                                         lig_nbr_angular_gdf), axis=-1)
        # receptor chemical features
        rec_verts = data[f'{rec}_verts']
        rec_vnormals = data[f'{rec}_vnormals']
        rec_atom_coords = data[f'{rec}_atom_coords']
        rec_atom_feats = data[f'{rec}_atom_feats']
        rec_bt = BallTree(rec_atom_coords)
        rec_dist, rec_ind = rec_bt.query(rec_verts, k=self.vert_nbr_atoms)
        rec_dist = np.vstack(rec_dist)
        rec_ind = np.vstack(rec_ind)
        rec_nbr_dist_gdf = self.dist_gdf.expand(rec_dist)
        rec_nbr_vec = rec_atom_coords[rec_ind] - rec_verts.reshape(-1, 1, 3)
        rec_nbr_angular = np.einsum('vkj,vj->vk', 
                                    rec_nbr_vec / np.linalg.norm(rec_nbr_vec, axis=-1, keepdims=True), 
                                    rec_vnormals)
        rec_nbr_angular_gdf = self.angular_gdf.expand(rec_nbr_angular)
        # (num_verts, vert_nbr_atoms, 6 + 2*num_gdf)
        rec_chem_feats = np.concatenate((rec_atom_feats[rec_ind],
                                         rec_nbr_dist_gdf,
                                         rec_nbr_angular_gdf), axis=-1)

        # ligand geometric features
        lig_geom_feats_in = data[f'{lig}_geom_feats']
        lig_gauss_curvs = lig_geom_feats_in[:, 0]
        lig_gauss_curvs_gdf = self.gauss_curv_gdf.expand(lig_gauss_curvs)
        lig_mean_curvs = lig_geom_feats_in[:, 1]
        lig_mean_curvs_gdf = self.mean_curv_gdf.expand(lig_mean_curvs)       
        lig_geom_feats = np.concatenate((lig_gauss_curvs_gdf,
                                         lig_mean_curvs_gdf,
                                         lig_geom_feats_in[:, 2:]), axis=-1)
        # receptor geometric features
        rec_geom_feats_in = data[f'{rec}_geom_feats']
        rec_gauss_curvs = rec_geom_feats_in[:, 0]
        rec_gauss_curvs_gdf = self.gauss_curv_gdf.expand(rec_gauss_curvs)
        rec_mean_curvs = rec_geom_feats_in[:, 1]
        rec_mean_curvs_gdf = self.mean_curv_gdf.expand(rec_mean_curvs)
        rec_geom_feats = np.concatenate((rec_gauss_curvs_gdf,
                                         rec_mean_curvs_gdf,
                                         rec_geom_feats_in[:, 2:]), axis=-1)

        # Laplace-Beltrami basis
        lig_eigs = data[f'{lig}_eigs']
        rec_eigs = data[f'{rec}_eigs']

        # directional gradient
        lig_grad_op = data[f'{lig}_grad_op'].item().tocoo()
        lig_grad_op_dense = np.concatenate((lig_grad_op.data.real.reshape(1, -1), 
                                            lig_grad_op.data.imag.reshape(1, -1),
                                            lig_grad_op.row.reshape(1, -1), 
                                            lig_grad_op.col.reshape(1, -1)), axis=0)
        lig_grad_basis = data[f'{lig}_grad_basis']
        rec_grad_op = data[f'{rec}_grad_op'].item().tocoo()
        rec_grad_op_dense = np.concatenate((rec_grad_op.data.real.reshape(1, -1), 
                                            rec_grad_op.data.imag.reshape(1, -1),
                                            rec_grad_op.row.reshape(1, -1), 
                                            rec_grad_op.col.reshape(1, -1)), axis=0)
        rec_grad_basis = data[f'{rec}_grad_basis']
        lig_edge = data[f'{lig}_edge']
        rec_edge = data[f'{rec}_edge']
        # ligand features
        lig_dict = {
            'iface_p2p': torch.tensor(map_lig2rec, dtype=torch.int64),
            'vnormals': torch.tensor(lig_vnormals, dtype=torch.float32),
            'chem_feats': torch.tensor(lig_chem_feats, dtype=torch.float32),
            'geom_feats': torch.tensor(lig_geom_feats, dtype=torch.float32),
            'edge':torch.tensor(lig_edge,dtype=torch.long),
            'eigs': torch.tensor(lig_eigs, dtype=torch.float32),
            'grad_op': torch.tensor(lig_grad_op_dense, dtype=torch.float32),
            'grad_basis': torch.tensor(lig_grad_basis, dtype=torch.float32),
        }

        # receptor features
        rec_dict = {
            'iface_p2p': torch.tensor(map_rec2lig, dtype=torch.int64),
            'vnormals': torch.tensor(rec_vnormals, dtype=torch.float32),
            'chem_feats': torch.tensor(rec_chem_feats, dtype=torch.float32),
            'geom_feats': torch.tensor(rec_geom_feats, dtype=torch.float32),
            'edge':torch.tensor(rec_edge,dtype=torch.long),
            'eigs': torch.tensor(rec_eigs, dtype=torch.float32),
            'grad_op': torch.tensor(rec_grad_op_dense, dtype=torch.float32),
            'grad_basis': torch.tensor(rec_grad_basis, dtype=torch.float32),
        }

        return lig_dict, rec_dict, fpath


    def __len__(self):
        return len(self.fpaths)


class GaussianDistance(object):
    def __init__(self, start, stop, num_centers):
        self.filters = np.linspace(start, stop, num_centers)
        self.var = (stop - start) / (num_centers - 1)

    def expand(self, d):
        return np.exp(-0.5 * (d[..., None] - self.filters)**2 / self.var**2)


