import json
import pickle
import warnings
from collections import defaultdict
import math

import smart_open
import numpy as np
from rdkit import Chem
import torch
from tqdm import tqdm
from joblib import Parallel, delayed

from coarsebind_public.mol_encoder.util.periodic_table import ATOMIC_NUM_X, ATOMIC_NUM_Y
from coarsebind_public.mol_encoder.util.s3.s3_io import cache_read


def remove_H_isotope(mol):

    mol = Chem.AddHs(mol)

    # Iterate over atoms to reset isotopes of hydrogen atoms
    for atom in mol.GetAtoms():
        if atom.GetAtomicNum() == 1:  # Atomic number for hydrogen
            atom.SetIsotope(0)  # Set isotope to 0 to remove isotopic label

    # return mol
    return mol


def dense_to_sparse(dense_edges):
    """
    Edges here are packed as batch X maxnodes*(maxnodes-1)/2
    """
    n = int(math.sqrt(2 * (dense_edges.shape[-1]))) + 1
    tmp = torch.nonzero(dense_edges > 0)
    Is = tmp[:, 0]
    ks = tmp[:, 1]
    Js = n - 2 - torch.floor(torch.sqrt(-8 * ks + 4 * n * (n - 1) - 7) / 2.0 - 0.5)
    Ks = ks + Js + 1 - n * (n - 1) / 2 + (n - Js) * ((n - Js) - 1) / 2
    return (
        Is.to(torch.long),
        Js.to(torch.long),
        Ks.to(torch.long),
        dense_edges[Is, ks].to(torch.long),
    )


def sparse_to_dense(nodes, edge_labels, edges):
    # edges are always sorted.
    n = len(nodes)
    tore = -2 * np.ones(int(n * (n - 1) / 2))
    for I, edge in enumerate(edges):
        i, j = sorted(edge)
        k = int((n * (n - 1) / 2) - (n - i) * ((n - i) - 1) / 2 + j - i - 1)
        tore[k] = edge_labels[I]
    return nodes, tore


class MissingNodeException(Exception):
    def __init__(self, node):
        self.node = node


class MissingEdgeException(Exception):
    def __init__(self, node):
        self.node = node


class GraphTokenizer:
    ATOM_BYTES = [
        # req
        "AtomicNum",
        # optional
        "FormalCharge",
        "ChiralTag",
        "NumExplicitHs",
        "NumRadicalElectrons",
        "Isotope",
    ]
    ATOM_TYPE_MAP = {
        "ChiralTag": Chem.ChiralType.values,
    }
    BOND_BYTES = [
        # req
        "BondType",
        # optional
        "BondDir",  # this handles double bond stereochemistry information in a kind of roundabout way.
    ]
    BOND_TYPE_MAP = {
        "BondType": Chem.BondType.values,
        "BondDir": Chem.BondDir.values,
    }
    CATEGORY_1_1_VALS = [
        -1,
        1,
        2,
        3,
        4,
        5,
        6,
        7,
        8,
        9,
        10,
        11,
        12,
        13,
        14,
        15,
        16,
        17,
        18,
    ]
    CATEGORY_1_2_VALS = [-1, 1, 2, 3, 4, 5, 6, 7, 9, 10]
    CATEGORY_2_VALS = [0, 1, 2, 3, 4, -2, -5, -4, -3, -1]
    CATEGORY_3_VALS = [0, 1, 2, 6, 7, 8]
    CATEGORY_4_VALS = [0, 1]
    CATEGORY_5_VALS = [0, 1, 2, 3]
    CATEGORY_6_VALS = [
        0,
        1,
        2,
        3,
        131,
        10,
        11,
        12,
        13,
        14,
        15,
        16,
        17,
        18,
        19,
        32,
        35,
        68,
        76,
        77,
        80,
        211,
        99,
        123,
        124,
        125,
        127,
    ]

    def __init__(
        self,
        vocab_path: str,
        chiral=True,
    ):
        self._vocab_is_set = False
        self.vocab_path = vocab_path
        self.chiral = chiral

        self.default_atom_bytes = {}
        self.default_bond_bytes = {}
        if not self.chiral:
            for b_name in ["ChiralTag"]:
                if b_name in self.ATOM_BYTES:
                    self.default_atom_bytes[self.ATOM_BYTES.index(b_name)] = 0

            for b_name in ["BondDir"]:
                if b_name in self.BOND_BYTES:
                    self.default_bond_bytes[self.BOND_BYTES.index(b_name)] = 0
        return

    @classmethod
    def sparse_from_atom(cls, atom):
        to_return = [
            (
                int(getattr(atom, f"Get{a_b}")())
                if a_b in cls.ATOM_TYPE_MAP
                else getattr(atom, f"Get{a_b}")()
            )
            for a_b in cls.ATOM_BYTES
        ]

        # handle aromatic amine edge case
        if (
            atom.GetSymbol() == "N"
            and atom.GetIsAromatic()
            and atom.GetTotalNumHs(includeNeighbors=True) > 0
        ):
            to_return[3] = 1

        return to_return

    @classmethod
    def sparse_from_bond(cls, bond):
        return [int(getattr(bond, f"Get{b_b}")()) for b_b in cls.BOND_BYTES]

    @classmethod
    def sparse_bonds(cls, mol):
        return [sorted([b.GetBeginAtomIdx(), b.GetEndAtomIdx()]) for b in mol.GetBonds()]

    @classmethod
    def mol_coords(cls, mol):
        return mol.GetConformer().GetPositions().tolist()

    @classmethod
    def sparse_from_mol(cls, mol, return_coords=False):

        base_ret = [
            [cls.sparse_from_atom(a) for a in mol.GetAtoms()],
            [cls.sparse_from_bond(b) for b in mol.GetBonds()],
            cls.sparse_bonds(mol),
        ]

        if return_coords:
            base_ret.append(cls.mol_coords(mol))

        return tuple(base_ret)

    @classmethod
    def sparse_to_mol(cls, sparse_atoms, sparse_bonds, bonds, coords=None):
        mol = Chem.EditableMol(Chem.Mol())
        for atom_idx, sparse_atom in enumerate(sparse_atoms):
            atom = Chem.Atom(sparse_atom[0])

            for i, prop_name in enumerate(cls.ATOM_BYTES):
                if prop_name not in {"AtomicNum"}:
                    # cursed
                    getattr(atom, f"Set{prop_name}")(
                        cls.ATOM_TYPE_MAP[prop_name][sparse_atom[i]]
                        if prop_name in cls.ATOM_TYPE_MAP
                        else sparse_atom[i]
                    )

            new_atom_idx = mol.AddAtom(atom)
            assert new_atom_idx == atom_idx, "new atom idx did not match idx of sparse atom!!!"

        for i, sparse_bond in enumerate(sparse_bonds):
            mol.AddBond(bonds[i][0], bonds[i][1], Chem.BondType.values[sparse_bond[0]])

        mol = mol.GetMol()
        for i, sparse_bond in enumerate(sparse_bonds):
            bond = mol.GetBondBetweenAtoms(bonds[i][0], bonds[i][1])
            for i, prop_name in enumerate(cls.BOND_BYTES):
                if prop_name not in {"BondType"}:
                    getattr(bond, f"Set{prop_name}")(
                        cls.BOND_TYPE_MAP[prop_name][sparse_bond[i]]
                        if prop_name in cls.BOND_TYPE_MAP
                        else sparse_bond[i]
                    )

        mol.UpdatePropertyCache(strict=False)
        Chem.AssignStereochemistry(mol, cleanIt=False, force=True)

        # NOTE: could assign stereochem from 3d conf here, but then
        # returned mol wouldn't represent the tokens.
        if coords:
            assert len(coords) == len(sparse_atoms), "coords array length must match atom length"

            num_atoms = mol.GetNumAtoms()
            conformer = Chem.Conformer(num_atoms)
            # Set the coordinates for each atom
            for i in range(num_atoms):
                conformer.SetAtomPosition(i, coords[i])
            mol.AddConformer(conformer)

        return mol

    def get_atom_tok(self, atom):
        atom_sparse = self.sparse_from_atom(atom)
        atom_type = self.atom_vocab[tuple(atom_sparse)]
        return atom_type

    def get_bond_tok(self, bond):
        bond_type_sparse = self.sparse_from_bond(bond)
        bond_type = self.bond_vocab[tuple(bond_type_sparse)]
        return bond_type

    def discrete_from_mol(self, mol, return_coords=False, ignore_light=False):
        """
        like tokens_from_mol but returns atom X 6 integers 0 => n cat for the atoms.
        """
        self.set_vocab()
        if return_coords:
            atom_dat_, bond_dat, bonds, coords = self.sparse_from_mol(mol)
        else:
            atom_dat_, bond_dat, bonds = self.sparse_from_mol(mol)
        atom_dat = []
        for a in atom_dat_:
            if a[0] == 0:
                # then this is a pad atom!
                atom_dat.append([0 for _ in range(7)])
            elif a[0] == 100:
                # then this is a mask atom!
                # but I haven't coded how to manage it!
                atom_dat.append([0 for _ in range(7)])
            else:
                atom_dat.append(
                    [
                        self.CATEGORY_1_1_VALS.index(ATOMIC_NUM_X(a[0])),
                        self.CATEGORY_1_2_VALS.index(ATOMIC_NUM_Y(a[0])),
                        self.CATEGORY_2_VALS.index(a[1]),
                        self.CATEGORY_3_VALS.index(a[2]),
                        self.CATEGORY_4_VALS.index(a[3]),
                        self.CATEGORY_5_VALS.index(a[4]),
                        self.CATEGORY_6_VALS.index(a[5]),
                    ]
                )
        if self.default_bond_bytes:
            for sparse_bond in bond_dat:
                for k in self.default_bond_bytes:
                    sparse_bond[k] = self.default_bond_bytes[k]
        bond_toks = []
        for sparse_bond in bond_dat:
            if (tok := self.bond_vocab.get(tuple(sparse_bond), None)) is not None:
                bond_toks.append(tok)
            else:
                raise MissingEdgeException(sparse_bond)
                # raise RuntimeError(f"cant tokenize mol, No token for bond like {sparse_bond}, {self.BOND_BYTES}")

        if ignore_light:
            atom_dat, bond_toks, bonds = self.remove_light(atom_dat, bond_toks, bonds)
        if return_coords:
            return atom_dat, bond_toks, bonds, coords
        else:
            return atom_dat, bond_toks, bonds

    def set_vocab(self):
        """
        MUST have [pad] = 0 and [mask] = 1
        As a hack I patch that here if there are tokens keyd below 2
        """
        if not self._vocab_is_set:
            with cache_read(self.vocab_path, "rb") as f_in:
                vocab = pickle.load(f_in)

                # remap 0=> pad 1=>Mask
                self.inv_atom_vocab = {v: k for k, v in vocab["atom_vocab"].items()}
                self.inv_bond_vocab = {v: k for k, v in vocab["bond_vocab"].items()}

                if 0 in self.inv_atom_vocab:
                    new_key = max(self.inv_atom_vocab.keys()) + 1
                    self.inv_atom_vocab[new_key] = self.inv_atom_vocab.pop(0)
                if 1 in self.inv_atom_vocab:
                    new_key = max(self.inv_atom_vocab.keys()) + 1
                    self.inv_atom_vocab[new_key] = self.inv_atom_vocab.pop(1)
                if 0 in self.inv_bond_vocab:
                    new_key = max(self.inv_bond_vocab.keys()) + 1
                    self.inv_bond_vocab[new_key] = self.inv_bond_vocab.pop(0)
                if 1 in self.inv_bond_vocab:
                    new_key = max(self.inv_bond_vocab.keys()) + 1
                    self.inv_bond_vocab[new_key] = self.inv_bond_vocab.pop(1)

                self.atom_vocab = AtomVocab({v: k for k, v in self.inv_atom_vocab.items()})
                self.bond_vocab = {v: k for k, v in self.inv_bond_vocab.items()}

                # self.atom_vocab = AtomVocab(vocab["atom_vocab"])
                # self.bond_vocab = vocab["bond_vocab"]
                # self.inv_atom_vocab = {v: k for k, v in self.atom_vocab.items()}
                # self.inv_bond_vocab = {v: k for k, v in self.bond_vocab.items()}

                if vocab["exclusive_idx"] and len(
                    shared_indices := set(self.inv_atom_vocab).intersection(
                        set(self.inv_bond_vocab)
                    )
                ):
                    raise ValueError(shared_indices)

            self._vocab_is_set = True

    def remove_light(self, atom_toks, bond_toks, bonds):
        """Removes all hydrogen tokens from the graph."""
        # get indices of all hydrogens

        light_token_idx = self.atom_vocab[(1, 0, 0, 0, 0, 0)]
        deut_token_idx = self.atom_vocab[(1, 0, 0, 0, 0, 2)]
        trit_token_idx = self.atom_vocab[(1, 0, 0, 0, 0, 3)]
        light_idxs = set([i for i, x in enumerate(atom_toks) if x == light_token_idx])
        deut_idxs = set([i for i, x in enumerate(atom_toks) if x == deut_token_idx])
        trit_idxs = set([i for i, x in enumerate(atom_toks) if x == trit_token_idx])

        light_idxs = light_idxs.union(deut_idxs).union(trit_idxs)

        # keep only heavy
        atoks_noh = [x for i, x in enumerate(atom_toks) if i not in light_idxs]
        btoks_noh = []
        bonds_noh = []
        for etok, bond in zip(bond_toks, bonds):
            if (bond[0] not in light_idxs) and (bond[1] not in light_idxs):
                btoks_noh.append(etok)
                bonds_noh.append(bond)
        return atoks_noh, btoks_noh, bonds_noh

    def tokens_from_mol(self, mol, return_coords=False, ignore_light=False):
        self.set_vocab()
        if return_coords:
            atom_dat, bond_dat, bonds, coords = self.sparse_from_mol(mol)
        else:
            atom_dat, bond_dat, bonds = self.sparse_from_mol(mol)

        if self.default_atom_bytes:
            for sparse_atom in atom_dat:
                for k in self.default_atom_bytes:
                    sparse_atom[k] = self.default_atom_bytes[k]

        if self.default_bond_bytes:
            for sparse_bond in bond_dat:
                for k in self.default_bond_bytes:
                    sparse_bond[k] = self.default_bond_bytes[k]

        atom_toks = []
        for atom_idx, sparse_atom in enumerate(atom_dat):
            if (tok := self.atom_vocab.get(tuple(sparse_atom), None)) is not None:
                atom_toks.append(tok)
            else:
                # print('atom vocab', self.atom_vocab)
                raise MissingNodeException(sparse_atom)
                # raise RuntimeError(f"cant tokenize mol-  No token for atom like {sparse_atom}, {self.ATOM_BYTES}")

        bond_toks = []
        for sparse_bond in bond_dat:
            if (tok := self.bond_vocab.get(tuple(sparse_bond), None)) is not None:
                bond_toks.append(tok)
            else:
                raise MissingEdgeException(sparse_bond)
                # raise RuntimeError(f"cant tokenize mol, No token for bond like {sparse_bond}, {self.BOND_BYTES}")

        if ignore_light:
            atom_toks, bond_toks, bonds = self.remove_light(atom_toks, bond_toks, bonds)

        if return_coords:
            return atom_toks, bond_toks, bonds, coords
        else:
            return atom_toks, bond_toks, bonds

    def tokens_to_mol(self, atom_toks, bond_toks, bonds, coords=None):
        self.set_vocab()

        atom_data = []
        bond_data = []
        failed_mol = False
        for tok in atom_toks:
            if tok in self.inv_atom_vocab:
                atom_data.append(self.inv_atom_vocab[tok])
            else:
                warnings.warn(f"No atom representation for token: {tok}")
                failed_mol = True
        for tok in bond_toks:
            if tok in self.inv_bond_vocab:
                bond_data.append(self.inv_bond_vocab[tok])
            else:
                warnings.warn(f"No bond representation for token: {tok}")
                failed_mol = True

        if not failed_mol:
            return self.sparse_to_mol(atom_data, bond_data, bonds, coords=coords)


class AtomVocab(dict):
    def __getitem__(self, key_tuple):
        if key_tuple in self:
            return super().__getitem__(key_tuple)
        elif (alt_key := (key_tuple[0],) + (0,) * (len(key_tuple) - 1)) in self:
            # default to just atomic num with nothing special
            return super().__getitem__(alt_key)
        else:
            return None
