from collections import defaultdict
from functools import partial
import json
import logging
import os
import pickle
import uuid

import numpy as np
from matplotlib.pyplot import MultipleLocator
import torch
import seaborn as sns
from IPython.core.display import display, HTML, Javascript
from torch.utils.data import Dataset
# from bertviz.util import format_attention, num_layers
import pkg_resources
from tqdm import tqdm
from simpletransformers.classification.classification_utils import preprocess_data_multiprocessing, preprocess_data
from multiprocessing import Pool
from typing import List, Tuple
from tqdm.auto import tqdm
from transformers import PreTrainedTokenizer
from rdkit import Chem
from rdkit.Chem.Draw import SimilarityMaps
from matplotlib import pyplot as plt
import pandas as pd
from rxnfp.tokenization import RegexTokenizer

logger = logging.getLogger(__name__)

BAD_TOKS = ["[CLS]", "[SEP]"]  # Default Bad Tokens
CONDITION_TYPE = ['c1', 's1', 's2', 'r1', 'r2']

def canonicalize_smiles(smi, clear_map=False):
    if pd.isna(smi):
        return ''
    mol = Chem.MolFromSmiles(smi)
    if mol is not None:
        if clear_map:
            [atom.ClearProp('molAtomMapNumber') for atom in mol.GetAtoms()]
        return Chem.MolToSmiles(mol)
    else:
        return ''



def caonicalize_rxn_smiles(rxn_smiles):
    try:
        react, _, prod = rxn_smiles.split('>')
        react, prod = [canonicalize_smiles(x) for x in [react, prod]]
        if '' in [react, prod]:
            return ''
        return f'{react}>>{prod}'
    except:
        return ''

def build_classification_dataset(
    data, tokenizer, args, mode, multi_label, output_mode, no_cache
):
    cached_features_file = os.path.join(
        args.cache_dir,
        "cached_{}_{}_{}_{}_{}".format(
            mode,
            args.model_type,
            args.max_seq_length,
            len(args.labels_list),
            len(data),
        ),
    )

    if os.path.exists(cached_features_file) and (
        (not args.reprocess_input_data and not args.no_cache)
        or (mode == "dev" and args.use_cached_eval_features and not args.no_cache)
    ):
        data = torch.load(cached_features_file)
        logger.info(f" Features loaded from cache at {cached_features_file}")
        examples, labels = data
    else:
        logger.info(" Converting to features started. Cache is not used.")

        if len(data) == 3:
            # Sentence pair task
            text_a, text_b, labels = data
        else:
            text_a, labels = data
            text_b = None

        # If labels_map is defined, then labels need to be replaced with ints
        if args.labels_map and not args.regression:
            if multi_label:
                labels = [[args.labels_map[l] for l in label]
                          for label in labels]
            else:
                labels = [args.labels_map[label] for label in labels]

        if (mode == "train" and args.use_multiprocessing) or (
            mode == "dev" and args.use_multiprocessing_for_evaluation
        ):
            if args.multiprocessing_chunksize == -1:
                chunksize = max(len(data) // (args.process_count * 2), 500)
            else:
                chunksize = args.multiprocessing_chunksize

            if text_b is not None:
                data = [
                    (
                        text_a[i: i + chunksize],
                        text_b[i: i + chunksize],
                        tokenizer,
                        args.max_seq_length,
                    )
                    for i in range(0, len(text_a), chunksize)
                ]
            else:
                data = [
                    (text_a[i: i + chunksize], None,
                     tokenizer, args.max_seq_length)
                    for i in range(0, len(text_a), chunksize)
                ]

            with Pool(args.process_count) as p:
                examples = list(
                    tqdm(
                        p.imap(preprocess_data_multiprocessing, data),
                        total=len(text_a),
                        disable=args.silent,
                    )
                )

            examples = {
                key: torch.cat([example[key] for example in examples])
                for key in examples[0]
            }
        else:
            examples = preprocess_data(
                text_a, text_b, labels, tokenizer, args.max_seq_length
            )
        if not args.use_temperature:
            if output_mode == "classification":
                labels = torch.tensor(labels, dtype=torch.long)
            elif output_mode == "regression":
                labels = torch.tensor(labels, dtype=torch.float)
            data = (examples, labels)
        else:
            labels = torch.tensor(labels)
            condition_labels = labels[:, :-1].long()
            temperature = labels[:, -1:].float()

            data = (examples, (condition_labels, temperature))

        if not args.no_cache and not no_cache:
            logger.info(" Saving features into cached file %s",
                        cached_features_file)
            torch.save(data, cached_features_file)

    return data


class ConditionWithTempDataset(Dataset):
    def __init__(self, data, tokenizer, args, mode, multi_label, output_mode, no_cache):
        self.examples, self.labels = build_classification_dataset(
            data, tokenizer, args, mode, multi_label, output_mode, no_cache
        )

    def __len__(self):
        return len(self.examples["input_ids"])

    def __getitem__(self, index):
        return (
            {key: self.examples[key][index] for key in self.examples},
            (self.labels[0][index], self.labels[1][index]),
        )

class ConditionWithTextDataset(Dataset):
    def __init__(self, data, tokenizer, args, mode, multi_label, output_mode, no_cache, **kwargs):
        self.examples, self.labels = build_classification_dataset(
            data, tokenizer, args, mode, multi_label, output_mode, no_cache
        )
        self.kwargs=kwargs

    def __len__(self):
        return len(self.examples["input_ids"])

    def __getitem__(self, index):
        return (
            {**{key: self.examples[key][index] for key in self.examples},
            **{key: self.kwargs[key][index] for key in self.kwargs}
            },
            self.labels[index]
        )




def encode(data):
    tokenizer, line, max_seq_length = data
    encode_results = tokenizer(line, padding="max_length", max_length=max_seq_length)
    encode_results = {k:v[:max_seq_length] for k,v in encode_results.items()}
    return encode_results


def encode_sliding_window(data):
    tokenizer, line, max_seq_length, special_tokens_count, stride, no_padding = data

    tokens = tokenizer.tokenize(line)
    stride = int(max_seq_length * stride)
    token_sets = []
    if len(tokens) > max_seq_length - special_tokens_count:
        token_sets = [
            tokens[i: i + max_seq_length - special_tokens_count]
            for i in range(0, len(tokens), stride)
        ]
    else:
        token_sets.append(tokens)

    features = []
    if not no_padding:
        sep_token = tokenizer.sep_token_id
        cls_token = tokenizer.cls_token_id
        pad_token = tokenizer.pad_token_id

        for tokens in token_sets:
            tokens = [cls_token] + tokens + [sep_token]

            input_ids = tokenizer.convert_tokens_to_ids(tokens)
            padding_length = max_seq_length - len(input_ids)
            input_ids = input_ids + ([pad_token] * padding_length)

            assert len(input_ids) == max_seq_length

            features.append(input_ids)
    else:
        for tokens in token_sets:
            input_ids = tokenizer.convert_tokens_to_ids(tokens)
            features.append(input_ids)

    return features


def preprocess_batch_for_hf_dataset(dataset, tokenizer, max_seq_length):
    return tokenizer(
        text=dataset["text"],
        truncation=True,
        padding="max_length",
        max_length=max_seq_length,
    )


def load_hf_dataset(data, tokenizer, args):
    dataset = load_dataset(
        "text",
        data_files=data,
        download_mode="force_redownload"
        if args.reprocess_input_data
        else "reuse_dataset_if_exists",
    )

    dataset = dataset.map(
        lambda x: preprocess_batch_for_hf_dataset(
            x, tokenizer=tokenizer, max_seq_length=args.max_seq_length
        ),
        batched=True,
    )

    dataset.set_format(type="pt", columns=["input_ids"])

    if isinstance(data, str):
        # This is not necessarily a train dataset. The datasets library insists on calling it train.
        return dataset["train"]
    else:
        return dataset


class SimpleCenterDataset(Dataset):
    def __init__(
        self,
        tokenizer,
        args,
        file_path,
        mode,
        block_size=512,
        special_tokens_count=2,
        sliding_window=False,
    ):
        assert os.path.isfile(file_path)
        block_size = block_size - special_tokens_count
        directory, filename = os.path.split(file_path)
        rxn_center_file_path = file_path.replace('rxn', 'rxn_center')
        cached_features_file = os.path.join(
            args.cache_dir,
            args.model_type + "_cached_lm_" + str(block_size) + "_" + filename,
        )

        if os.path.exists(cached_features_file) and (
            (not args.reprocess_input_data and not args.no_cache)
            or (mode == "dev" and args.use_cached_eval_features and not args.no_cache)
        ):
            logger.info(" Loading features from cached file %s",
                        cached_features_file)
            with open(cached_features_file, "rb") as handle:
                self.examples = pickle.load(handle)
        else:
            logger.info(
                " Creating features from dataset file at %s", args.cache_dir)

            if sliding_window:
                no_padding = (
                    True if args.model_type in [
                        "gpt2", "openai-gpt"] else False
                )
                with open(file_path, encoding="utf-8") as f:
                    lines = [
                        (
                            tokenizer,
                            line,
                            args.max_seq_length,
                            special_tokens_count,
                            args.stride,
                            no_padding,
                        )
                        for line in f.read().splitlines()
                        if (len(line) > 0 and not line.isspace())
                    ]

                if (mode == "train" and args.use_multiprocessing) or (
                    mode == "dev" and args.use_multiprocessing_for_evaluation
                ):
                    if args.multiprocessing_chunksize == -1:
                        chunksize = max(
                            len(lines) // (args.process_count * 2), 500)
                    else:
                        chunksize = args.multiprocessing_chunksize

                    with Pool(args.process_count) as p:
                        self.examples = list(
                            tqdm(
                                p.imap(
                                    encode_sliding_window, lines, chunksize=chunksize
                                ),
                                total=len(lines),
                                # disable=silent,
                            )
                        )
                else:
                    self.examples = [encode_sliding_window(
                        line) for line in lines]

                self.examples = [
                    example for example_set in self.examples for example in example_set
                ]
            else:
                with open(file_path, encoding="utf-8") as f:
                    lines = [
                        (tokenizer, line)
                        for line in f.read().splitlines()
                        if (len(line) > 0 and not line.isspace())
                    ]

                with open(rxn_center_file_path, encoding="utf-8") as fc:
                    center_lines = [
                        (tokenizer, line)
                        for line in fc.read().splitlines()
                        if (len(line) > 0 and not line.isspace())
                    ]

                if args.use_multiprocessing:
                    if args.multiprocessing_chunksize == -1:
                        chunksize = max(
                            len(lines) // (args.process_count * 2), 500)
                    else:
                        chunksize = args.multiprocessing_chunksize

                    with Pool(args.process_count) as p:
                        self.examples = list(
                            tqdm(
                                p.imap(encode, lines, chunksize=chunksize),
                                total=len(lines),
                                # disable=silent,
                            )
                        )
                    with Pool(args.process_count) as p:
                        self.examples_with_center = list(
                            tqdm(
                                p.imap(encode, center_lines,
                                       chunksize=chunksize),
                                total=len(center_lines),
                                # disable=silent,
                            )
                        )
                else:
                    self.examples = [encode(line) for line in lines]
                    self.examples_with_center = [
                        encode(line) for line in center_lines]

                self.examples = [
                    token for tokens in self.examples for token in tokens]
                self.examples_with_center = [
                    token for tokens in self.examples_with_center for token in tokens]
                if len(self.examples) > block_size:
                    self.examples = [
                        tokenizer.build_inputs_with_special_tokens(
                            self.examples[i: i + block_size]
                        )
                        for i in tqdm(
                            range(0, len(self.examples) -
                                  block_size + 1, block_size)
                        )
                    ]
                    self.examples_with_center = [
                        tokenizer.build_inputs_with_special_tokens(
                            self.examples_with_center[i: i + block_size]
                        )
                        for i in tqdm(
                            range(0, len(self.examples_with_center) -
                                  block_size + 1, block_size)
                        )
                    ]
                else:
                    self.examples = [
                        tokenizer.build_inputs_with_special_tokens(
                            self.examples)
                    ]
                    self.examples_with_center = [
                        tokenizer.build_inputs_with_special_tokens(
                            self.examples_with_center)
                    ]
            self.is_rxn_center_tokens = (torch.tensor(
                self.examples) != torch.tensor(self.examples_with_center)).long()
            logger.info(" Saving features into cached file %s",
                        cached_features_file)
            with open(cached_features_file, "wb") as handle:
                pickle.dump(self.examples, handle,
                            protocol=pickle.HIGHEST_PROTOCOL)

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, item):
        return torch.tensor(self.examples[item], dtype=torch.long), self.is_rxn_center_tokens[item]

class SimpleCenterIdxDataset(Dataset):
    def __init__(
        self,
        tokenizer,
        args,
        dataset_df,
        mode,
        block_size=512,
        special_tokens_count=2,
        sliding_window=False,
    ):

        block_size = block_size - special_tokens_count
      
        if sliding_window:
            no_padding = (
                True if args.model_type in [
                    "gpt2", "openai-gpt"] else False
            )
            with open(dataset_df, encoding="utf-8") as f:
                lines = [
                    (
                        tokenizer,
                        line,
                        args.max_seq_length,
                        special_tokens_count,
                        args.stride,
                        no_padding,
                    )
                    for line in f.read().splitlines()
                    if (len(line) > 0 and not line.isspace())
                ]

            if (mode == "train" and args.use_multiprocessing) or (
                mode == "dev" and args.use_multiprocessing_for_evaluation
            ):
                if args.multiprocessing_chunksize == -1:
                    chunksize = max(
                        len(lines) // (args.process_count * 2), 500)
                else:
                    chunksize = args.multiprocessing_chunksize

                with Pool(args.process_count) as p:
                    self.examples = list(
                        tqdm(
                            p.imap(
                                encode_sliding_window, lines, chunksize=chunksize
                            ),
                            total=len(lines),
                            # disable=silent,
                        )
                    )
            else:
                self.examples = [encode_sliding_window(
                    line) for line in lines]

            self.examples = [
                example for example_set in self.examples for example in example_set
            ]
        else:
            text_lines = dataset_df['text'].tolist()
            lines = [
                (tokenizer, line, args.max_seq_length)
                for line in text_lines
                if (len(line) > 0 and not line.isspace())
            ]
            
            def pad_masks(mask_label, max_seq_length):
                assert isinstance(mask_label, torch.Tensor)
                pad_tensor = torch.zeros(1).bool()
                mask_label_same_with_example = torch.cat([pad_tensor, mask_label, pad_tensor], dim=-1)
                pad_length = max_seq_length - mask_label_same_with_example.shape[0]
                return torch.cat([mask_label_same_with_example]+ [pad_tensor]*pad_length, dim=-1)[:max_seq_length]
            
            center_masks_labels, template_labels = map(list, zip(*dataset_df['labels'].tolist()))

            mask_labels = [torch.from_numpy(np.array(x)) for x in tqdm(center_masks_labels)]
            mask_labels_pad = [pad_masks(x, args.max_seq_length) for x in tqdm(mask_labels)]


            if args.use_multiprocessing:
                if args.multiprocessing_chunksize == -1:
                    chunksize = max(
                        len(lines) // (args.process_count * 2), 500)
                else:
                    chunksize = args.multiprocessing_chunksize

                with Pool(args.process_count) as p:
                    self.examples = list(
                        tqdm(
                            p.imap(encode, lines, chunksize=chunksize),
                            total=len(lines),
                            # disable=silent,
                        )
                    )

            else:
                self.examples = [encode(line) for line in tqdm(lines, total=len(lines))]

            mask_labels_pad = torch.stack(mask_labels_pad)
            def check_data(tokens_ids, mask_labels):
                flag = True
                for tokens, mask_label in tqdm(zip(tokens_ids, mask_labels)):
                    if len(tokens['input_ids']) != len(mask_label.tolist()):
                        print(tokens['input_ids'])
                        print(mask_label.tolist())
                        print(len(tokens['input_ids']))
                        print(len(mask_label.tolist()))
                        flag = False
                        return flag
                return flag
            
            assert check_data(self.examples, mask_labels_pad)
            
            self.is_rxn_center_tokens = mask_labels_pad
            self.template_labels = template_labels

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, item):

        data = {k:torch.tensor(v, dtype=torch.long) for k,v in self.examples[item].items()}

        return data, self.is_rxn_center_tokens[item], self.template_labels[item]


def mask_tokens_with_rxn(
    inputs, tokenizer: PreTrainedTokenizer, args
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original."""
    if isinstance(inputs, torch.Tensor):
        inputs = inputs
    elif isinstance(inputs, tuple):
        inputs_dict, is_center_marks, template_label = inputs
    if tokenizer.mask_token is None:
        raise ValueError(
            "This tokenizer does not have a mask token which is necessary for masked language modeling."
            "Set 'mlm' to False in args if you want to use this tokenizer."
        )
    inputs = inputs_dict['input_ids']
    labels = inputs.clone()
    # We sample a few tokens in each sequence for masked-LM training
    # (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
    probability_matrix = torch.full(labels.shape, args.mlm_probability)
    special_tokens_mask = [
        tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True)
        for val in labels.tolist()
    ]
    probability_matrix.masked_fill_(
        torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0
    )
    # We sample a few tokens in each sequence for masked reaction center modeling training
    # (with probability args.mrc_probability defaults to 0.5)
    probability_matrix.masked_fill_(
        is_center_marks.bool(), value=args.mrc_probability
    )
    if tokenizer._pad_token is not None:
        padding_mask = labels.eq(tokenizer.pad_token_id)
        probability_matrix.masked_fill_(padding_mask, value=0.0)
    masked_indices = torch.bernoulli(probability_matrix).bool()
    labels[~masked_indices] = -100  # We only compute loss on masked tokens

    if args.model_type == "electra":
        # For ELECTRA, we replace all masked input tokens with tokenizer.mask_token
        inputs[masked_indices] = tokenizer.convert_tokens_to_ids(
            tokenizer.mask_token)
    else:
        # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
        indices_replaced = (
            torch.bernoulli(torch.full(labels.shape, 0.8)
                            ).bool() & masked_indices
        )
        inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(
            tokenizer.mask_token)

        # 10% of the time, we replace masked input tokens with random word
        indices_random = (
            torch.bernoulli(torch.full(labels.shape, 0.5)).bool()
            & masked_indices
            & ~indices_replaced
        )
        random_words = torch.randint(
            len(tokenizer), labels.shape, dtype=torch.long)
        inputs[indices_random] = random_words[indices_random]

        # The rest of the time (10% of the time) we keep the masked input tokens unchanged
    inputs_dict['input_ids'] = inputs
    return inputs_dict, labels, template_label

def is_atom(token: str, special_tokens: List[str] = BAD_TOKS) -> bool:
    """

    
    Determine whether a token is an atom.

    Args:
        token: Token fed into the transformer model
        special_tokens: List of tokens to consider as non-atoms (often introduced by tokenizer)

    Returns:
        bool: True if atom, False if not
    """
    bad_toks = set(special_tokens)
    normal_atom = token[0].isalpha() or token[0] == "["
    is_bad = token in bad_toks
    return (not is_bad) and normal_atom


def get_mask_for_tokens(tokens: List[str],
                        special_tokens: List[str] = []) -> List[int]:
    """Return a mask for a tokenized smiles, where atom tokens
    are converted to 1 and other tokens to 0.

    e.g. c1ccncc1 would give [1, 0, 1, 1, 1, 1, 1, 0]

    Args:
        smiles: Smiles string of reaction
        special_tokens: Any special tokens to explicitly not call an atom. E.g. "[CLS]" or "[SEP]"

    Returns:
        Binary mask as a list where non-zero elements represent atoms
    """
    check_atom = partial(is_atom, special_tokens=special_tokens)

    atom_token_mask = [1 if check_atom(t) else 0 for t in tokens]
    return atom_token_mask


def number_tokens(tokens: List[str],
                  special_tokens: List[str] = BAD_TOKS) -> List[int]:
    """Map list of tokens to a list of numbered atoms

    Args:
        tokens: Tokenized SMILES
        special_tokens: List of tokens to not consider as atoms

    Example:
        >>> number_tokens(['[CLS]', 'C', '.', 'C', 'C', 'C', 'C', 'C', 'C','[SEP]'])
                #=> [-1, 0, -1, 1, 2, 3, 4, 5, 6, -1]
    """
    atom_num = 0
    isatm = partial(is_atom, special_tokens=special_tokens)

    def check_atom(t):
        if isatm(t):
            nonlocal atom_num
            ind = atom_num
            atom_num = atom_num + 1
            return ind
        return -1

    out = [check_atom(t) for t in tokens]

    return out

def identify_attention_token_idx_for_rxn_component(src_tokens):
    N_tokens = len(src_tokens)
    try:
        split_ind = src_tokens.index(
            ">>"
        )  # Index that separates products from reactants
        _product_inds = slice(split_ind + 1, N_tokens)
        _reactant_inds = slice(0, split_ind)
    except ValueError:
        raise ValueError(
            "rxn smiles is not a complete reaction. Can't find the '>>' to separate the products"
        )
    atom_token_mask = torch.tensor(
            get_mask_for_tokens(src_tokens, ["[CLS]", "[SEP]"])
        ).bool()
    token2atom = torch.tensor(number_tokens(src_tokens))
    atom2token = {
            k: v for k, v in zip(token2atom.tolist(), range(len(token2atom)))
        }


    _reactants_token_idx = torch.tensor([atom2token[x.item()] for x in token2atom[_reactant_inds][atom_token_mask[_reactant_inds]]]) 
    _product_token_idx = torch.tensor([atom2token[x.item()] for x in token2atom[_product_inds][atom_token_mask[_product_inds]]])
    
    return _reactants_token_idx, _product_token_idx, atom_token_mask

def generate_vocab(rxn_smiles, vocab_path):
    general_vocab = [
        '[PAD]',
        '[unused1]',
        '[unused2]',
        '[unused3]',
        '[unused4]',
        '[unused5]',
        '[unused6]',
        '[unused7]',
        '[unused8]',
        '[unused9]',
        '[unused10]',
        '[UNK]',
        '[CLS]',
        '[SEP]',
        '[MASK]',
    ]
    vocab = set(general_vocab)
    basic_tokenizer = RegexTokenizer()
    for rxn in tqdm(rxn_smiles):
        tokens = basic_tokenizer.tokenize(rxn)
        for token in tokens:
            vocab.add(token)
    print('A total of {} vacabs were obtained.'.format(len(vocab)))
    print('Write vocabs to {}'.format(os.path.abspath(vocab_path)))

    with open(vocab_path, 'w', encoding='utf-8') as f:
        f.write('\n'.join(list(vocab)))