import os
import pathlib
import multiprocessing
import logging

import pyarrow as pa
import numpy as np
import torch
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import GPT2TokenizerFast

from workbench.data.detokenizer import wikitext_detokenize
from workbench.data.collators import DataCollator

IGNORE_INDEX = -100


class IndexDataset(torch.utils.data.Dataset):
    """
    Wrapper class to hold arrow file dataset indices
    """

    def __init__(self, dataset_indices):
        self.dataset_indices = dataset_indices

    def __getitem__(self, index):
        return self.dataset_indices[index]

    def __len__(self):
        return len(self.dataset_indices)


class PlArrowFileModule():
    """
    Datamodule to perform pretraining
    based on 1 train arrow file, 1 val arrow file
    Assumes that pre-processed indices exist
    """

    def __init__(
        self,
        tokenizer,
        dataset_name,
        num_cpu_worker,
        num_gpu_worker,
        max_sample_len,
        seed,
        batch_size,
        data_dir,
        cache_dir,
        val_ratio,
        val_split_seed,
        evaluation_sets,
    ):
        super().__init__()

        self.num_gpu_worker = num_gpu_worker

        if num_cpu_worker is None:
            num_cpu_worker = os.cpu_count()
        self.num_cpu_worker = num_cpu_worker

        self.resume_index = None # TODO not implemented yet
        self.dataset_name = dataset_name

        self.data_dir = pathlib.Path(data_dir)
        self.data_dir.mkdir(parents=True, exist_ok=True)
        self.cache_dir = pathlib.Path(cache_dir)
        self.cache_dir.mkdir(parents=True, exist_ok=True)

        self.batch_size = batch_size
        self.max_sample_len = max_sample_len
        self.seed = seed

        self.evaluation_sets = evaluation_sets
        self.val_sets_name = [dataset_name]
        if evaluation_sets:
            self.val_sets_name += evaluation_sets

        self.logger = logging.getLogger(__name__)

        self.splits = ['validation', 'train']
        self.val_ratio = val_ratio
        self.val_split_seed = val_split_seed
        if self.dataset_name == 'openwebtext':
            self.val_cfg_str = f"val-{str(val_ratio).split('.')[1]}-{val_split_seed}_"
        else:
            self.val_cfg_str = f""
        self.ignore_index = IGNORE_INDEX


        self.tokenizer = tokenizer

        if not self.tokenizer.pad_token:
            self.tokenizer.add_special_tokens({'pad_token': "<pad>"})

        self.seq_vocab_size = int(np.ceil(len(self.tokenizer) / 128)) * 128
        self.trg_vocab_size = int(np.ceil(len(self.tokenizer) / 128)) * 128
        self.vocab_size = len(self.tokenizer)

        self.global_rank = 0

        self.collator = DataCollator(src_mask_token_id=self.tokenizer.pad_token_id, trg_mask_token_id=self.ignore_index)

        self.prepare_data() # pre-call to avoid initialized distributed pytroch for distributed pre-process

    def prepare_data(self):
        if not self._exist_preprocessed_data():
            self._preprocess_data()

        if self.evaluation_sets:
            if not self._exist_preprocessed_eval_data():
                self._preprocess_eval_data()

    def _preprocess_eval_data(self):
        for dataset in self.evaluation_sets:
            if dataset == "wikitext":
                samples = load_dataset("wikitext", "wikitext-2-raw-v1", split='validation',
                                       cache_dir=self.cache_dir.as_posix())
                samples = [self.tokenizer(wikitext_detokenize(s))['input_ids'] for s in samples["text"] if len(s) > 0]
            elif dataset == "lambada":
                samples = load_dataset("lambada", split='validation', cache_dir=self.cache_dir.as_posix())
                samples = [self.tokenizer(s)['input_ids'] for s in samples["text"] if len(s) > 0]
            elif dataset == "ptb":
                samples = load_dataset("ptb_text_only", split='validation', cache_dir=self.cache_dir.as_posix())
                samples = [self.tokenizer(s)['input_ids'] for s in samples["sentence"] if len(s) > 0]
            else:
                raise UserWarning(f"unknown evaluation dataset {dataset}")

            pre_num_samples = len(samples)
            samples = [s for s in samples if 1 < len(s)]  # check that every sample is shiftable
            file_name = f"evaluation_{dataset}.pth"
            torch.save(samples, self.data_dir / file_name)
            self.logger.info(
                f"load {dataset} with {len(samples)} samples ({pre_num_samples - len(samples)}samples are too short/long)")


    def _exist_preprocessed_eval_data(self):
        all_files_exist = True
        for dataset in self.evaluation_sets:
            file_name = f"evaluation_{dataset}.pth"
            file_exist = os.path.exists(self.data_dir / file_name)
            all_files_exist &= file_exist
            if not file_exist:
                self.logger.info(f"Checked preprocessed eval data: {(self.data_dir / file_name).as_posix()} does not exist.")
        if all_files_exist:
            self.logger.info("Checked preprocessed evaluation data: All file exist.")
        return all_files_exist



    def _exist_preprocessed_data(self):
        all_files_exist = True
        for split in self.splits:
            base_file = f"{self.dataset_name}_{split}_{self.val_cfg_str}{self.max_sample_len +1}_{self.num_gpu_worker}"
            for worker_id in range(self.num_gpu_worker):
                file_name = base_file + f"_{worker_id}.arrow"
                file_exist = os.path.exists(self.data_dir / file_name)
                all_files_exist &= file_exist
                if not file_exist:
                    self.logger.info(f"Checked preprocessed data: {(self.data_dir / file_name).as_posix()} does not exist.")
        if all_files_exist:
            self.logger.info("Checked preprocessed data: All file exist.")
        return all_files_exist

    def _preprocess_data(self):

        max_sample_len_plus = self.max_sample_len + 1 # Because of source target shift in decoder/teacher-forcing training

        if self.dataset_name == "pile":
            all_samples = load_dataset('the_pile', name="all",
                                   data_dir=self.data_dir.absolute().as_posix(),
                                   cache_dir=self.cache_dir.absolute().as_posix())
        elif self.dataset_name == "openwebtext":
            all_samples = load_dataset('openwebtext',
                                   data_dir=self.data_dir.absolute().as_posix(),
                                   cache_dir=self.cache_dir.absolute().as_posix())
        elif self.dataset_name == "tinystories":
            all_samples = load_dataset("roneneldan/TinyStories",)
                                   # data_dir=self.data_dir.absolute().as_posix(),
                                   # cache_dir=self.cache_dir.absolute().as_posix())
        elif self.dataset_name == "wikitext":
            all_samples = load_dataset('wikitext','wikitext-103-v1', )
                                       # data_dir=self.data_dir.absolute().as_posix(),
                                       # cache_dir=self.cache_dir.absolute().as_posix())

            # [2021-12-25] TD: Running the detokenizer on wikitext-103 makes ppl worse
            # (GPT2-small val ppl after 10 epochs ~22 -> ~25)
            # However, it's useful for zero-shot transfer from Openwebtext,
            # as after detokenization it's closer to Openwebtext's format.
            # https://github.com/stanford-crfm/mistral/issues/12
            all_samples = all_samples.map(
                lambda example: {'text': wikitext_detokenize(example['text'])},
                num_proc=max(self.num_cpu_worker, 1),
                desc='Running detokenizer on dataset' )
        else:
            raise UserWarning(f"dataset name unknown: {self.dataset_name}")

        if 'validation' not in all_samples:
            all_samples = all_samples["train"].train_test_split(
                test_size=self.val_ratio, seed=self.val_split_seed,
                shuffle=True  # Otherwise test will be at the end of the dataset
            )
            all_samples['validation'] = all_samples['test']


        for split in self.splits:

            samples = all_samples[split]

            base_file = f"{self.dataset_name}_{split}_{self.val_cfg_str}{max_sample_len_plus}_{self.num_gpu_worker}"
            file_name = (self.data_dir / base_file).as_posix()

            numb_samples = len(samples)
            avg_length = np.mean([len(s['text']) for s in samples])
            self.logger.info(f'split {split} load {numb_samples} data samples with avg length of {avg_length}')

            if multiprocessing.cpu_count() < 16:
                raise UserWarning(f'preprocess requires at least {self.num_gpu_worker * 2} cpus')

            worker_world_size = (multiprocessing.cpu_count() // self.num_gpu_worker - 1) * self.num_gpu_worker
            assert worker_world_size % self.num_gpu_worker == 0

            index_list = list(range(numb_samples))
            index_step = numb_samples // worker_world_size

            pa_type = pa.list_(pa.uint16() if self.vocab_size < 65535 else pa.uint32())
            batch = pa.RecordBatch.from_arrays([pa.array([list(range(max_sample_len_plus))], type=pa_type)], names=['text'])

            self.logger.info(f'split {split} start parallel {worker_world_size} worker')

            return_queues = []
            for _ in range(self.num_gpu_worker):
                return_queues.append(multiprocessing.Queue(maxsize=25))

            memory_manager_list = []

            for worker_idx in range(self.num_gpu_worker):
                self.logger.info(f"start queue2file_writer process {worker_idx}")
                memory_manager = multiprocessing.Process(target=self._queue2file_writer, args=(
                file_name, batch, worker_idx, worker_world_size, return_queues))
                memory_manager.daemon = True
                memory_manager.start()
                memory_manager_list.append(memory_manager)

            for worker_idx in range(worker_world_size):
                self.logger.info(f"start preprocess_samples2queue process {worker_idx}")
                indexes = index_list[worker_idx * index_step:(1 + worker_idx) * index_step]
                worker = worker_idx % self.num_gpu_worker
                memory_manager = multiprocessing.Process(target=self._preprocess_samples2queue,
                                                         args=(self.tokenizer, samples, indexes, max_sample_len_plus, worker, return_queues, pa_type))
                memory_manager.daemon = True
                memory_manager.start()
                memory_manager_list.append(memory_manager)

            for memory_manager in memory_manager_list:
                memory_manager.join()

            self.logger.info(f'split {split} preprocess done')

    @staticmethod
    def _queue2file_writer(file_name, batch, worker, worker_world_size, return_queues):

        total_samples = 0
        with pa.OSFile(f"{file_name}_{worker}.arrow", 'wb') as sink:
            with pa.ipc.new_file(sink, batch.schema) as writer:
                end_count = 0
                while end_count < worker_world_size // len(return_queues):
                    batch = return_queues[worker].get()
                    if batch == 'END':
                        end_count += 1
                    else:
                        writer.write_batch(batch)
                        total_samples += 1
        return_queues[worker].close()
        print(f"queue2file_writer {worker} wrote {total_samples} in {file_name}_{worker}.arrow")

    @staticmethod
    def _preprocess_samples2queue(tokenizer, samples, indexes, max_sample_len_plus, worker, return_queues, pa_type):

        count_writes = 0
        idx_sample = 0
        tmp_sample = []
        numb_samples = len(indexes)

        while idx_sample < numb_samples:
            if len(tmp_sample) < max_sample_len_plus:
                raw_text = samples[indexes[idx_sample]]['text']
                tmp_sample += tokenizer(raw_text)['input_ids'] + [tokenizer.eos_token_id]
                idx_sample += 1
            else:
                arr, tmp_sample = pa.array([tmp_sample[:max_sample_len_plus]], type=pa_type), tmp_sample[max_sample_len_plus:]
                batch = pa.RecordBatch.from_arrays([arr], names=['text'])
                return_queues[worker].put(batch)
                count_writes += 1

        return_queues[worker].put('END')
        print(
            f"preprocess_samples2queue {worker} done: processed {idx_sample} data samples, created {count_writes} training samples")

    def setup(self, global_rank=0):

        self.global_rank = global_rank



        # if world_size < self.num_gpu_worker:
        #     raise UserWarning(f"world size ({world_size}) is smaller than expected number of GPUs ({self.num_gpu_worker})")
        # elif world_size > self.num_gpu_worker:
        #     self.logger.warning(f"world size ({world_size}) is larger than number of GPUs ({self.num_gpu_worker}), you could use more")

        self.rng = np.random.RandomState(self.seed + self.global_rank)
        

        self.logger.info("Create memory map\n")
        train_file_name = f"{self.dataset_name}_train_{self.val_cfg_str}{self.max_sample_len +1}_{self.num_gpu_worker}_{self.global_rank}.arrow"
        mmap = pa.memory_map( (self.data_dir / train_file_name).as_posix() )
        self.logger.info("MMAP Read ALL")
        self._train_dataset = pa.ipc.open_file(mmap)
        
        
        
        

        valid_file_name = f"{self.dataset_name}_validation_{self.val_cfg_str}{self.max_sample_len +1}_{self.num_gpu_worker}_{self.global_rank}.arrow"
        valid_mmap = pa.memory_map((self.data_dir / valid_file_name).as_posix() )
        self._valid_dataset = pa.ipc.open_file(valid_mmap)


    def batch_to_device(self, batch, device):
        for key, value in batch.items():
            if isinstance(value, torch.Tensor):
                batch[key] = value.to(device)
        return batch


    def train_dataloader(self, current_epoch=0, local_rank=0):
        """This will be run every epoch."""

        # if torch.distributed.is_initialized():
        #     global_rank = torch.distributed.get_rank()
        # else:
        #     global_rank = 0
        #
        # local_rank = global_rank % self.num_gpu_worker # TODO CONFIGRUABEL

        train_set_size = self._train_dataset.num_record_batches
        train_indexes = list(range(train_set_size))
        train_indexes = self.rng.permutation(train_indexes)

        # min_num_samples = torch.LongTensor([train_set_size]).to(local_rank)
        min_num_samples = torch.tensor(train_set_size, device=f'cuda:{local_rank}')
        if torch.distributed.is_initialized():
            torch.distributed.all_reduce(min_num_samples, op=torch.distributed.ReduceOp.MIN)
        min_num_samples = min_num_samples.item()
        train_indexes = train_indexes[:min_num_samples]

        self.logger.info(f"### load train set with size {min_num_samples} from {train_set_size} samples on rank {self.global_rank}")

        # shuffle the indices for every epoch other than 0.
        # the loaded indices are already shuffled
        if current_epoch > 0:
            seed = self.seed + current_epoch + self.global_rank
            tmp_rng = np.random.default_rng(seed)
            train_indexes = tmp_rng.permutation(train_indexes)

        if self.resume_index is not None:
            train_indexes = train_indexes[self.resume_index :]
            self.resume_index = None  # reset to avoid next-epoch issues

        train_index_dataset = IndexDataset(train_indexes)

        def train_pl_collate_fn(indices):
            raw_samples = [self._train_dataset.get_record_batch(i)['text'].to_pylist()[0] for i in indices]
            return self.collator(raw_samples)

        loader = DataLoader(
            train_index_dataset,
            batch_size=self.batch_size,
            collate_fn=train_pl_collate_fn,
            num_workers=self.num_cpu_worker,
            pin_memory=True,
            drop_last=False,
        )
        self.logger.info("Finished loading training data")
        return loader

    def val_dataloader(self, local_rank=0):

        valid_set_size = self._valid_dataset.num_record_batches
        valid_indexes = list(range(valid_set_size))

        # if torch.distributed.is_initialized():
        #     global_rank = torch.distributed.get_rank()
        # else:
        #     global_rank = 0
        #
        #
        # local_rank = global_rank % self.num_gpu_worker  # TODO CONFIGRUABEL

        # min_num_samples = torch.LongTensor([valid_set_size]).to(local_rank)
        min_num_samples = torch.tensor(valid_set_size, device=f'cuda:{local_rank}')

        if torch.distributed.is_initialized():
            torch.distributed.all_reduce(min_num_samples, op=torch.distributed.ReduceOp.MIN)
        min_num_samples = min_num_samples.item()
        valid_indexes = valid_indexes[:min_num_samples]

        valid_index_dataset = IndexDataset(valid_indexes)

        print(f"### load valid set with size {min_num_samples} from {valid_set_size} samples on rank {self.global_rank}")

        def val_pl_collate_fn(indices):
            inputs = [self._valid_dataset.get_record_batch(i)['text'].to_pylist()[0] for i in indices]
            return self.collator(inputs)

        loader = DataLoader(
            valid_index_dataset,
            batch_size=self.batch_size,
            collate_fn=val_pl_collate_fn,
            num_workers=self.num_cpu_worker,
            pin_memory=True,
            drop_last=False,
        )
        self.logger.info(f"Finished loading validation data")

        if not self.evaluation_sets:
            return loader
        else:

            loader_list = [loader]
            for dataset in self.evaluation_sets:
                file_name = f"evaluation_{dataset}.pth"
                samples = torch.load(self.data_dir / file_name)
                loader = DataLoader(
                    samples,
                    batch_size=self.batch_size,
                    collate_fn=self.collator,
                    num_workers=self.num_cpu_worker,
                    pin_memory=True,
                    drop_last=False,
                )
                loader_list.append(loader)
                self.logger.info(f"Finished loading evaluation data {dataset}")

            return loader_list


