from dataclasses import dataclass
import numpy as np
import json
from typing import Any, Optional, Union

import torch
from datasets import Dataset, DatasetDict
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy
from transformers.data.data_collator import pad_without_fast_tokenizer_warning

IGNORE_INDEX = -100
def preprocess_train_dataset(examples, tokenizer, cutoff_len):
    model_inputs = {
        "input_ids": [], 
        "labels": [], 
        "lora_B_idx": [],
        "source_input_ids": [],
        "target_input_ids": [],
    }

    for i in range(len(examples["instruction"])):
        
        # for instruct model
        chat_input = [{"role": "user", "content": examples["instruction"][i]}]
        input_ids = tokenizer.apply_chat_template(chat_input, add_generation_prompt=True)
        source_mask = [IGNORE_INDEX] * len(input_ids)
    
        labels = source_mask
        
        response = examples["output"][i]
        response_ids = tokenizer.encode(response, add_special_tokens=False)
        input_ids += response_ids
        labels += response_ids
        
        # add eos token
        input_ids += [tokenizer.eos_token_id]
        labels += [tokenizer.eos_token_id]
        
        if len(input_ids) > cutoff_len:
            input_ids = input_ids[:cutoff_len]
            labels = labels[:cutoff_len]
        
        model_inputs["input_ids"].append(input_ids)
        model_inputs["labels"].append(labels)
        model_inputs["lora_B_idx"].append(examples["lora_B_idx"][i])
        
        source_input_ids = tokenizer.encode(examples['source_text'][i])
        target_input_ids = tokenizer.encode(examples['target_text'][i])
        if len(source_input_ids) > cutoff_len:
            source_input_ids = source_input_ids[:cutoff_len]
        if len(target_input_ids) > cutoff_len:
            target_input_ids = target_input_ids[:cutoff_len]
        
        model_inputs['source_input_ids'].append(source_input_ids)
        model_inputs['target_input_ids'].append(target_input_ids)

    return model_inputs

  

@dataclass
class CustomDataCollatorForSeq2SeqForTrain:
    tokenizer: PreTrainedTokenizerBase
    model: Optional[Any] = None
    padding: Union[bool, str, PaddingStrategy] = True
    max_length: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    label_pad_token_id: int = -100
    return_tensors: str = "pt"

    def __call__(self, features, return_tensors=None):
        if return_tensors is None:
            return_tensors = self.return_tensors

        label_name = "label" if "label" in features[0].keys() else "labels"
        labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None

        input_dict = {
            "input_ids": [feature['input_ids'] for feature in features]
        }
        
        source_dict = {
            "input_ids": [feature['source_input_ids'] for feature in features]
        }
        
        target_dict = {
            "input_ids": [feature['target_input_ids'] for feature in features]
        }
    
        # run through tokenizer without labels to ensure no side effects
        input_dict = pad_without_fast_tokenizer_warning(
            self.tokenizer,
            input_dict,
            padding=self.padding,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors=return_tensors,
            return_attention_mask=True,
        )
        
        source_dict = pad_without_fast_tokenizer_warning(
            self.tokenizer,
            source_dict,
            padding=self.padding,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors=return_tensors,
            return_attention_mask=True,
        )
        
        target_dict = pad_without_fast_tokenizer_warning(
            self.tokenizer,
            target_dict,
            padding=self.padding,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors=return_tensors,
            return_attention_mask=True,
        )
        
        lora_B_idx = [feature['lora_B_idx'] for feature in features]
        lora_B_idx = torch.tensor(lora_B_idx, dtype=torch.int64)
        
        
        batch = {
            "input_ids": input_dict['input_ids'],
            "attention_mask": input_dict["attention_mask"],
            "source_input_ids": source_dict["input_ids"],
            "source_attention_mask": source_dict["attention_mask"],
            "target_input_ids": target_dict["input_ids"],
            "target_attention_mask": target_dict["attention_mask"],
            "lora_B_idx": lora_B_idx,
        }
        
        no_padding = self.padding is False or self.padding == PaddingStrategy.DO_NOT_PAD
        if labels is not None:
            if no_padding:
                if isinstance(features[0][label_name], list):
                    batch["labels"] = list(labels)
                else:
                    batch["labels"] = [np.concatenate([label, []]) for label in labels]
            else:
                max_padding = self.padding == PaddingStrategy.MAX_LENGTH and self.max_length is not None
                max_label_length = max(len(l) for l in labels) if not max_padding else self.max_length
                if self.pad_to_multiple_of is not None:
                    max_label_length = (
                        (max_label_length + self.pad_to_multiple_of - 1)
                        // self.pad_to_multiple_of
                        * self.pad_to_multiple_of
                    )

                padding_side = self.tokenizer.padding_side
                if isinstance(features[0][label_name], list):
                    batch["labels"] = [
                        label + [self.label_pad_token_id] * (max_label_length - len(label))
                        if padding_side == "right"
                        else [self.label_pad_token_id] * (max_label_length - len(label)) + label
                        for label in labels
                    ]
                else:
                    batch["labels"] = [
                        np.concatenate(
                            [
                                label,
                                np.array([self.label_pad_token_id] * (max_label_length - len(label)), dtype=np.int64),
                            ]
                        )
                        if padding_side == "right"
                        else np.concatenate(
                            [
                                np.array([self.label_pad_token_id] * (max_label_length - len(label)), dtype=np.int64),
                                label,
                            ]
                        )
                        for label in labels
                    ]

        # reintroduce side effects via tokenizer that return respective datatypes for the `return_tensors` argument
        if batch.get("labels", None) is not None:
            batch["labels"] = torch.tensor(batch["labels"], dtype=torch.int64)

        return batch


DATASET_INFO = {
    "elife": {
        "train": "/root/lorax/data/elife/train.json",
        "test": "/root/lorax/data/elife/test.json",
    },
    "cochrane": {
        "train": "/root/lorax/data/cochrane/train.json",
        "test": "/root/lorax/data/cochrane/test.json",
    },
    "genetics": {
        "train": "/root/lorax/data/plos_genetics/train.json",
        "test": "/root/lorax/data/plos_genetics/test.json",
    },
}

Dataset2loraB = {
    "elife": 0,
    "cochrane": 1,
    "genetics": 2,
}


INST = r"""Simplify the following medical text.
Medical text: {source}"""


def prepare_datasets(dataset_name):

    if isinstance(dataset_name, str):
        dataset_name = [dataset_name]
    
    train_source = []
    train_target = []
    train_lora_B_idx = []
    for ds_name in dataset_name:
        train_json = json.load(open(DATASET_INFO[ds_name]["train"], "r"))

        for example in train_json:
            train_source.append(example["source"])
            train_target.append(example["target"])
            train_lora_B_idx.append(Dataset2loraB[ds_name])
    
    train_dataset = Dataset.from_dict({
        "source_text": train_source,
        "target_text": train_target,
        "lora_B_idx": train_lora_B_idx,
        "instruction": [INST.format(source=source) for source in train_source],
        "output": train_target,
    })

    return train_dataset









