import json
import torch
from datasets import Dataset
from transformers import (
    DataCollatorWithPadding, 
    PreTrainedTokenizerFast,
)
from typing import Any, Dict, List, Union


def load_data(file_path: str) -> Dataset:
    """Load data from a JSON file and format it into a Dataset object.
    
    Args:
        file_path (str): Path to the JSON file containing data.

    Returns:
        Dataset: A Hugging Face Dataset object with formatted data.
    """
    with open(file_path, 'r') as f:
        data = json.load(f)
    formatted_data = []
    for item in data:
        instruction = ' '.join(item['instruction'])
        input_text = ' '.join(item['input'])
        response = ' '.join(item['response'])
        query = f"{instruction} {input_text}"
        full_text = f"{instruction} {input_text} {response} <eos>"
        formatted_data.append({
            "full_text": full_text,
            "query": query,
            "response": response,
        })
    return Dataset.from_list(formatted_data)


class CustomDataCollator(DataCollatorWithPadding):
    """Custom data collator to preprocess and batch data for training.
    
    This collator handles the creation of input IDs, attention masks, 
    and label tensors for the model.

    Args:
        tokenizer (PreTrainedTokenizerFast): Pre-trained tokenizer.
        padding (Union[bool, str], optional): Padding strategy. Defaults to True.
        max_length (int, optional): Maximum sequence length. Defaults to None.
        pad_to_multiple_of (int, optional): Pad sequence lengths to a multiple of this value. Defaults to None.
    """
    tokenizer: PreTrainedTokenizerFast
    padding: Union[bool, str] = True
    max_length: int = None
    pad_to_multiple_of: int = None

    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
        """Prepare batched data with input IDs, attention masks, and labels.
        
        Args:
            features (List[Dict[str, Any]]): List of feature dictionaries from the dataset.

        Returns:
            Dict[str, torch.Tensor]: Batched tensors for input IDs, attention masks, and labels.
        """
        batch, batch_labels = [], []
        for feature in features:
            full_text = feature["full_text"]
            query = feature["query"]

            # Tokenize the full text
            full_encoding = self.tokenizer(
                full_text,
                padding=False,
                truncation=True,
                max_length=self.max_length,
                return_tensors=None,
            )

            # Tokenize the query text
            query_encoding = self.tokenizer(
                query,
                padding=False,
                truncation=True,
                max_length=self.max_length,
                return_tensors=None,
            )

            query_length = len(query_encoding["input_ids"])
            labels = [-100] * query_length
            labels.extend(full_encoding["input_ids"][query_length:])

            if len(labels) < len(full_encoding["input_ids"]):
                labels.extend([-100] * (len(full_encoding["input_ids"]) - len(labels)))

            batch.append({
                "input_ids": full_encoding["input_ids"],
                "attention_mask": full_encoding["attention_mask"],
            })
            batch_labels.append({
                "input_ids": labels,
            })

        # Pad the batched inputs
        padded_batch = self.tokenizer.pad(
            batch,
            padding=True,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )
        # Pad the batched labels
        padded_labels = self.tokenizer.pad(
            batch_labels,
            padding=True,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )

        padded_batch["labels"] = padded_labels["input_ids"]

        return padded_batch

