from collections import defaultdict
from functools import partial
import json
import logging
import os
import pickle
import uuid
import pandas as pd
from tqdm import tqdm
import numpy as np
import yaml
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from simpletransformers.classification.classification_utils import (
    ClassificationDataset, )
from torch.utils.data import Dataset
from rxnfp.models import SmilesClassificationModel, SmilesTokenizer
import torch
import glob
from multiprocessing import Pool
from simpletransformers.classification.classification_utils import preprocess_data_multiprocessing, preprocess_data
import pandas as pd
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset
logger = logging.getLogger(__name__)

def build_classification_dataset(
    data, tokenizer, args, mode, multi_label, output_mode, no_cache
):
    cached_features_file = os.path.join(
        args.cache_dir,
        "cached_{}_{}_{}_{}_{}".format(
            mode,
            args.model_type,
            args.max_seq_length,
            len(args.labels_list),
            len(data),
        ),
    )

    if os.path.exists(cached_features_file) and (
        (not args.reprocess_input_data and not args.no_cache)
        or (mode == "dev" and args.use_cached_eval_features and not args.no_cache)
    ):
        data = torch.load(cached_features_file)
        logger.info(f" Features loaded from cache at {cached_features_file}")
        examples, labels = data
    else:
        logger.info(" Converting to features started. Cache is not used.")

        if len(data) == 3:
            # Sentence pair task
            text_a, text_b, labels = data
        else:
            text_a, labels = data
            text_b = None

        # If labels_map is defined, then labels need to be replaced with ints
        if args.labels_map and not args.regression:
            if multi_label:
                labels = [[args.labels_map[l] for l in label]
                          for label in labels]
            else:
                labels = [args.labels_map[label] for label in labels]

        if (mode == "train" and args.use_multiprocessing) or (
            mode == "dev" and args.use_multiprocessing_for_evaluation
        ):
            if args.multiprocessing_chunksize == -1:
                chunksize = max(len(data) // (args.process_count * 2), 500)
            else:
                chunksize = args.multiprocessing_chunksize

            if text_b is not None:
                data = [
                    (
                        text_a[i: i + chunksize],
                        text_b[i: i + chunksize],
                        tokenizer,
                        args.max_seq_length,
                    )
                    for i in range(0, len(text_a), chunksize)
                ]
            else:
                data = [
                    (text_a[i: i + chunksize], None,
                     tokenizer, args.max_seq_length)
                    for i in range(0, len(text_a), chunksize)
                ]

            with Pool(args.process_count) as p:
                examples = list(
                    tqdm(
                        p.imap(preprocess_data_multiprocessing, data),
                        total=len(text_a),
                        disable=args.silent,
                    )
                )

            examples = {
                key: torch.cat([example[key] for example in examples])
                for key in examples[0]
            }
        else:
            examples = preprocess_data(
                text_a, text_b, labels, tokenizer, args.max_seq_length
            )
        if not args.use_temperature:
            if output_mode == "classification":
                labels = torch.tensor(labels, dtype=torch.long)
            elif output_mode == "regression":
                labels = torch.tensor(labels, dtype=torch.float)
            data = (examples, labels)
        else:
            labels = torch.tensor(labels)
            condition_labels = labels[:, :-1].long()
            temperature = labels[:, -1:].float()

            data = (examples, (condition_labels, temperature))

        if not args.no_cache and not no_cache:
            logger.info(" Saving features into cached file %s",
                        cached_features_file)
            torch.save(data, cached_features_file)

    return data

def load_dataset(dataset_root, database_fname, use_temperature=False):
    csv_fpath = os.path.abspath(os.path.join(dataset_root, database_fname))
    print('Reading database csv from {}...'.format(csv_fpath))
    database = pd.read_csv(csv_fpath)

    all_idx_mapping_data_fpath = os.path.join(
        dataset_root, '{}_alldata_idx.pkl'.format(database_fname.split('.')[0]))
    with open(all_idx_mapping_data_fpath, 'rb') as f:
        all_idx2data, all_data2idx = pickle.load(f)

    all_condition_labels_fpath = os.path.join(
        dataset_root, '{}_condition_labels.pkl'.format(database_fname.split('.')[0]))
    with open(all_condition_labels_fpath, 'rb') as f:
        all_condition_labels = pickle.load(f)

    assert len(database) == len(all_condition_labels)

    database['condition_labels'] = all_condition_labels
    condition_label_mapping = (all_idx2data, all_data2idx)
    return database, condition_label_mapping

def data_provider_testing(args,config):
    model_args = config['model_args']
    dataset_args = config['dataset_args']
    decoder_args = model_args['decoder_args']

    try:
        model_args['use_temperature'] = dataset_args['use_temperature']
        print('Using Temperature:', model_args['use_temperature'])
    except:
        print('No temperature information is specified!')

    database_df, condition_label_mapping = load_dataset(**dataset_args)

    model_args['decoder_args'].update({
        'tgt_vocab_size':
        len(condition_label_mapping[0]),
        'condition_label_mapping':
        condition_label_mapping
    })

    test_df = database_df.loc[database_df['dataset'] == 'test']
    if args.use_graph:
        test_df['graph_info'] = test_df['graph_info'].apply(lambda x: eval(x))
        test_df = test_df[['canonical_rxn', 'condition_labels', 'sim_corpus', 'graph_info']]
        test_df.columns = ['text', 'labels', 'corpus_text', 'graph_info']
    else:
        if 'sim_corpus' not in test_df.columns:
            print('taking reaction SMILES as reaction text...')
            test_df['sim_corpus'] = test_df['canonical_rxn']
        test_df = test_df[['canonical_rxn', 'condition_labels', 'sim_corpus']]
        test_df.columns = ['text', 'labels', 'corpus_text']

    test_df.index = [i for i in range(len(test_df))]
    return test_df, condition_label_mapping, model_args


def data_provider_training(args,config):
    model_args = config['model_args']
    dataset_args = config['dataset_args']
    decoder_args = model_args['decoder_args']

    try:
        model_args['use_temperature'] = dataset_args['use_temperature']
        print('Using Temperature:', model_args['use_temperature'])
    except:
        print('No temperature information is specified!')

    database_df, condition_label_mapping = load_dataset(**dataset_args)

    model_args['decoder_args'].update({
        'tgt_vocab_size':
        len(condition_label_mapping[0]),
        'condition_label_mapping':
        condition_label_mapping
    })

    train_df = database_df.loc[database_df['dataset'] == 'train']
    if args.use_graph:
        train_df['graph_info'] = train_df['graph_info'].apply(lambda x: eval(x))
        train_df = train_df[['canonical_rxn', 'condition_labels', 'sim_corpus', 'graph_info']]
        train_df.columns = ['text', 'labels', 'corpus_text', 'graph_info']
    else:
        if 'sim_corpus' not in train_df.columns:
            print('taking reaction SMILES as reaction text...')
            train_df['sim_corpus'] = train_df['canonical_rxn']
        train_df = train_df[['canonical_rxn', 'condition_labels', 'sim_corpus']]
        train_df.columns = ['text', 'labels', 'corpus_text']
    train_df.index = [i for i in range(len(train_df))]

    eval_df = database_df.loc[database_df['dataset'] == 'val']
    if args.use_graph:
        eval_df['graph_info'] = eval_df['graph_info'].apply(lambda x: eval(x))
        eval_df = eval_df[['canonical_rxn', 'condition_labels', 'sim_corpus', 'graph_info']]
        eval_df.columns = ['text', 'labels', 'corpus_text', 'graph_info']
    else:
        if 'sim_corpus' not in eval_df.columns:
            print('taking reaction SMILES as reaction text...')
            eval_df['sim_corpus'] = eval_df['canonical_rxn']
        eval_df = eval_df[['canonical_rxn', 'condition_labels', 'sim_corpus']]
        eval_df.columns = ['text', 'labels', 'corpus_text']
    eval_df.index = [i for i in range(len(eval_df))]
    return train_df, eval_df, condition_label_mapping, model_args
