import pytorch_lightning as pl
from pytorch_lightning.utilities.data import DataLoader
from datasets import load_from_disk
import torch
import transformers
import copy
from typing import Dict, Optional, Sequence


ALPACA_IGNORE_INDEX = -100
ALPACA_DEFAULT_PAD_TOKEN = "[PAD]"
ALPACA_DEFAULT_EOS_TOKEN = "</s>"
ALPACA_DEFAULT_BOS_TOKEN = "<s>"
ALPACA_DEFAULT_UNK_TOKEN = "<unk>"


class ValidationDataModule(pl.LightningDataModule):

    def __init__(self, batch_size, iid_valset_path, benchmark_observed_path, benchmark_unobserved_path, tokenizer, isAlpaca):
        self.batch_size = batch_size
        self.iid_valset = load_from_disk(iid_valset_path)
        self.benchmark_observed_valset = load_from_disk(benchmark_observed_path)
        self.benchmark_unobserved_valset = load_from_disk(benchmark_unobserved_path)
        self.tokenizer = tokenizer
        self.isAlpaca = isAlpaca
        
    def collate(self, batch):
        input_text = [b["input_text"] for b in batch]
        output_text = [b["output_text"] for b in batch]
        assert len(input_text) == len(output_text)
        if not self.isAlpaca:
            batch = self.tokenizer(text=input_text, text_target=output_text, padding="longest", return_tensors="pt", truncation=True, max_length=512)
        else:
            batch = alpaca_preprocess(input_text, output_text, self.tokenizer)
        return batch    
        
    def iid_dataloader(self):
        return DataLoader(self.iid_valset, batch_size=self.batch_size, collate_fn=self.collate)
    
    def benchmark_observed_dataloader(self):
        return DataLoader(self.benchmark_observed_valset, batch_size=self.batch_size, collate_fn=self.collate)
    
    def benchmark_unobserved_dataloader(self):
        return DataLoader(self.benchmark_unobserved_valset, batch_size=self.batch_size, collate_fn=self.collate)
    
    
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
    """Tokenize a list of strings."""
    tokenized_list = [
        tokenizer(
            text,
            return_tensors="pt",
            padding="longest",
            max_length=tokenizer.model_max_length,
            truncation=True,
        )
        for text in strings
    ]
    input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
    input_ids_lens = labels_lens = [
        tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
    ]
    return dict(
        input_ids=input_ids,
        labels=labels,
        input_ids_lens=input_ids_lens,
        labels_lens=labels_lens,
    )

def alpaca_preprocess(
    sources: Sequence[str],
    targets: Sequence[str],
    tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
    """Preprocess the data by tokenizing."""
    targets = [f"{t} {ALPACA_DEFAULT_EOS_TOKEN}" for t in targets]
    examples = [s + t for s, t in zip(sources, targets)]
    examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
    input_ids = examples_tokenized["input_ids"]
    labels = copy.deepcopy(input_ids)
    for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
        label[:source_len] = ALPACA_IGNORE_INDEX
    input_ids = torch.nn.utils.rnn.pad_sequence(
        input_ids, batch_first=True, padding_value=tokenizer.pad_token_id
    )
    labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=ALPACA_IGNORE_INDEX)
    return dict(input_ids=input_ids, labels=labels, attention_mask=input_ids.ne(tokenizer.pad_token_id))