import logging
import os
import numpy as np
import importlib
from typing import List, Optional, Dict, Union
import datasets
import torch.utils.data.dataset
from datasets import Dataset, DatasetInfo, NamedSplit, IterableDataset
from datasets.arrow_dataset import (
    _check_column_names,
    _check_if_features_can_be_aligned,
    _align_features,
    update_metadata_with_features
)
# from datasets.iterable_dataset import (
#     VerticallyConcatenatedMultiSourcesExamplesIterable,
#     HorizontallyConcatenatedMultiSourcesExamplesIterable,
#     iterable_dataset
# )
from datasets.table import concat_tables, InMemoryTable
from datasets.fingerprint import update_fingerprint
# from datasets.features import Features
import pyarrow as pa
import pyarrow.compute as pc
from functools import partial
from tqdm import tqdm
import pickle

import hexa.tasks.constants as CONST

#from memory_profiler import profile
import psutil

def log_memory_usage(prefix=''):
    memory_usage_dict = dict(psutil.virtual_memory()._asdict())
    # memory_usage_gb = memory_usage_dict['available'] / 2. ** 30
    memory_usage_gb = psutil.Process().memory_info().rss / (1024 * 1024 * 1024)
    memory_usage_percent = memory_usage_dict['percent']
    logging.info(f"[{prefix}] Memory usage: {memory_usage_gb:.2f} [GB], {memory_usage_percent:.1f} %")


def load_dataset_with_mutators(data_name, config,
                               id='none',
                               mutators=None,
                               tokenize_keys=(CONST.MESSAGE_TEXT, CONST.LABELS),
                               tokenizer=None,
                               use_cache=False,):
    dataset_module = importlib.import_module(f'hexa.tasks.{data_name}.dataloader')
    load_data = getattr(dataset_module, 'load_data')
    load_func = partial(load_data, config)
    datatype = config.get('datatype', 'train')
    dataset_name_with_tag = f'{data_name}_{datatype}'
    episodic = config.get('episodic', False)
    dataset = mutate_dataset(load_func, config, id=id,
                             dataset_name=dataset_name_with_tag, use_cache=use_cache,
                             mutators=mutators, tokenizer=tokenizer,
                             tokenize_keys=tokenize_keys, episodic=episodic)
    return dataset


def apply_mutators(episode, mutators):
    episodes = [episode]
    for mutator_idx, mutator in enumerate(mutators):
        new_episodes = []
        for epi in episodes:
            new_epi = mutator(epi)
            if new_epi:
                new_episodes.extend(new_epi)
        episodes = new_episodes
    return episodes


# @profile
def mutate_dataset(
        load_func,
        config,
        id='none',
        dataset_name='',
        mutators=None,
        use_cache=False,
        tokenizer=None,
        tokenize_keys=None,
        episodic=False
):
    """
    This dataset object applied mutators and holds all data in memory and returns according to index
    The load_func is a generator function that returns data according to the following format
    :param load_func: Generator function to fetch data
    :param id: id of the Teacher being used
    :param dataset_name: name of the dataset, used for store and load of cache dataset
    :param mutators: (Optional) List of mutators
    :param use_cache: (Optional) rather to save and load cache for data
    :param tokenizer: (Optional) if provided, tokenizes the string data

    """
    # TODO: this needs to be set to custom directory
    _CACHE_PATH = config.cachepath.format(config.datatype)
    if id != 'none':
        _name = id
    else:
        _name = dataset_name
        mutators = [] if mutators is None else mutators
        for mutator in mutators:
            _name += '+' + mutator.__name__
    if tokenizer:
        _name += '_tokenized'
    # TODO: add option to save and load as arrow file
    if config.use_inhouse_summarizer and config.datatype=='train':
        file_path = os.path.join(_CACHE_PATH, _name + '_inhouse_summarizer.pkl')        
        if not os.path.isfile(file_path):
            file_path = os.path.join(_CACHE_PATH, _name + '.pkl')
    else:
        file_path = os.path.join(_CACHE_PATH, _name + '.pkl')
    # import pdb; pdb.set_trace()
    if use_cache and os.path.isfile(file_path):
        assert dataset_name, "Must provide dataset_name to use cache"
        print(f'Loading data from {file_path}')
        with open(file_path, 'rb') as f:
            data = pickle.load(f)
        if config.datatype=='valid' and config.vme > -1:
            data = data[:config.vme]            
    else:
        mutators = [mutator(config) for mutator in mutators]
        logging.info(f'Reading data using {load_func} with the name {dataset_name}')
        os.makedirs(_CACHE_PATH, exist_ok=True)
        data = []
        num_excluded = 0
        data_count = 0
        episodes = []
        episode_buffer = []

        log_memory_usage("step1")
        # Iterate through data once to parse dataset into episodes
        for d, aux in tqdm(load_func(), desc='Raw data'):
            if aux.get(CONST.EPISODE_BEGIN, True):
                if episode_buffer:
                    episodes.extend(apply_mutators(episode_buffer, mutators))
                    episode_buffer = []

            data_count += 1
            include_data = aux.get(CONST.KEEP, True)
            if include_data:
                episode_buffer.append((d, aux))
        # Check for leftovers
        if len(episode_buffer) > 0:
            episodes.extend(apply_mutators(episode_buffer, mutators))

        log_memory_usage("step2")
        # return few episodes for debugging
        if config.get('num_return_episodes', -1) > 0:
            episodes = episodes[:config.num_return_episodes]

        # # Mutate episodes
        # for mutator_idx, mutator in enumerate(mutators):
        #     new_episodes = []
        #     for epi_idx in tqdm(range(len(episodes)), desc=mutator.__name__, leave=False):
        #         new_epi = mutator(episodes[epi_idx])
        #         if new_epi:
        #             new_episodes += new_epi
        #     episodes = new_episodes

        log_memory_usage("step3")

        # tokenize or copy data
        for epi in episodes:
            new_epi = []
            for msg, _ in epi:
                msg = msg.copy()
                msg['id'] = id
                msg[CONST.EPISODE_END] = False  # just to make sure
                if tokenizer is not None:
                    tokenize_keys = [] if not tokenize_keys else tokenize_keys
                    for k in list(msg.keys()):
                        v = msg[k]
                        if k in tokenize_keys:
                            if isinstance(v, str):
                                msg[k] = tokenizer.encode(v)
                            elif isinstance(v, list) and isinstance(v[0], str):
                                msg[k] = tokenizer(v)['input_ids']
                        else:
                            msg.pop(k)
                new_epi.append(msg)
            new_epi[-1][CONST.EPISODE_END] = True
            if episodic:
                data.append(new_epi)
            else:
                data += new_epi

        logging.info(f'{data_count} instances read')
        logging.info(f'{num_excluded} instances excluded')
        if use_cache:
            logging.info(f'Saving data to {file_path}')
            with open(file_path, 'wb') as f:
                pickle.dump(data, f)
    # return datasets.Dataset.from_pandas(pd.DataFrame(data=data))
    return EpisodicDataset(data, episodic=episodic)


class EpisodicDataset(torch.utils.data.dataset.Dataset):
    def __init__(self, data: Union[List[List[Dict]], List[Dict]], episodic: bool=True):
        self._episodic = episodic
        self._flat = isinstance(data[0], dict) if len(data) > 0 else True
        self.data = data
        self.episode_lengths = [1] * len(data) if self._flat else [len(d) for d in data]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if self._episodic and self._flat:
            return [self.data[idx]]
        else:
            return self.data[idx]


class RandomCombinedDataset(Dataset):
    def __init__(self, data, lengths, weights=None, **kwargs):
        super().__init__(data, **kwargs)
        # Make sure length of dataset is as expected
        assert self.__len__() == sum(lengths)
        if not weights:
            weights = [1] * len(lengths)
        else:
            assert isinstance(weights, list), "Must provide list for weights "
            assert len(weights) == len(lengths), "Number of weights must be equal to number of datasets combined"
        self.weights = weights
        cumsum = 0
        cumlen = 0
        self.cum_weights = []
        self.cum_lengths = []
        for i, (w, l) in enumerate(zip(weights, lengths)):
            cumsum += w
            cumlen += l
            self.cum_weights.append(cumsum)
            self.cum_lengths.append(cumlen)

    def __getitem__(self, idx):
        random_choice = idx / self.cum_lengths[-1] * self.cum_weights[-1]
        task_idx = -1
        for i, cl in enumerate(self.cum_weights):
            if random_choice < cl:
                task_idx = i
                break
        if task_idx == -1:
            raise IndexError
        index_start = 0 if task_idx == 0 else self.cum_lengths[task_idx - 1]
        task_idx_slice = np.arange(index_start, self.cum_lengths[task_idx])
        random_idx = np.random.choice(task_idx_slice)
        return self._getitem(int(random_idx))


def concatenate_weighted_datasets(
    dsets: List[Dataset],
    info: Optional[DatasetInfo] = None,
    split: Optional[NamedSplit] = None,
    weights: Optional[List[int]] = None,
    axis: int = 0,
):
    """
    Converts a list of :class:`Dataset` with the same schema into a single :class:`Dataset`.
    Args:
        dsets (:obj:`List[datasets.Dataset]`): List of Datasets to concatenate.
        info (:class:`DatasetInfo`, optional): Dataset information, like description, citation, etc.
        split (:class:`NamedSplit`, optional): Name of the dataset split.
        weights (:class:`List[int]`): List of weights for each dataset.
        axis (``{0, 1}``, default ``0``, meaning over rows):
            Axis to concatenate over, where ``0`` means over rows (vertically) and ``1`` means over columns
            (horizontally).
            *New in version 1.6.0*
    Example:
    ```py
    >>> ds3 = concatenate_weighted_datasets([ds1, ds2])
    ```
    """

    if not dsets:
        raise ValueError("Unable to concatenate an empty list of datasets.")
    iterable = isinstance(dsets[0], IterableDataset)
    map_style = isinstance(dsets[0], Dataset)
    if not (iterable ^ map_style):
        raise ValueError(
            f"Expected a list of Dataset objects or a list of IterableDataset objects, but first element is a {type(dsets[0])}"
        )
    for dataset in dsets[1:]:
        if (map_style and not isinstance(dataset, Dataset)) or (iterable and not isinstance(dataset, IterableDataset)):
            raise ValueError(
                f"Unable to concatenate a {type(dsets[0])} with a {type(dataset)}. Expected a list of Dataset objects or a list of IterableDataset objects."
            )
    if map_style:
        return _concatenate_weighted_map_style_datasets(dsets, info=info, split=split, axis=axis, weights=weights)
    else:
        raise NotImplementedError("Multi-task concat for IterableDataset is not supported yet")
        # return _concatenate_iterable_datasets(dsets, info=info, split=split, axis=axis)


def _concatenate_weighted_map_style_datasets(
    dsets: List[Dataset],
    info: Optional[DatasetInfo] = None,
    split: Optional[NamedSplit] = None,
    weights: Optional[List[int]] = None,
    axis: int = 0,
):
    """
    Converts a list of :class:`Dataset` with the same schema into a single :class:`Dataset`.
    When you concatenate on axis 0, missing data are filled with None values.
    Args:
        dsets (:obj:`List[datasets.Dataset]`): List of Datasets to concatenate.
        info (:class:`DatasetInfo`, optional): Dataset information, like description, citation, etc.
        split (:class:`NamedSplit`, optional): Name of the dataset split.
        weights (:class:`List[int]`): List of weights for each dataset.
        axis (``{0, 1}``, default ``0``, meaning over rows):
            Axis to concatenate over, where ``0`` means over rows (vertically) and ``1`` means over columns
            (horizontally).
            *New in version 1.6.0*
    Example:
    ```py
    >>> ds3 = _concatenate_map_style_datasets([ds1, ds2])
    ```
    """
    # Ignore datasets with no rows
    if any(dset.num_rows > 0 for dset in dsets):
        dsets = [dset for dset in dsets if dset.num_rows > 0]
    else:
        # Return first dataset if all datasets are empty
        return dsets[0]

    # Perform checks (and a potentional cast if axis=0)
    if axis == 0:
        _check_if_features_can_be_aligned([dset.features for dset in dsets])
    else:
        if not all([dset.num_rows == dsets[0].num_rows for dset in dsets]):
            raise ValueError("Number of rows must match for all datasets")
        _check_column_names([col_name for dset in dsets for col_name in dset._data.column_names])

    # Find common format or reset format
    format = dsets[0].format
    if any(dset.format != format for dset in dsets):
        format = {}
        # logger.info("Some of the datasets have disparate format. Resetting the format of the concatenated dataset.")

    def apply_offset_to_indices_table(table, offset):
        if offset == 0:
            return table
        else:
            array = table["indices"]
            new_array = pc.add(array, pa.scalar(offset, type=pa.uint64()))
            return InMemoryTable.from_arrays([new_array], names=["indices"])

    # Concatenate indices if they exist
    if any(dset._indices is not None for dset in dsets):
        if axis == 0:
            # Datasets with no indices tables are replaced with a dataset with an indices table in memory.
            # Applying an offset to an indices table also brings the table in memory.
            indices_tables = []
            for i in range(len(dsets)):
                if dsets[i]._indices is None:
                    dsets[i] = dsets[i]._select_with_indices_mapping(range(len(dsets[i])))
                indices_tables.append(dsets[i]._indices)

            # An offset needs to be applied to the indices before concatenating
            offset = 0
            for i in range(len(dsets)):
                indices_tables[i] = apply_offset_to_indices_table(indices_tables[i], offset)
                offset += len(dsets[i]._data)

            # Concatenate indices
            indices_tables = [t for t in indices_tables if len(t) > 0]
            if indices_tables:
                indices_table = concat_tables(indices_tables)
            else:
                indices_table = InMemoryTable.from_batches([], schema=pa.schema({"indices": pa.int64()}))
        else:
            if len(dsets) == 1:
                indices_table = dsets[0]._indices
            else:
                for i in range(len(dsets)):
                    dsets[i] = dsets[i].flatten_indices()
                indices_table = None
    else:
        indices_table = None

    dset_lengths = [len(dset) for dset in dsets]

    table = concat_tables([dset._data for dset in dsets], axis=axis)
    if axis == 0:
        features_list = _align_features([dset.features for dset in dsets])
    else:
        features_list = [dset.features for dset in dsets]
    table = update_metadata_with_features(table, {k: v for features in features_list for k, v in features.items()})

    # Concatenate infos
    if info is None:
        info = datasets.DatasetInfo.from_merge([dset.info for dset in dsets])
    fingerprint = update_fingerprint(
        "".join(dset._fingerprint for dset in dsets), _concatenate_weighted_map_style_datasets, {"info": info, "split": split}
    )

    # Make final concatenated dataset
    concatenated_dataset = RandomCombinedDataset(
        table,
        lengths=dset_lengths,
        weights=weights,
        info=info,
        split=split,
        indices_table=indices_table,
        fingerprint=fingerprint,
    )
    concatenated_dataset.set_format(**format)
    return concatenated_dataset

