import re
import json
from pathlib import Path
from typing import Dict, Sequence, Tuple
from dataclasses import dataclass

import torch
from torch.utils.data import Dataset, DataLoader
from transformers.tokenization_utils import PreTrainedTokenizer

from .build_linear import build_data
from .contrastive import get_negative
from .utils import encode, build_inputs


IGNORE_INDEX = -100


class SupervisedDataset(Dataset):
    """Dataset for supervised fine-tuning."""

    def __init__(self, args, tokenizer):
        super(SupervisedDataset, self).__init__()

        self.split_digit = args.split_digit

        self.data = build_data(seed=args.seed, path=args.data_path)[args.data_split]

        # rank0_print("Formatting inputs...Skip in lazy mode")
        self.tokenizer = tokenizer
        self.max_length = args.max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, i) -> Tuple[torch.Tensor, int]:
        row = self.data[i]["conv"]
        text = self.tokenizer.apply_chat_template(
            [
                {"role": "user", "content": row[0]},
                {"role": "assistant", "content": row[1]},
            ],
            tokenize=False,
        )
        prefix = self.tokenizer.apply_chat_template(
            [
                {"role": "user", "content": row[0]},
            ],
            add_generation_prompt=True,
            tokenize=False,
        )
        # input_ids = self.tokenizer(text, return_tensors="pt").input_ids[0]
        # prefix_input_ids = self.tokenizer(prefix, return_tensors="pt").input_ids[0]
        input_ids = encode(self.tokenizer, text, split_digit=self.split_digit)
        prefix_input_ids = encode(self.tokenizer, prefix, split_digit=self.split_digit)
        return input_ids, len(prefix_input_ids)


def is_number_regex(input_str):
    # return bool(re.fullmatch(r"[0-9]+", s))
    flag = bool(re.match(r"^[0-9]+$", input_str))
    if not flag:
        return False
    return str(int(input_str)) == input_str


@dataclass
class DataCollatorForSupervisedDataset:
    """Collate examples for supervised fine-tuning."""

    tokenizer: PreTrainedTokenizer
    max_length: int = 2048
    use_cont_loss: bool = False
    split_digits: bool = True

    def build_neg_samples(self, instances: Sequence[Tuple]) -> Sequence[Tuple]:
        device = instances[0][0].device
        assert self.split_digits, "Split digits must be enabled for negative samples"

        neg_instances = []
        for input_ids, prefix_len in instances:
            negative = get_negative(self.tokenizer, input_ids, prefix_len)
            neg_instances.append((negative, prefix_len))
        return [*instances, *neg_instances]

    def __call__(self, instances: Sequence[Tuple]) -> Dict[str, torch.Tensor]:
        tokenizer = self.tokenizer

        if self.use_cont_loss:
            instances = self.build_neg_samples(instances)

        input_ids, prefix_lens = zip(*instances)
        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids, batch_first=True, padding_value=tokenizer.pad_token_id
        )
        # remove prefix
        labels = input_ids.clone()
        for i, prefix_len in enumerate(prefix_lens):
            labels[i, :prefix_len] = IGNORE_INDEX
        attention_mask = input_ids.ne(tokenizer.pad_token_id)
        input_ids = input_ids[:, : self.max_length]
        labels = labels[:, : self.max_length]
        attention_mask = attention_mask[:, : self.max_length]
        labels[labels == tokenizer.pad_token_id] = IGNORE_INDEX
        batch = dict(
            input_ids=input_ids,
            labels=labels,
            attention_mask=attention_mask,
        )

        return batch


def load_data(args, tokenizer):
    data = SupervisedDataset(args, tokenizer)
    collator = DataCollatorForSupervisedDataset(
        tokenizer,
        max_length=args.max_length,
        use_cont_loss=args.use_cont_loss,
        split_digits=args.split_digit,
    )
    return data, collator
