import base64
import dataclasses
import json
import pprint
import time
from functools import partial
from multiprocessing import Pool

import h5py
import mlxu
import numpy as np
from datasets import load_dataset
from ml_collections import ConfigDict
from ml_collections.config_dict import config_dict
from tqdm import tqdm, trange

import FoT.data_pipeline as data_pipeline


class DatasetFactory(object):
    """Datset builder class."""

    @staticmethod
    def get_default_config(updates=None):
        config = ConfigDict()
        config.type = "huggingface"
        config.text_processor = TextProcessor.get_default_config()
        config.huggingface_dataset = HuggingfaceDataset.get_default_config()
        config.json_dataset = JsonDataset.get_default_config()

        if updates is not None:
            config.update(ConfigDict(updates).copy_and_resolve_references())
        return config

    @classmethod
    def load_dataset(cls, config, tokenizer, **kwargs):
        config = cls.get_default_config(config)
        text_processor = TextProcessor(config.text_processor, tokenizer)
        if config.type == "huggingface":
            return HuggingfaceDataset(
                config.huggingface_dataset, tokenizer, text_processor, **kwargs
            )
        elif config.type == "json":
            return JsonDataset(config.json_dataset, tokenizer, text_processor, **kwargs)
        else:
            raise ValueError(f"Unknown dataset type: {config.type}")

    def __init__(self):
        raise ValueError(
            "DatasetFactory is a static class and should not be instantiated."
        )


class TextProcessor(object):
    """Example processor that converts a dictionary of texts into tokens."""

    @staticmethod
    def get_default_config(updates=None):
        config = ConfigDict()
        config.fields_from_example = ""
        config.fields = ""
        config.subfield_separator = " "
        config.add_bos_token = True
        config.add_eos_token = True
        config.prepend_text = ""
        config.base64_token_dtype = "i4"
        if updates is not None:
            config.update(ConfigDict(updates).copy_and_resolve_references())
        return config

    def __init__(self, config, tokenizer):
        self.config = self.get_default_config(config)
        assert (
            self.config.fields != "" or self.config.fields_from_example != ""
        ), "Either fields or fields_from_example must be specified."
        self.tokenizer = tokenizer

    def __call__(self, example, has_aux=False):
        if has_aux:
            example, *aux = example
        else:
            aux = tuple()
        data_source = example.get("source_dataset", "other")
        token_buffer = []
        loss_mask_buffer = []

        if self.config.add_bos_token:
            token_buffer.append(self.tokenizer.bos_token_id)
            loss_mask_buffer.append(0.0)

        if self.config.fields_from_example != "":
            fields = example[self.config.fields_from_example].split(",")
        else:
            fields = self.config.fields.split(",")

        for i, field in enumerate(fields):
            if field.startswith("[") and field.endswith("]"):
                # No loss for this field.
                field = field[1:-1]
                mask = 0.0
            else:
                mask = 1.0

            if field == "<|bos|>":
                token_buffer.append(self.tokenizer.bos_token_id)
                loss_mask_buffer.append(mask)
            elif field == "<|eos|>":
                token_buffer.append(self.tokenizer.eos_token_id)
                loss_mask_buffer.append(mask)
            elif field.startswith("{") and field.endswith("}"):
                field = field[1:-1]
                # Base64 encoded raw tokens.
                tokens = np.frombuffer(
                    base64.b64decode(example[field]),
                    dtype=self.config.base64_token_dtype,
                ).tolist()
                token_buffer.extend(tokens)
                loss_mask_buffer.extend([mask for _ in range(len(tokens))])
            else:
                subfields = field.split("+")
                text = self.config.subfield_separator.join(
                    [example[subfield] for subfield in subfields]
                )
                if i == 0:
                    text = self.config.prepend_text + text
                tokens = self.tokenizer.encode(text)
                token_buffer.extend(tokens)
                loss_mask_buffer.extend([mask for _ in range(len(tokens))])

        if self.config.add_eos_token:
            token_buffer.append(self.tokenizer.eos_token_id)
            loss_mask_buffer.append(1.0)

        return token_buffer, loss_mask_buffer, data_source, *aux


class HuggingfaceDataset(object):
    """Huggingface dataset, where the dataset is loaded using the huggingface
    datasets.load_dataset() function.
    """

    @staticmethod
    def get_default_config(updates=None):
        config = ConfigDict()
        config.path = "c4"
        config.name = "en"
        config.split = "train"
        config.streaming = False
        config.seq_length = 1024
        config.batch_size = 8
        config.always_start_with_bos = False
        config.shuffle = False
        config.data_pipeline = "linear"
        config.min_example_length = 0  # 0 means no filtering

        if updates is not None:
            config.update(ConfigDict(updates).copy_and_resolve_references())
        return config

    def __init__(self, config, tokenizer, text_processor):
        self.config = self.get_default_config(config)
        name = self.config.name if self.config.name != "" else None
        split = self.config.split if self.config.split != "" else None
        self._tokenizer = tokenizer
        self._text_processor = text_processor
        self._dataset = load_dataset(
            self.config.path, name, split=split, streaming=self.config.streaming
        )
        if config.shuffle:
            self._dataset = self._dataset.shuffle(buffer_size=10_000, seed=42)
        self.data_pipeline_constructor = data_pipeline.get_data_pipeline_constructor(
            data_pipeline=self.config.data_pipeline
        )

    def __iter__(self):
        # chunk_size = self.config.batch_size * self.config.seq_length
        # total_tokens = 0

        assert self.config.always_start_with_bos is False

        def continuous_data_source():
            while True:
                token_source = data_pipeline.TextToToken(
                    text_source=iter(self._dataset), text_processor=self.text_processor
                )
                for elem in iter(token_source):
                    yield elem

        token_source = data_pipeline.TokenFilter(
            token_source=continuous_data_source,
            min_example_length=self.config.min_example_length,
            num_stat_samples=200,
        )

        data_source = iter(
            self.data_pipeline_constructor(
                token_source=token_source,
                batch_size=self.config.batch_size,
                seq_len=self.config.seq_length,
            )
        )
        while True:
            data_batch, metrics = next(data_source)
            yield data_batch, metrics

    def get_state_dict(self):
        return dict(config=self.config)

    def load_state_dict(self, state_dict):
        if "config" in state_dict:
            self.config.update(ConfigDict(state_dict["config"]))

    @property
    def seq_length(self):
        return self.config.seq_length

    @property
    def tokenizer(self):
        return self._tokenizer

    @property
    def text_processor(self):
        return self._text_processor

    @property
    def dataset(self):
        return self._dataset

    @property
    def vocab_size(self):
        return len(self._tokenizer)


class JsonDataset(object):
    """JSON dataset, where each line of the data file contains a JSON
    dictionary with text fields.
    """

    @staticmethod
    def get_default_config(updates=None):
        config = ConfigDict()
        config.path = ""
        config.seq_length = 1024
        config.min_example_length = 0  # 0 means no filtering
        config.truncate_files = False
        config.batch_size = 8
        config.always_start_with_bos = False
        config.start_seek_loc = 0
        config.example_index_at_start = 0
        config.tokens_count_at_start = 0
        config.tokenizer_processes = 1
        config.tokenizer_parallel_chunk_size = 32
        config.tokenizer_parallel_batch_size = 1024
        config.throughput_average_window_size = 200
        config.data_pipeline = "linear"

        if updates is not None:
            config.update(ConfigDict(updates).copy_and_resolve_references())
        return config

    def __init__(self, config, tokenizer, text_processor):
        self.config = self.get_default_config(config)
        assert self.config.path != ""
        self._tokenizer = tokenizer
        self._text_processor = text_processor
        self._index = self.config.example_index_at_start
        self._file_loc = self.config.start_seek_loc
        self._total_tokens = self.config.tokens_count_at_start
        self.epoch = -1

        self.data_pipeline_constructor = data_pipeline.get_data_pipeline_constructor(
            data_pipeline=self.config.data_pipeline,
        )

    def parse_json(self, line):
        if not line or line == "\n":
            return None
        try:
            data = json.loads(line)
        except json.decoder.JSONDecodeError:
            print(f"Error parsing json line:\n{line}")
            return None
        return data

    def json_iterator(self):
        with mlxu.open_file(self.config.path, "r") as fin:
            fin.seek(self._file_loc)
            while True:
                line = fin.readline()
                self._file_loc = fin.tell()
                if not line:  # Reached EOF
                    self._index = 0
                    fin.seek(0)
                    self.epoch += 1
                    continue

                data = self.parse_json(line)
                if data is not None:
                    # JSON parsing succeeded
                    yield data, self._file_loc, self._index
                self._index += 1

    def batched(self, iterator, batch_size):
        batch = []
        for example in iterator:
            batch.append(example)
            if len(batch) == batch_size:
                yield batch
                batch = []
        if len(batch) > 0:
            yield batch

    def parallel_example_iterator(self):
        if self.config.tokenizer_processes == 1:
            for example, loc, index in self.json_iterator():
                yield self.text_processor((example, loc, index), has_aux=True)
        else:
            process_pool = Pool(self.config.tokenizer_processes)
            batched_iterator = self.batched(
                self.json_iterator(), self.config.tokenizer_parallel_batch_size
            )
            with process_pool as pool:
                map_fn = partial(self.text_processor, has_aux=True)
                next_batch = pool.map_async(
                    map_fn,
                    next(batched_iterator),
                    chunksize=self.config.tokenizer_parallel_chunk_size,
                )
                while True:
                    current_batch = next_batch
                    next_batch = pool.map_async(
                        map_fn,
                        next(batched_iterator),
                        chunksize=self.config.tokenizer_parallel_chunk_size,
                    )
                    for example in current_batch.get():
                        yield example

    def __iter__(self):
        if self.config.tokenizer_processes == 1:
            self.epoch = 0
        token_source = data_pipeline.TokenFilter(
            token_source=self.parallel_example_iterator,
            min_example_length=self.config.min_example_length,
            num_stat_samples=self.config.throughput_average_window_size,
        )

        token_source_it = iter(token_source)

        data_source = iter(
            self.data_pipeline_constructor(
                token_source=token_source_it,
                batch_size=self.config.batch_size,
                seq_len=self.config.seq_length,
                truncate_files=self.config.truncate_files
            )
        )
        while True:
            data_batch, metrics = next(data_source)
            metrics.update(token_source.get_metrics())
            yield data_batch, metrics

    def get_state_dict(self):
        return dict(
            config=self.config,
            index=self._index,
            file_loc=self._file_loc,
            total_tokens=self._total_tokens,
        )

    def load_state_dict(self, state_dict):
        if "config" in state_dict:
            self.config.update(ConfigDict(state_dict["config"]))
        self._index = state_dict.get("index", self.config.example_index_at_start)
        self._file_loc = state_dict.get("file_loc", self.config.start_seek_loc)
        self._total_tokens = state_dict.get(
            "total_tokens", self.config.tokens_count_at_start
        )

    @property
    def seq_length(self):
        return self.config.seq_length

    @property
    def tokenizer(self):
        return self._tokenizer

    @property
    def text_processor(self):
        return self._text_processor

    @property
    def vocab_size(self):
        return len(self.tokenizer)
