import os.path as osp
import h5py
import numpy as np
import warnings
from tqdm import tqdm

import torch 
import torch.nn.functional as F

from torch_geometric.data import Data
from torch_geometric.data import InMemoryDataset
from scipy.spatial import ConvexHull
import sys
from torch_geometric.nn import radius_graph
sys.path.append('/root/workspace/UnitSphere/')
from mol_unit_sphere import Frame
from torch.utils.data import DataLoader

class FOLDdataset(InMemoryDataset):
    def __init__(self,
                 root,
                 transform=None,
                 pre_transform=None,
                 pre_filter=None,
                 split='train',
                 n_node=-1,
                 ct_lst=[50, 100, 200, 300, 400, 600, 1000, 10000]
                ):

        self.split_orig = split

        if split != 'training' and split != 'validation':
            self.split = split + '_{}'.format(int(ct_lst[n_node]))
        else:
            self.split = split
        self.root = root
        self.n_node = n_node
        self.ct_lst = ct_lst
        super(FOLDdataset, self).__init__(
            root, transform, pre_transform, pre_filter)
        
        self.transform, self.pre_transform, self.pre_filter = transform, pre_transform, pre_filter
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def processed_dir(self):
        name = 'processed'
        return osp.join(self.root, name, self.split)
        
    @property
    def raw_file_names(self):
        name = self.split + '.txt'
        return name

    @property
    def processed_file_names(self):
        return 'data.pt'


    def _normalize(self,tensor, dim=-1):
        '''
        Normalizes a `torch.Tensor` along dimension `dim` without `nan`s.
        '''
        return torch.nan_to_num(
            torch.div(tensor, torch.norm(tensor, dim=dim, keepdim=True)))

    def get_atom_pos(self, amino_types, atom_names, atom_amino_id, atom_pos):
        # atoms to compute side chain torsion angles: N, CA, CB, _G/_G1, _D/_D1, _E/_E1, _Z, NH1
        mask_n = np.char.equal(atom_names, b'N')
        mask_ca = np.char.equal(atom_names, b'CA')
        mask_c = np.char.equal(atom_names, b'C')
        mask_cb = np.char.equal(atom_names, b'CB')
        mask_g = np.char.equal(atom_names, b'CG') | np.char.equal(atom_names, b'SG') | np.char.equal(atom_names, b'OG') | np.char.equal(atom_names, b'CG1') | np.char.equal(atom_names, b'OG1')
        mask_d = np.char.equal(atom_names, b'CD') | np.char.equal(atom_names, b'SD') | np.char.equal(atom_names, b'CD1') | np.char.equal(atom_names, b'OD1') | np.char.equal(atom_names, b'ND1')
        mask_e = np.char.equal(atom_names, b'CE') | np.char.equal(atom_names, b'NE') | np.char.equal(atom_names, b'OE1')
        mask_z = np.char.equal(atom_names, b'CZ') | np.char.equal(atom_names, b'NZ')
        mask_h = np.char.equal(atom_names, b'NH1')
        import pdb; pdb.set_trace()
        pos_n = np.full((len(amino_types),3),np.nan)
        pos_n[atom_amino_id[mask_n]] = atom_pos[mask_n]
        pos_n = torch.FloatTensor(pos_n)

        pos_ca = np.full((len(amino_types),3),np.nan)
        pos_ca[atom_amino_id[mask_ca]] = atom_pos[mask_ca]
        pos_ca = torch.FloatTensor(pos_ca)

        pos_c = np.full((len(amino_types),3),np.nan)
        pos_c[atom_amino_id[mask_c]] = atom_pos[mask_c]
        pos_c = torch.FloatTensor(pos_c)

        # if data only contain pos_ca, we set the position of C and N as the position of CA
        pos_n[torch.isnan(pos_n)] = pos_ca[torch.isnan(pos_n)]
        pos_c[torch.isnan(pos_c)] = pos_ca[torch.isnan(pos_c)]

        pos_cb = np.full((len(amino_types),3),np.nan)
        pos_cb[atom_amino_id[mask_cb]] = atom_pos[mask_cb]
        pos_cb = torch.FloatTensor(pos_cb)

        pos_g = np.full((len(amino_types),3),np.nan)
        pos_g[atom_amino_id[mask_g]] = atom_pos[mask_g]
        pos_g = torch.FloatTensor(pos_g)

        pos_d = np.full((len(amino_types),3),np.nan)
        pos_d[atom_amino_id[mask_d]] = atom_pos[mask_d]
        pos_d = torch.FloatTensor(pos_d)

        pos_e = np.full((len(amino_types),3),np.nan)
        pos_e[atom_amino_id[mask_e]] = atom_pos[mask_e]
        pos_e = torch.FloatTensor(pos_e)

        pos_z = np.full((len(amino_types),3),np.nan)
        pos_z[atom_amino_id[mask_z]] = atom_pos[mask_z]
        pos_z = torch.FloatTensor(pos_z)

        pos_h = np.full((len(amino_types),3),np.nan)
        pos_h[atom_amino_id[mask_h]] = atom_pos[mask_h]
        pos_h = torch.FloatTensor(pos_h)

        return pos_n, pos_ca, pos_c, pos_cb, pos_g, pos_d, pos_e, pos_z, pos_h


    def side_chain_embs(self, pos_n, pos_ca, pos_c, pos_cb, pos_g, pos_d, pos_e, pos_z, pos_h):
        v1, v2, v3, v4, v5, v6, v7 = pos_ca - pos_n, pos_cb - pos_ca, pos_g - pos_cb, pos_d - pos_g, pos_e - pos_d, pos_z - pos_e, pos_h - pos_z

        # five side chain torsion angles
        # We only consider the first four torsion angles in side chains since only the amino acid arginine has five side chain torsion angles, and the fifth angle is close to 0.
        angle1 = torch.unsqueeze(self.compute_dihedrals(v1, v2, v3),1)
        angle2 = torch.unsqueeze(self.compute_dihedrals(v2, v3, v4),1)
        angle3 = torch.unsqueeze(self.compute_dihedrals(v3, v4, v5),1)
        angle4 = torch.unsqueeze(self.compute_dihedrals(v4, v5, v6),1)
        angle5 = torch.unsqueeze(self.compute_dihedrals(v5, v6, v7),1)

        side_chain_angles = torch.cat((angle1, angle2, angle3, angle4),1)
        side_chain_embs = torch.cat((torch.sin(side_chain_angles), torch.cos(side_chain_angles)),1)
        
        return side_chain_embs

    
    def bb_embs(self, X):   
        # X should be a num_residues x 3 x 3, order N, C-alpha, and C atoms of each residue
        # N coords: X[:,0,:]
        # CA coords: X[:,1,:]
        # C coords: X[:,2,:]
        # return num_residues x 6 
        # From https://github.com/jingraham/neurips19-graph-protein-design
        
        X = torch.reshape(X, [3 * X.shape[0], 3])
        dX = X[1:] - X[:-1]
        U = self._normalize(dX, dim=-1)
        u0 = U[:-2]
        u1 = U[1:-1]
        u2 = U[2:]

        angle = self.compute_dihedrals(u0, u1, u2)
        
        # add phi[0], psi[-1], omega[-1] with value 0
        angle = F.pad(angle, [1, 2]) 
        angle = torch.reshape(angle, [-1, 3])
        angle_features = torch.cat([torch.cos(angle), torch.sin(angle)], 1)
        return angle_features

    
    def compute_dihedrals(self, v1, v2, v3):
        n1 = torch.cross(v1, v2)
        n2 = torch.cross(v2, v3)
        a = (n1 * n2).sum(dim=-1)
        b = torch.nan_to_num((torch.cross(n1, n2) * v2).sum(dim=-1) / v2.norm(dim=1))
        torsion = torch.nan_to_num(torch.atan2(b, a))
        return torsion
    
    
    def protein_to_graph(self, pFilePath):
        h5File = h5py.File(pFilePath, "r")
        data = Data()

        amino_types = h5File['amino_types'][()] # size: (n_amino,)
        mask = amino_types == -1
        if np.sum(mask) > 0:
            amino_types[mask] = 25 # for amino acid types, set the value of -1 to 25
        atom_amino_id = h5File['atom_amino_id'][()] # size: (n_atom,)
        atom_names = h5File['atom_names'][()] # size: (n_atom,)
        atom_pos = h5File['atom_pos'][()][0] #size: (n_atom,3)
        
        # atoms to compute side chain torsion angles: N, CA, CB, _G/_G1, _D/_D1, _E/_E1, _Z, NH1
        pos_n, pos_ca, pos_c, pos_cb, pos_g, pos_d, pos_e, pos_z, pos_h = self.get_atom_pos(amino_types, atom_names, atom_amino_id, atom_pos)
        
        # five side chain torsion angles
        # We only consider the first four torsion angles in side chains since only the amino acid arginine has five side chain torsion angles, and the fifth angle is close to 0.
        side_chain_embs = self.side_chain_embs(pos_n, pos_ca, pos_c, pos_cb, pos_g, pos_d, pos_e, pos_z, pos_h)
        side_chain_embs[torch.isnan(side_chain_embs)] = 0
        data.side_chain_embs = side_chain_embs

        # three backbone torsion angles
        bb_embs = self.bb_embs(torch.cat((torch.unsqueeze(pos_n,1), torch.unsqueeze(pos_ca,1), torch.unsqueeze(pos_c,1)),1))
        bb_embs[torch.isnan(bb_embs)] = 0
        data.bb_embs = bb_embs

        data.x = torch.unsqueeze(torch.tensor(amino_types),1)
        data.coords_ca = pos_ca
        data.coords_n = pos_n
        data.coords_c = pos_c

        assert len(data.x)==len(data.coords_ca)==len(data.coords_n)==len(data.coords_c)==len(data.side_chain_embs)==len(data.bb_embs)

        h5File.close()
        return data
    
    def merge_cutoff_ch(self, edge_index_hull, curProtein, cutoff):
        pos = curProtein.coords_ca
        edge_index_cut = radius_graph(torch.tensor(pos), 
                                        r=cutoff, 
                                        max_num_neighbors=32)
        edge_index_cut = edge_index_cut.tolist()
        edge_cut_ch_dict = {}
        for k in range(len(pos)):
            edge_cut_ch_dict[k] = []
        for k in range(len(edge_index_hull[0])):
            edge_cut_ch_dict[edge_index_hull[0][k]].append(edge_index_hull[1][k])
        for k in range(len(edge_index_cut[0])):
            if edge_index_cut[1][k] in edge_cut_ch_dict[edge_index_cut[0][k]]:
                continue
            else:
                edge_cut_ch_dict[edge_index_cut[0][k]].append(edge_index_cut[1][k])
        for k in range(len(pos)):
            edge_cut_ch_dict[k].sort()

        edge_index_cut_ch = [[], []]
        for key in edge_cut_ch_dict:
            lst_ = edge_cut_ch_dict[key]
            for j in range(len(lst_)):
                edge_index_cut_ch[0].append(lst_[j])
                edge_index_cut_ch[1].append(key)

        return edge_index_cut_ch, edge_index_cut
    
    def process(self):
        print('Beginning Processing with n_node>={}...'.format(self.n_node))

        # Load the file with the list of functions.
        classes_ = {}
        with open(self.root+"/class_map.txt", 'r') as mFile:
            for line in mFile:
                lineList = line.rstrip().split('\t')
                classes_[lineList[0]] = int(lineList[1])
        print(self.split)

        # Get the file list.
        fileList_ = []
        cathegories_ = []
        with open(self.root+"/"+self.split_orig+".txt", 'r') as mFile:
            for curLine in mFile:
                splitLine = curLine.rstrip().split('\t')
                curClass = classes_[splitLine[-1]]
                fileList_.append(self.root+"/"+self.split_orig+"/"+splitLine[0])
                cathegories_.append(curClass)

        # Load the dataset
        print("Reading the data")
        frame = Frame()
        ct_arr = np.zeros(len(self.ct_lst))
        
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            data_list = []
            for fileIter, curFile in tqdm(enumerate(fileList_)):

                fileName = curFile.split('/')[-1]
                curProtein = self.protein_to_graph(curFile+".hdf5") 
                
                curProtein.y = torch.tensor(cathegories_[fileIter])
                curProtein.id = fileName

                is_continue = False
                if self.split == 'training' or self.split == 'validation':
                    for k in range(len(self.ct_lst)):
                        if len(curProtein.coords_ca) < self.ct_lst[k]:
                            if np.random.rand() < 1:
                                ct_arr[k] += 1
                                break
                            else:
                                is_continue = True
                                break
                else:
                    if self.n_node > 0:
                        if len(curProtein.coords_ca) >= self.ct_lst[self.n_node-1] and len(curProtein.coords_ca) < self.ct_lst[self.n_node]:
                            ct_arr[self.n_node] += 1
                        else:
                            is_continue = True                      
                    else:
                        if len(curProtein.coords_ca) < self.ct_lst[self.n_node]:
                            ct_arr[self.n_node] += 1
                        else:
                            is_continue = True     

                if is_continue:
                    continue

                pos = curProtein.coords_ca
                _, shell_data_ch, edge_index_hull = frame.get_frame(pos.numpy())
                ch_pos = torch.tensor(shell_data_ch, dtype=torch.float)
                ch_r = torch.norm(ch_pos - torch.mean(ch_pos, dim=0), dim=-1)

                curProtein['ch_pos'] = ch_pos
                curProtein['ch_r'] = ch_r
 
                # cut_lst = [4, 5, 6, 7, 8, 10]
                cut_lst = [6, 7, 8, 10]
                for cutoff in cut_lst:
                    edge_index_cut_ch, _ = self.merge_cutoff_ch(edge_index_hull,
                                                                curProtein,
                                                                cutoff=cutoff)
                    curProtein['ch_cut_{}_edge_index'.format(cutoff)] = torch.tensor(edge_index_cut_ch, 
                                                                                     dtype=torch.int)
                curProtein.ch_edge_index = torch.tensor(edge_index_hull, dtype=torch.int)
                
                if not curProtein.x is None:
                    data_list.append(curProtein) 
        # print(data_list)
        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])
        print("ct_lst: {}".format(self.ct_lst))
        print("ct_arr: {}".format(ct_arr))
        print('Done!')