from dataclasses import dataclass
import itertools
import random
import re

import torch
from torch.utils.data import IterableDataset as TorchIterableDataset

from megatron.training import get_args
from megatron.core.datasets.megatron_tokenizer import MegatronTokenizer
from megatron_datasets.utils import print_rank_0
from gpatch.training.v3.ppo_actor import iter_to_ppo_epoch_step

from gdataset import GDatasetV4


def tokenize_text(tokenizer, prompt_seq_len, prompt):
    assert tokenizer.pad_token is not None
    prompt_tokenized = tokenizer(prompt, add_special_tokens=False)
    input_ids = prompt_tokenized.input_ids
    unpadded_lens = len(input_ids)

    if len(input_ids) > prompt_seq_len:
        input_ids = input_ids[-prompt_seq_len:]
        unpadded_lens = prompt_seq_len
    assert unpadded_lens > 0 and unpadded_lens == len(input_ids)
    return input_ids, unpadded_lens


def extract_gt_answer(text):
    match = re.search(r'####\s*(-?\d+(\.\d+)?)', text)
    if match:
        return float(match.group(1)) if '.' in match.group(1) else int(match.group(1))
    else:
        return None


@dataclass
class PpoActorDataset(TorchIterableDataset):
    def __init__(
        self,
        tokenizer,
        seq_len,
        max_position_embeddings,
        resp_seq_len,
        rollout_micro_batch_size,
        rollout_global_batch_size,
        domain_probabilities,
        domain_names,
        train=False,
        metadata_file=None,
        dp_rank=0,
        dp_size=1,
        shuffling_buffer_size=1000,
        seed=0,
        eos_token=None,
        prompt_format=None,
    ):
        self.seq_len = seq_len
        self.max_position_embeddings = max_position_embeddings
        self.resp_seq_len = resp_seq_len
        self.rollout_micro_batch_size = rollout_micro_batch_size
        self.rollout_global_batch_size = rollout_global_batch_size
        self.tokenizer = tokenizer
        self.train = train
                                               
        self.domain_probabilities = domain_probabilities
        self.domain_names = domain_names
        self.metadata_file = metadata_file
        self.dp_rank = dp_rank
        self.dp_size = dp_size
        self.shuffling_buffer_size = shuffling_buffer_size
        self.seed = seed
        self.in_iter = False

        if eos_token is None:
            self.eos_token = self.tokenizer._tokenizer.eos_token
        else:
            self.eos_token = eos_token
                                                       
        self.prompt_format = "{tokenized_prompt}"
                                                           

        if not self.train:
            assert self.shuffling_buffer_size == 0
        args = get_args()
        self.start_epoch, ppo_step = iter_to_ppo_epoch_step(args.iteration)
        consumed = ppo_step * rollout_global_batch_size
        self.make_underlying(consumed=consumed, epoch=self.start_epoch)

    def make_underlying(self, consumed, epoch):
        self.underlying = GDatasetV4(
            metadata_file=self.metadata_file,
            dp_rank=self.dp_rank,
            dp_size=self.dp_size,
            gbs=self.rollout_global_batch_size,
            shuffling_buffer_size=self.shuffling_buffer_size,
            seed=self.seed,
            consumed=consumed,
        )
        self.underlying.set_epoch(epoch)

    def iter_in_epoch(self, epoch):
        if torch.distributed.get_rank() == 0:
            print(
                f'PpoActorDataset.iter_in_epoch epoch {epoch} train {self.train}')

        for example in self.underlying:

            prompt = example['tokenized_prompt']
                                                

            input_ids, unpadded_lens = tokenize_text(
                self.tokenizer._tokenizer,
                self.seq_len - self.resp_seq_len,
                prompt,
            )

            if 'answer' in example.keys():
                gt_label = example['answer']
            else:
                raise ValueError("ppo actor dataset must have answer field")

            o = {
                'input_ids': input_ids,
                'unpadded_lens': unpadded_lens,
                'train': self.train,
                'gt_label': gt_label,
            }

            yield o

    def __iter__(self):
        assert not self.in_iter
        self.in_iter = True
        for epoch in itertools.count(start=self.start_epoch):
            yield from self.iter_in_epoch(epoch)
            self.make_underlying(consumed=0, epoch=epoch + 1)
        assert False, 'never reachable'


def build_train_valid_test_datasets(
    args, tokenizer, dp_rank=0, dp_size=1, prompt_format=None, eos_token=None
):
    train_metadata_file = args.gdatasetv4_train_metadata_file
    eval_metadata_file = args.gdatasetv4_eval_metadata_file
    domain_probabilities = args.px_domain_probabilities
    domain_names = args.px_train_data_domain_names
    assert args.num_workers <= 1
    assert all([dr == 1.0 for dr in args.px_retention_rates_per_domain])
    assert len(domain_names) == 1                            
    assert len(domain_probabilities) == 1

    train_ds = PpoActorDataset(
        tokenizer=tokenizer,
        seq_len=args.seq_length,
        max_position_embeddings=args.max_position_embeddings,
        resp_seq_len=args.ppo_resp_seq_len,
        rollout_micro_batch_size=args.ppo_rollout_micro_batch_size,
        rollout_global_batch_size=args.ppo_rollout_global_batch_size,
        domain_probabilities=domain_probabilities,
        domain_names=domain_names,
        train=True,
        metadata_file=train_metadata_file,
        dp_rank=dp_rank,
        dp_size=dp_size,
        shuffling_buffer_size=args.px_shuffle_buffer_size,
        seed=args.seed,
        eos_token=eos_token,
        prompt_format=prompt_format,
    )
    eval_ds = None
    if eval_metadata_file is not None:
        eval_ds = PpoActorDataset(
            tokenizer=tokenizer,
            seq_len=args.seq_length,
            max_position_embeddings=args.max_position_embeddings,
            resp_seq_len=args.ppo_resp_seq_len,
            rollout_micro_batch_size=args.ppo_eval_rollout_micro_batch_size,
            rollout_global_batch_size=args.ppo_eval_rollout_global_batch_size,
            domain_probabilities=None,
            domain_names=domain_names,
            train=False,
            metadata_file=eval_metadata_file,
            dp_rank=dp_rank,
            dp_size=dp_size,
            shuffling_buffer_size=0,
            seed=0,
            eos_token=eos_token,
            prompt_format=prompt_format,
        )
    test_ds = None
    return train_ds, eval_ds, test_ds


@dataclass
class DataCollator(object):
    tokenizer: MegatronTokenizer
    seq_len: int                      
    resp_seq_len: int
    gen_left_pad: bool

    def __call__(self, batch):
        pad_token_id = self.tokenizer._tokenizer.pad_token_id

        batch_lens = [item['unpadded_lens'] for item in batch]
        input_ids_lens = [len(item['input_ids']) for item in batch]
        batch_max_len = max(batch_lens)
        lpad_lens = []
        num_lpads = []
        for item in batch:
            input_ids = item['input_ids']
            assert isinstance(input_ids, list)
            to_pad = batch_max_len - len(input_ids)
            assert batch_lens == input_ids_lens and to_pad >= 0, f'wtf batch_lens {batch_lens} input_ids_lens {input_ids_lens}'

            if self.gen_left_pad:
                lpad_lens.append(batch_max_len)
                num_lpads.append(to_pad)
                item['input_ids'] = [pad_token_id] * to_pad + input_ids
            else:
                lpad_lens.append(len(input_ids))
                num_lpads.append(0)
                item['input_ids'] += [pad_token_id] * to_pad
            assert len(input_ids) <= self.seq_len

        input_ids = torch.as_tensor([item['input_ids']
                                    for item in batch], dtype=torch.int64)
        unpadded_lens = torch.as_tensor(
            [item['unpadded_lens'] for item in batch], dtype=torch.int64
        )
        train = torch.as_tensor([item['train']
                                for item in batch], dtype=torch.bool)
        num_lpads = torch.as_tensor(num_lpads, dtype=torch.int64)
        lpad_lens = torch.as_tensor(lpad_lens, dtype=torch.int64)
                                                                         
                                             

                                                      
                                                                           
        gt_label = [item['gt_label'] for item in batch]
        ret = dict(
            input_ids=input_ids,
            unpadded_lens=unpadded_lens,
            num_lpads=num_lpads,
            lpad_lens=lpad_lens,
            train=train,
            gt_label=gt_label,
        )

        return ret
