"""
This Memformer Dataloader ignores sentence boundary

It is hard to debug inside the dataloader and processor.
Therefore, please make sure everything works before feeding into FlyDataloader
"""
import os
import json
import numpy as np
import random
import torch
import torch.nn as nn
from torch.utils.data import IterableDataset
from torch.nn.utils.rnn import pad_sequence

from transformers import RobertaTokenizerFast
from nlp import Dataset

from torchfly.rl.env import Env
from torchfly.flydata import FlyDataLoader
from torchfly.flyconfig import GlobalFlyConfig
from torchfly.rl.vector import AsyncVectorEnv
from torchfly.common import set_random_seed

from typing import Iterator, Tuple, List

# pylint:disable=no-member


class CollateFunc:
    def __init__(self, config):
        self.memory_reset_dropout = config.processing.memory_reset_dropout
        self.time_horizon = config.processing.time_horizon
        self.tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
        self.pad_token_id = self.tokenizer.pad_token_id

    def collate_func(self, observations: List, infos: List, dones: List):
        rollout = []
        memory_reset_signals = []

        for time in range(self.time_horizon):

            source = [obs[time][0] for obs in observations]
            target = [obs[time][1] for obs in observations]
            reset = torch.FloatTensor([info[time] for info in infos])

            source_input_ids = pad_sequence(
                [torch.LongTensor(item) for item in source], batch_first=True, padding_value=self.pad_token_id
            )
            target_input_ids = pad_sequence(
                [torch.LongTensor(item) for item in target], batch_first=True, padding_value=self.pad_token_id
            )

            batch = {"source_ids": source_input_ids, "target_ids": target_input_ids}
            rollout.append(batch)
            memory_reset_signals.append(reset)

        return rollout, memory_reset_signals


class Wiki103NoSentBoundDataLoader:
    def __init__(self, config):
        self.config = config
        self.collator = CollateFunc(config.flydata.training)
        self.eval_collator = CollateFunc(config.flydata.validation)

    def train_dataloader_fn(self, config):
        def _make_env(rank):
            env = TimeHorizonWiki103NoSentBound(rank, config.flydata.training)
            return env

        in_series = config.flydata.training.dataloader.in_series
        vec_env = AsyncVectorEnv(
            [_make_env for i in range(config.flydata.training.dataloader.batch_size)], in_series=in_series
        )
        dataloader = FlyDataLoader(config.flydata.training, vec_env, collate_func=self.collator.collate_func)

        return dataloader

    def valid_dataloader_fn(self, config):
        def _make_env(rank):
            env = TimeHorizonWiki103NoSentBound(rank, config.flydata.validation)
            return env

        in_series = config.flydata.validation.dataloader.in_series
        vec_env = AsyncVectorEnv(
            [_make_env for i in range(config.flydata.validation.dataloader.batch_size)], in_series=in_series
        )
        dataloader = FlyDataLoader(config.flydata.validation, vec_env, collate_func=self.eval_collator.collate_func)

        return dataloader


class Wiki103NoSentBound(Env):
    def __init__(self, rank, config):
        super().__init__()
        self.rank = rank
        self.config = config
        self.filename = config.datapath
        self.tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")

        self.max_seq_len = config.processing.max_seq_len
        self.time_horizon = config.processing.time_horizon
        self.batch_size = config.dataloader.batch_size

        assert isinstance(self.time_horizon, int)

        self.sep_token_id = self.tokenizer.sep_token_id
        self.cls_token_id = self.tokenizer.cls_token_id
        self.pad_token_id = self.tokenizer.pad_token_id
        self.mask_token_id = self.tokenizer.mask_token_id

        # Load data and select ranking
        dataset = Dataset.from_file(self.filename)
        split_size = max(1, len(dataset) // self.batch_size)
        self.data = dataset[self.rank * split_size:(self.rank + 1) * split_size]["document"]

        self.buffer = []
        self.iterator = iter(self)

    def step(self, actions=None):
        item = next(self.iterator)
        source, target, is_last_segment, done = item
        observation = (source, target)
        info = is_last_segment
        return observation, info, done

    def reset(self):
        self.buffer = []
        self.iterator = iter(self)

    def __iter__(self):
        "Never ends"
        random.shuffle(self.data)

        for document in self.data:
            document = self.tokenizer.encode(document, add_special_tokens=False)

            for item in self.pre_process(document):
                source, target, is_last_segment = item
                done = False
                yield source, target, is_last_segment, done

        # if all documents have been processed
        while True:
            done = True
            yield [], [], 0, done

    def pre_process(self, document):
        document = [self.cls_token_id] + document + [self.sep_token_id]

        for idx in range(0, len(document)):
            token_id = document[idx]

            if token_id == self.sep_token_id:
                source = self.buffer + [self.sep_token_id]
                target = self.buffer[1:] + [self.sep_token_id, self.sep_token_id]
                # source, target, last_segment
                yield source, target, 1
                self.buffer = []
            elif len(self.buffer) == self.max_seq_len:
                source = self.buffer
                target = self.buffer[1:] + [token_id]
                # source, target, last_segment
                yield source, target, 0
                self.buffer = []

            self.buffer.append(token_id)
        self.buffer = []


class TimeHorizonWiki103NoSentBound(Wiki103NoSentBound):
    def __init__(self, rank, config):
        super().__init__(rank, config)
        self.time_horizon = config.processing.time_horizon

    def step(self, actions=None):
        observations = []
        infos = []

        for _ in range(self.time_horizon):
            item = next(self.iterator)
            source, target, is_last_segment, done = item
            obs = (source, target)
            info = is_last_segment

            observations.append(obs)
            infos.append(info)

        done = done

        return observations, infos, done
