import os
from collections import deque
from selfies.exceptions import SMILESParserError
from selfies.mol_graph import MolecularGraph, Attribution
import torch
from selfies.utils.smiles_utils import (
    tokenize_smiles, SMILESTokenTypes, smiles_to_atom, smiles_to_bond, SMILES_STEREO_BONDS
)


class SmilesTokenizer:
    def __init__(self, vocab_file: str = ''):
        if not os.path.isfile(vocab_file):
            raise ValueError(
                "Can't find a vocab file at path '{}'.".format(vocab_file))
        self.vocab = self.load_vocab(vocab_file)

    @staticmethod
    def load_vocab(vocab_file):

        vocab = {}
        with open(vocab_file, "r", encoding="utf-8") as reader:
            tokens = reader.readlines()
        for index, token in enumerate(tokens):
            token = token.rstrip("\n")
            vocab[token] = index
        return vocab

    def batch_encode(self, smiles_batch, max_len):
        inputs = {}
        all_token_ids = []
        atoms = []
        batch_bond_types = []
        batch_bond_edges = []
        batch_alignment = []
        for smiles in smiles_batch:
            mol, tokens = self.smiles_to_mol(smiles, False)
            elements, bond_edges, bond_types, alignment = self.get_graph(mol)
            atoms.append(elements)
            batch_bond_edges.append(bond_edges)
            batch_bond_types.append(bond_types)
            batch_alignment.append(alignment)
            token_ids = []
            for t in tokens:
                if t in self.vocab:
                    token_ids.append(self.vocab[t])
                else:
                    token_ids.append(self.vocab["[UNK]"])
            all_token_ids.append(token_ids)
        padded_sequences, attention_masks = self.pad_sequences_with_attention_masks(all_token_ids, max_len)
        inputs['input_ids'] = torch.tensor(padded_sequences)
        inputs['attention_mask'] = torch.tensor(attention_masks)
        inputs['bond_edges'] = batch_bond_edges
        inputs['bond_types'] = batch_bond_types
        inputs['alignments'] = batch_alignment
        inputs['atoms'] = atoms
        return inputs

    def smiles_to_mol(self, smiles: str, attributable: bool):
        if smiles == "":
            raise SMILESParserError(smiles, "empty SMILES", 0)

        mol = MolecularGraph(attributable=attributable)
        tokens = deque(tokenize_smiles(smiles))
        i = 0
        tok_idx = -1
        tokenized_words = []
        while tokens:
            i = self.derive_mol_from_tokens(mol, smiles, tokens, tokenized_words, tok_idx, i)
        return mol, tokenized_words

    @staticmethod
    def get_graph(mol):
        atoms = mol.get_atoms()
        elements = [a.element for a in atoms]
        alignment = {}
        for i, a in enumerate(atoms):
            alignment[i] = a.tok_idx
        bonds = mol._adj_list
        bond_edges = []
        bond_types = []
        for i, bs in enumerate(bonds):
            for b in bs:
                bond_edges.append([i, b.dst])
                bond_edges.append([b.dst, i])
                bond_types.append(b.order)
                bond_types.append(b.order)
        return elements, bond_edges, bond_types, alignment

    def derive_mol_from_tokens(self, mol, smiles, tokens, words, tok_idx, i):
        tok = None
        prev_stack = deque()  # keep track of previous atom on the current chain
        branch_stack = deque()  # keep track of open branches
        ring_log = dict()  # keep track of hanging ring numbers
        chain_start = True

        prev_stack.append(tok)
        while tokens:
            tok = tokens.popleft()
            bond_char = tok.extract_bond_char(smiles)
            symbol, symbol_type = tok.extract_symbol(smiles), tok.token_type
            if bond_char:
                words.append(bond_char)
                tok_idx += 1
            words.append(symbol)
            tok_idx += 1
            prev_atom = prev_stack[-1]

            if symbol_type == SMILESTokenTypes.DOT:
                break

            elif symbol_type == SMILESTokenTypes.ATOM:
                curr = smiles_to_atom(symbol)
                if curr is None:
                    err_msg = "invalid atom symbol '{}'".format(symbol)
                    raise SMILESParserError(smiles, err_msg, tok.start_idx)
                curr.start_idx = tok.start_idx
                curr.end_idx = tok.end_idx
                curr.tok_idx = tok_idx
                curr, i = self.attach_atom(mol, bond_char, curr, prev_atom, i, tok)
                prev_stack.pop()
                prev_stack.append(curr)
                chain_start = False

            elif chain_start:
                err_msg = "SMILES chain begins with non-atom"
                raise SMILESParserError(smiles, err_msg, tok.start_idx)

            elif symbol_type == SMILESTokenTypes.BRANCH:
                if symbol == "(":
                    branch_stack.append(tok)
                    prev_stack.append(prev_atom)
                    chain_start = True
                else:
                    if not branch_stack:
                        err_msg = "hanging ')' bracket"
                        raise SMILESParserError(smiles, err_msg, tok.start_idx)
                    branch_stack.pop()
                    prev_stack.pop()

            elif symbol_type == SMILESTokenTypes.RING:
                if symbol not in ring_log:
                    lpos = mol.add_placeholder_bond(src=prev_atom.index)
                    ring_log[symbol] = (tok, prev_atom, lpos)
                else:
                    ltoken, latom, lpos = ring_log.pop(symbol)
                    self.make_ring_bonds(
                        mol=mol, smiles=smiles,
                        ltoken=ltoken, latom=latom, lpos=lpos,
                        rtoken=tok, ratom=prev_atom
                    )

            else:
                # should not happen
                raise Exception("invalid symbol type")
            i += 1

        if len(mol) == 0:
            err_idx = (len(smiles) if (tok is None) else tok.start_idx) - 1
            raise SMILESParserError(smiles, "empty SMILES fragment", err_idx)

        if branch_stack:
            err_idx = branch_stack[-1].start_idx
            raise SMILESParserError(smiles, "hanging '(' bracket", err_idx)

        if ring_log:
            rnum, (tok, _, _) = list(ring_log.items())[-1]
            err_msg = "hanging ring number '{}'".format(rnum)
            raise SMILESParserError(smiles, err_msg, tok.start_idx)
        return i

    @staticmethod
    def attach_atom(mol, bond_char, atom, prev_atom, i, tok):
        is_root = (prev_atom is None)
        if bond_char:
            i += 1
        o = mol.add_atom(atom, mark_root=is_root)
        mol.add_attribution(o, [Attribution(i, str(tok))])
        if not is_root:
            src, dst = prev_atom.index, atom.index
            order, stereo = smiles_to_bond(bond_char)
            if prev_atom.is_aromatic and atom.is_aromatic and (bond_char is None):
                order = 1.5  # handle implicit aromatic bonds, e.g. cc
            o = mol.add_bond(src=src, dst=dst, order=order, stereo=stereo)
            mol.add_attribution(o, [Attribution(i, str(tok))])
        return atom, i

    @staticmethod
    def make_ring_bonds(mol, smiles, ltoken, latom, lpos, rtoken, ratom):
        if mol.has_bond(latom.index, ratom.index):
            err_msg = "ring bond specified between already-bonded atoms"
            raise SMILESParserError(smiles, err_msg, ltoken.start_idx)

        lbond_char = ltoken.extract_bond_char(smiles)
        rbond_char = rtoken.extract_bond_char(smiles)

        # checking that ring bonds match
        bonds = (lbond_char, rbond_char)
        if bonds[0] is None:
            bonds = (bonds[1], bonds[0])
        # swap bonds so that if bonds[0] is None, then bonds[1] is None

        if ((bonds[0] == bonds[1])
                or (bonds[1] is None)
                or all(x in SMILES_STEREO_BONDS for x in bonds)):
            pass
        else:
            err_msg = "mismatched ring bonds"
            raise SMILESParserError(smiles, err_msg, ltoken.start_idx)

        lorder, lstereo = smiles_to_bond(lbond_char)
        rorder, rstereo = smiles_to_bond(rbond_char)
        if latom.is_aromatic and ratom.is_aromatic and (bonds == (None, None)):
            lorder = rorder = 1.5  # handle implicit aromatic bonds, e.g. c1ccccc1

        mol.add_ring_bond(
            a=latom.index, a_stereo=lstereo, a_pos=lpos,
            b=ratom.index, b_stereo=rstereo,
            order=max(lorder, rorder)
        )

    @staticmethod
    def pad_sequences_with_attention_masks(sequences, max_len=None, padding_value=0):
        # If max_len is not provided, find the maximum length of sequences
        if max_len is None:
            max_len = max(len(seq) for seq in sequences)

        padded_sequences = []
        attention_masks = []

        for seq in sequences:
            seq_len = len(seq)
            padding_length = max_len - seq_len

            # Create padded sequence
            if len(seq) < max_len:
                padded_seq = seq + [padding_value] * padding_length
            else:
                padded_seq = seq[:max_len]
            padded_sequences.append(padded_seq)

            # Create attention mask
            if seq_len < max_len:
                attention_mask = [1] * seq_len + [0] * padding_length
            else:
                attention_mask = [1] * max_len
            attention_masks.append(attention_mask)

        return padded_sequences, attention_masks
