import torch
import torch.nn as nn
import torch.nn.functional as F

from layers.hyp_layers import LorentzLinear, LorentzCentroidDistance
from .mol_tree import Vocab, MolTree
from .nnutils import create_var, flatten_tensor, avg_pool
from .hyperbolic_jtnn_enc import HyperbolicJTNNEncoder
from .hyperbolic_jtnn_dec import HyperbolicJTNNDecoder
from .hyperbolic_mpn import HyperbolicMPN
from .hyperbolic_jtmpn import HyperbolicJTMPN
from .hyperbolic_embedding import HyperbolicEmbedding
from .datautils import HyperbolicTensorize
import manifolds

from .chemutils import enum_assemble, set_atommap, copy_edit_mol, attach_mols
import rdkit
import rdkit.Chem as Chem
import copy, math

class HyperbolicJTNNAE(nn.Module):

    def __init__(self, args, vocab):
        super(HyperbolicJTNNAE, self).__init__()
        self.vocab = vocab
        self.hidden_size = args.dim
        self.latent_size = args.latent_dim
        self.manifold = getattr(manifolds, args.manifold)()

        self.jtnn = HyperbolicJTNNEncoder(args, HyperbolicEmbedding(vocab.size(), self.hidden_size, self.hidden_size, self.manifold))
        self.decoder = HyperbolicJTNNDecoder(args, vocab, HyperbolicEmbedding(vocab.size(), self.hidden_size, self.hidden_size, self.manifold))

        self.jtmpn = HyperbolicJTMPN(args)
        self.mpn = HyperbolicMPN(args)

        # self.A_assm = LorentzLinear(self.manifold, self.latent_size, self.hidden_size, bias=False, dropout=args.dropout)
        self.A = LorentzLinear(
                in_features = self.hidden_size + self.latent_size - 1, 
                out_features = self.hidden_size,
                manifold = self.manifold,
                bias = args.bias,
                dropout = args.dropout
            )
        self.A_o = LorentzCentroidDistance(
                dim = self.hidden_size, 
                n_classes = 1,
                manifold = self.manifold,
                bias = args.bias
            )
        self.assm_loss = nn.CrossEntropyLoss(reduction = 'sum')

        self.T = LorentzLinear(self.manifold, self.hidden_size, self.latent_size, bias=args.bias, dropout=args.dropout)
        self.G = LorentzLinear(self.manifold, self.hidden_size, self.latent_size, bias=args.bias, dropout=args.dropout)

        # self.T_mean = nn.Linear(self.hidden_size, self.latent_size)
        # self.T_var = nn.Linear(self.hidden_size, self.latent_size)
        # self.G_mean = nn.Linear(self.hidden_size, self.latent_size)
        # self.G_var = nn.Linear(self.hidden_size, self.latent_size)

    def encode(self, jtenc_holder, mpn_holder):
        tree_vecs, tree_mess = self.jtnn(*jtenc_holder)
        mol_vecs = self.mpn(*mpn_holder)
        return tree_vecs, tree_mess, mol_vecs
    
    def encode_from_smiles(self, smiles_list):
        tree_batch = [MolTree(s) for s in smiles_list]
        _, jtenc_holder, mpn_holder = HyperbolicTensorize(tree_batch, self.vocab, assm=False)
        tree_vecs, _, mol_vecs = self.encode(jtenc_holder, mpn_holder)
        return torch.cat([tree_vecs, mol_vecs], dim=-1)

    # def encode_latent(self, jtenc_holder, mpn_holder):
    #     tree_vecs, _ = self.jtnn(*jtenc_holder)
    #     mol_vecs = self.mpn(*mpn_holder)
    #     tree_mean = self.T_mean(tree_vecs)
    #     mol_mean = self.G_mean(mol_vecs)
    #     tree_var = -torch.abs(self.T_var(tree_vecs))
    #     mol_var = -torch.abs(self.G_var(mol_vecs))
    #     return torch.cat([tree_mean, mol_mean], dim=1), torch.cat([tree_var, mol_var], dim=1)

    # def rsample(self, z_vecs, W_mean, W_var):
    #     batch_size = z_vecs.size(0)
    #     z_mean = W_mean(z_vecs)
    #     z_log_var = -torch.abs(W_var(z_vecs)) #Following Mueller et al.
    #     kl_loss = -0.5 * torch.sum(1.0 + z_log_var - z_mean * z_mean - torch.exp(z_log_var)) / batch_size
    #     epsilon = create_var(torch.randn_like(z_mean))
    #     z_vecs = z_mean + torch.exp(z_log_var / 2) * epsilon
    #     return z_vecs, kl_loss

    # def sample_prior(self, prob_decode=False):
    #     z_tree = torch.randn(1, self.latent_size).cuda()
    #     z_mol = torch.randn(1, self.latent_size).cuda()
    #     return self.decode(z_tree, z_mol, prob_decode)

    def forward(self, x_batch):
        x_batch, x_jtenc_holder, x_mpn_holder, x_jtmpn_holder = x_batch
        x_tree_vecs, x_tree_mess, x_mol_vecs = self.encode(x_jtenc_holder, x_mpn_holder)
        z_tree_vecs = self.T(x_tree_vecs)
        z_mol_vecs = self.G(x_mol_vecs)
        # z_tree_vecs = (x_tree_vecs)
        # z_mol_vecs = (x_mol_vecs)
        # z_tree_vecs,tree_kl = self.rsample(x_tree_vecs, self.T_mean, self.T_var)
        # z_mol_vecs,mol_kl = self.rsample(x_mol_vecs, self.G_mean, self.G_var)

        # kl_div = tree_kl + mol_kl
        word_loss, topo_loss, word_acc, topo_acc = self.decoder(x_batch, z_tree_vecs)
        assm_loss, assm_acc = self.assm(x_batch, x_jtmpn_holder, z_mol_vecs, x_tree_mess)
        assm_loss *= 10

        # print(word_loss, topo_loss, assm_loss)

        return word_loss, topo_loss, assm_loss, word_acc, topo_acc, assm_acc

    def assm(self, mol_batch, jtmpn_holder, x_mol_vecs, x_tree_mess):
        jtmpn_holder, batch_idx = jtmpn_holder
        adj, features, scope = jtmpn_holder
        batch_idx = create_var(batch_idx)

        cand_vecs = self.jtmpn(adj, features, x_tree_mess, scope)

        x_mol_vecs = x_mol_vecs.index_select(0, batch_idx)
        # print(cand_vecs.shape)
        # print(x_mol_vecs.shape)
        inp = self.manifold.Concat(cand_vecs, x_mol_vecs)
        h = self.A(inp)
        scores = self.A_o(h).squeeze()
        # print(scores.shape)
        # x_mol_vecs = self.A_assm(x_mol_vecs) #bilinear
        # scores = self.manifold.dist(x_mol_vecs, cand_vecs).squeeze()
        
        cnt,tot,acc = 0,0,0
        all_loss = []
        for i,mol_tree in enumerate(mol_batch):
            comp_nodes = [node for node in mol_tree.nodes if len(node.cands) > 1 and not node.is_leaf]
            cnt += len(comp_nodes)
            for node in comp_nodes:
                label = node.cands.index(node.label)
                ncand = len(node.cands)
                cur_score = scores.narrow(0, tot, ncand)
                tot += ncand

                if cur_score.data[label] >= cur_score.max().item():
                    acc += 1

                label = create_var(torch.LongTensor([label]))
                all_loss.append( self.assm_loss(cur_score.view(1,-1), label) )
        
        all_loss = sum(all_loss) / len(mol_batch)
        return all_loss, acc * 1.0 / cnt

    def decode(self, x_tree_vecs, x_mol_vecs, prob_decode):
        #currently do not support batch decoding
        assert x_tree_vecs.size(0) == 1 and x_mol_vecs.size(0) == 1

        x_tree_vecs = self.T(x_tree_vecs)
        x_mol_vecs = self.G(x_mol_vecs)

        pred_root,pred_nodes = self.decoder.decode(x_tree_vecs, prob_decode)
        if len(pred_nodes) == 0: return None
        elif len(pred_nodes) == 1: return pred_root.smiles

        #Mark nid & is_leaf & atommap
        for i,node in enumerate(pred_nodes):
            node.nid = i + 1
            node.is_leaf = (len(node.neighbors) == 1)
            if len(node.neighbors) > 1:
                set_atommap(node.mol, node.nid)

        scope = [(0, len(pred_nodes))]
        jtenc_holder, adj = HyperbolicJTNNEncoder.tensorize_nodes(pred_nodes, scope)
        _,tree_mess = self.jtnn(*jtenc_holder)
        tree_mess = (tree_mess, adj) 

        # x_mol_vecs = self.A_assm(x_mol_vecs).squeeze() #bilinear

        cur_mol = copy_edit_mol(pred_root.mol)
        global_amap = [{}] + [{} for node in pred_nodes]
        global_amap[1] = {atom.GetIdx():atom.GetIdx() for atom in cur_mol.GetAtoms()}

        cur_mol,_ = self.dfs_assemble(tree_mess, x_mol_vecs, pred_nodes, cur_mol, global_amap, [], pred_root, None, prob_decode, check_aroma=True)
        if cur_mol is None: 
            cur_mol = copy_edit_mol(pred_root.mol)
            global_amap = [{}] + [{} for node in pred_nodes]
            global_amap[1] = {atom.GetIdx():atom.GetIdx() for atom in cur_mol.GetAtoms()}
            cur_mol,pre_mol = self.dfs_assemble(tree_mess, x_mol_vecs, pred_nodes, cur_mol, global_amap, [], pred_root, None, prob_decode, check_aroma=False)
            if cur_mol is None: cur_mol = pre_mol

        if cur_mol is None: 
            return None

        cur_mol = cur_mol.GetMol()
        set_atommap(cur_mol)
        cur_mol = Chem.MolFromSmiles(Chem.MolToSmiles(cur_mol))
        return Chem.MolToSmiles(cur_mol) if cur_mol is not None else None
        
    def dfs_assemble(self, y_tree_mess, x_mol_vecs, all_nodes, cur_mol, global_amap, fa_amap, cur_node, fa_node, prob_decode, check_aroma):
        fa_nid = fa_node.nid if fa_node is not None else -1
        prev_nodes = [fa_node] if fa_node is not None else []

        children = [nei for nei in cur_node.neighbors if nei.nid != fa_nid]
        neighbors = [nei for nei in children if nei.mol.GetNumAtoms() > 1]
        neighbors = sorted(neighbors, key=lambda x:x.mol.GetNumAtoms(), reverse=True)
        singletons = [nei for nei in children if nei.mol.GetNumAtoms() == 1]
        neighbors = singletons + neighbors

        cur_amap = [(fa_nid,a2,a1) for nid,a1,a2 in fa_amap if nid == cur_node.nid]
        cands,aroma_score = enum_assemble(cur_node, neighbors, prev_nodes, cur_amap)
        if len(cands) == 0 or (sum(aroma_score) < 0 and check_aroma):
            return None, cur_mol

        cand_smiles,cand_amap = list(zip(*cands))
        aroma_score = torch.Tensor(aroma_score).cuda()
        cands = [(smiles, all_nodes, cur_node) for smiles in cand_smiles]

        if len(cands) > 1:
            jtmpn_holder = HyperbolicJTMPN.tensorize(cands, y_tree_mess[1])
            adj, features, scope = jtmpn_holder
            cand_vecs = self.jtmpn(adj, features, y_tree_mess[0], scope)

            inp = self.manifold.Concat(cand_vecs, x_mol_vecs)
            h = self.A(inp)
            scores = self.A_o(h).squeeze() + aroma_score
            # scores = self.manifold.dist(x_mol_vecs, cand_vecs) + aroma_score
        else:
            scores = torch.Tensor([1.0])

        if prob_decode:
            probs = F.softmax(scores.view(1,-1), dim=1).squeeze() + 1e-7 #prevent prob = 0
            cand_idx = torch.multinomial(probs, probs.numel())
        else:
            _,cand_idx = torch.sort(scores, descending=True)

        backup_mol = Chem.RWMol(cur_mol)
        pre_mol = cur_mol
        for i in range(cand_idx.numel()):
            cur_mol = Chem.RWMol(backup_mol)
            pred_amap = cand_amap[cand_idx[i].item()]
            new_global_amap = copy.deepcopy(global_amap)

            for nei_id,ctr_atom,nei_atom in pred_amap:
                if nei_id == fa_nid:
                    continue
                new_global_amap[nei_id][nei_atom] = new_global_amap[cur_node.nid][ctr_atom]

            cur_mol = attach_mols(cur_mol, children, [], new_global_amap) #father is already attached
            new_mol = cur_mol.GetMol()
            new_mol = Chem.MolFromSmiles(Chem.MolToSmiles(new_mol))

            if new_mol is None: continue
            
            has_error = False
            for nei_node in children:
                if nei_node.is_leaf: continue
                tmp_mol, tmp_mol2 = self.dfs_assemble(y_tree_mess, x_mol_vecs, all_nodes, cur_mol, new_global_amap, pred_amap, nei_node, cur_node, prob_decode, check_aroma)
                if tmp_mol is None: 
                    has_error = True
                    if i == 0: pre_mol = tmp_mol2
                    break
                cur_mol = tmp_mol

            if not has_error: return cur_mol, cur_mol

        return None, pre_mol

