"""
This combines hf_data and pt_data.
"""
from dataclasses import dataclass, field
from typing import Any, Dict, List

import torch
from ml_collections import ConfigDict
from torch.utils.data import IterableDataset
from transformers import DefaultDataCollator

from coh.data.hf_data import HumanFeedbackDataset
from coh.data.pt_data import PretrainDataset


@dataclass
class CoHDataArgs:
    seq_length: int = field(
        default=32,
        metadata={"help": "only use the first 32 tokens of documents (including title)"})
    hf_weights: str = field(
        default="",
        metadata={
            "help": "comma-separated weights for sampling from each dataset. Length should be 3."
        })
    cache_dir: str = field(default='./')

    data_contain: str = field(default="rlhf,summary,webgpt")
    pretain_task: bool = field(default=False)
    dataset_size: int = field(default=1000)
    factor: int = field(default=1)
    data_method: str = field(default=None)
    set_seed: int = field(default=42)
    test_data_path: str = field(default="")


class CoHDataset(IterableDataset):

    @staticmethod
    def get_default_config(updates=None):
        config = ConfigDict()
        config.seq_length = 512
        config.split = 'train'
        config.batch_size = 1  # fixed: control this outside
        config.hf_weights = ""
        config.pretrain_task = False

        if updates is not None:
            config.update(ConfigDict(updates).copy_and_resolve_references())
        ############## hf ##################
        config.hf = ConfigDict()
        config.hf.set_seed = config.set_seed
        config.hf.seq_length = config.seq_length
        config.hf.split = config.split
        config.hf.batch_size = config.batch_size
        config.hf.data_contain = config.data_contain
        config.hf.dataset_size = config.dataset_size
        config.hf.train_method = config.data_method
        config.hf.factor = config.factor
        config.hf.test_data_path = config.test_data_path
        # specific
        config.hf.weight = config.hf_weights
        ############## pt ##################
        config.pt = ConfigDict()
        config.pt.seq_length = config.seq_length
        config.pt.split = config.split
        config.pt.batch_size = config.batch_size
        # specific
        config.pt.path = 'c4'
        config.pt.name = 'en'
        config.pt.field = 'text'
        config.pt.streaming = True

        return config

    @staticmethod
    def load_webgpt_dataset(test_size=0.1):
        return HumanFeedbackDataset.make_webgpt_test_set(test_size=test_size)

    def __init__(self, config, tokenizer, webgpt_data):
        super().__init__()
        self.config = self.get_default_config(config)

        self._tokenizer = tokenizer
        self._hf_datset = HumanFeedbackDataset(self.config.hf, tokenizer, webgpt_data)
        if self.config.pretrain_task == True:
            self._pt_datset = PretrainDataset(self.config.pt, tokenizer)


    def __iter__(self):
        if self.config.pretrain_task == True:
            pass
        else:
            for hf in self._hf_datset:
                yield hf

    def __getstate__(self):
        return self.config, self.tokenizer

    def __len__(self):
        return self._hf_datset.dataset_size

    def __setstate__(self, state):
        config, tokenizer = state
        self.__init__(config, tokenizer)

    @property
    def seq_length(self):
        return self.config.seq_length

    @property
    def tokenizer(self):
        return self._tokenizer

    @property
    def hf_dataset(self):
        return self._hf_datset

    @property
    def pt_dataset(self):
        return self._pt_datset

    @property
    def vocab_size(self):
        return len(self._tokenizer)

class CoHDataCollator(DefaultDataCollator):
    def __call__(self,features: List[Dict[str, Any]], return_tensors=None) -> Dict[str, Any]:

        #padding
        if self.return_tensors[-1].method != "coh":
            key_list = ["input_ids","good_loss_masks","bad_loss_masks"]
            collated = {k: [] for k in key_list}
            for x in features:
                for k, v in x.items():
                    if k in key_list:
                        collated[k].append(v.view(-1))
            collated["input_ids"] = torch.nn.utils.rnn.pad_sequence(collated["input_ids"], batch_first=True, padding_value=self.return_tensors[0].pad_token_id)
            collated["good_loss_masks"] = torch.nn.utils.rnn.pad_sequence(collated["good_loss_masks"], batch_first=True, padding_value=0)
            collated["bad_loss_masks"] = torch.nn.utils.rnn.pad_sequence(collated["bad_loss_masks"], batch_first=True, padding_value=0)
        else:
            key_list = ["input_ids","masks"]
            collated = {k: [] for k in key_list}
            for x in features:
                for k, v in x.items():
                    if k in key_list:
                        collated[k].append(v.view(-1))
            collated["input_ids"] = torch.nn.utils.rnn.pad_sequence(collated["input_ids"], batch_first=True, padding_value=self.return_tensors[0].pad_token_id)
            collated["masks"] = torch.nn.utils.rnn.pad_sequence(collated["masks"], batch_first=True, padding_value=0)
        collated = {k: v for k, v in collated.items()}

        return collated



