import torch
import numpy as np
import time
from ogb.lsc import PCQM4MDataset
import random
from typing import Tuple, Sequence, List, Dict
import os
import pickle

bond_tokens = ('.', '-', '=', '#', ':')
digits = tuple([str(i) for i in range(1, 10)])
aromatic_tokens = (
    'c', 'n', 'o', 'nH', 's', '13cH', '13c', 'b', 'p', 'pH', '14cH',
    'se', 'n+', '14c', 'c-', 'n-', '15nH', 'cH-', 'bH'
)
non_atom_tokens = bond_tokens + digits + ('(', ')')
filtered_tokens = ('/', '\\', '0')

def split_smiles(input: str) -> List[str]:
    r"""
    Split SMILES string into words.
    For example:
        'CCClBr' -> ['C', 'C', 'Cl', 'Br']
    """
    l = list(input.strip())
    i = 0
    special_atom_tokens = ('Cl', 'Br')
    while i < len(l):
        if i + 1 < len(l) and (l[i] + l[i + 1]) in special_atom_tokens:
            l[i] = l[i] + l[i + 1]
            del l[i + 1]
        i += 1
    return l


class SmilesParser:
    r"""
    Normalizer for SMILES strings.
    """
    def __init__(self, max_atoms: int = 80):
        self.max_atoms = max_atoms

    def __call__(self, s: str) -> Tuple[List[str], List[str], List[Tuple[int, int]]]:
        r"""
        Stage 1: remove all decorate tokens and split SMILES strings into words.

        Note: Temporarily remove '/' and '\', decorators for double bonds.
        """
        res = []
        #print(dataset[i])
        s1 = s.split('[')
        for s2 in s1:
            s3 = s2.split(']')
            if len(s3) == 2:
                res = res + [s3[0]] + split_smiles(s3[1])
            else:
                res = res + split_smiles(s3[0])
        res = list(filter(lambda tok: tok not in filtered_tokens, res))
        #print(res)
        r"""
        Stage 2: obtain atom & bond tokens and positional embeddings.
        """
        atoms = [res[0]]
        bonds, bond_index = [], []

        atom_cnt = 1
        pt = 1
        prev_tk = res[0]
        prev_id = 1

        digits_prev_id = dict.fromkeys(digits)
        digits_prev_is_aromatic = dict.fromkeys(digits)
        brackets_stack = []
        while pt < len(res):
            tk = res[pt] 
            if tk not in non_atom_tokens:
                atom_cnt += 1
                atoms.append(tk)
                if prev_tk in aromatic_tokens and tk in aromatic_tokens:
                    bonds.append(':')
                else:
                    bonds.append('-')
                bond_index.append((prev_id, atom_cnt))
                
                pt += 1
                prev_tk = tk
                prev_id = atom_cnt
                continue

            if tk in bond_tokens:
                next_tk = res[pt + 1]
                if res[pt + 1] in non_atom_tokens:
                    assert next_tk in digits
                    if digits_prev_id[next_tk] is None:
                        pt += 1
                        continue
                    bonds.append(tk)
                    bond_index.append((prev_id, digits_prev_id[next_tk]))
                else:
                    bonds.append(tk)
                    atom_cnt += 1
                    atoms.append(next_tk)
                    bond_index.append((prev_id, atom_cnt))

                    prev_tk = next_tk
                    prev_id = atom_cnt
                
                pt += 2
                continue

            if tk in digits:
                if digits_prev_id[tk] is None:
                    digits_prev_id[tk] = prev_id
                    digits_prev_is_aromatic[tk] = prev_tk in aromatic_tokens
                else:
                    if digits_prev_is_aromatic[tk] and prev_tk in aromatic_tokens:
                        bonds.append(':')
                    else:
                        bonds.append('-')
                    bond_index.append((prev_id, digits_prev_id[tk]))
                pt += 1
                continue

            if tk == '(':
                brackets_stack.append((prev_id, prev_tk))
                pt += 1
                continue

            if tk == ')':
                prev_id, prev_tk = brackets_stack.pop()
                pt += 1
                continue
            
            raise ValueError('unknown token')
        
        #atoms_pe = [(j, 0) for j in range(1, len(atoms) + 1)]
        return atoms, bonds, bond_index
    

def get_vocabulary(smiles_data: Sequence[str]) -> Tuple[List[str], Dict[str, int]]:
    r"""
    Get ordered token vocabulary from list of smiles strings.
    """
    parser = SmilesParser()
    
    vocab = set([])
    vocab_cnt = dict()
    for s in smiles_data:
        atoms, bonds, _ = parser(s)
        tk_seq = atoms + bonds
        for tk in tk_seq:
            if tk in vocab:
                vocab_cnt[tk] += 1
            else:
                vocab.add(tk)
                vocab_cnt[tk] = 1
    
    # Sort vocabulary by frequency.
    sorted_vocab = sorted(list(vocab), key=lambda x: vocab_cnt[x], reverse=True)
    return sorted_vocab, vocab_cnt

def get_tokenizers(smiles_data: Sequence[str], dataset_name: str = None) -> Tuple[List[str], Dict[str, int]]:
    r"""
    Get tokenizers (id2tk and tk2id) from list of smiles strings.
    """
    print('Generating tokenizers for SMILES data...')

    id2tk_name, tk2id_name = None, None
    if dataset_name is not None:
        id2tk_name = './saved_tokenizers/{}_id2tk.pickle'.format(dataset_name)
        tk2id_name = './saved_tokenizers/{}_tk2id.pickle'.format(dataset_name)
        if os.path.exists(id2tk_name):
            id2tk = pickle.load(open(id2tk_name, 'rb'))
            tk2id = pickle.load(open(tk2id_name, 'rb'))
            return id2tk, tk2id

    # starting from ogb-lsc tokenizers
    id2tk = pickle.load(open('./saved_tokenizers/ogb_lsc_id2tk.pickle', 'rb'))
    tk2id = pickle.load(open('./saved_tokenizers/ogb_lsc_tk2id.pickle', 'rb'))
    sz = len(id2tk)

    sorted_vocab, vocab_cnt = get_vocabulary(smiles_data)

    for tk in sorted_vocab:
        if tk not in id2tk:
            print('add token {}'.format(tk))
            tk2id[tk] = sz - 1
            id2tk.append(tk)
    #id2tk = ['<s>', '<pad>'] + sorted_vocab
    #tk2id = dict((id2tk[i], i) for i in range(len(id2tk)))
    if dataset_name is not None:
        pickle.dump(id2tk, open(id2tk_name, 'wb'))
        pickle.dump(tk2id, open(tk2id_name, 'wb'))
    return id2tk, tk2id

def get_maxlen(smiles_data: Sequence[str]) -> int:
    parser = SmilesParser()
    maxlen = 0
    for s in smiles_data:
        atoms, bonds, _ = parser(s)
        length = len(atoms) + len(bonds)
        maxlen = max(length, maxlen)
    return maxlen