from omegaconf import DictConfig

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import lightning.pytorch as pl

from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_from_disk, Dataset

from .collators import DataCollatorForAIM


def load_prepared_dataset(path: str) -> Dataset:
    """
    Loads a prepared dataset from a file.

    We expect the dataset to contain the following columns:
      - "teacher_input_ids": list of input ids
      - "teacher_attention_mask": list of attention masks
      - "teacher_word_ids": list of word ids (of the same length as "teacher_input_ids")
      - "student_input_ids": list of input ids
      - "student_attention_mask": list of attention masks
      - "student_word_ids": list of word ids (of the same length as "student_input_ids")

    Args:
        path (str): The path to the prepared dataset file.

    Returns:
        Dataset: The prepared dataset.
    """
    return load_from_disk(path)


# TODO: curriculum learning based on the amount of new tokens in the student input?
class AIMDataModule(pl.LightningDataModule):
    def __init__(self, cfg: DictConfig):
        super().__init__()
        self.cfg = cfg

        self.train_dataset: Dataset | None = None
        self.collator: DataCollatorForAIM | None = None

    def prepare_data(self) -> None:
        pass

    def setup(self, stage: str | None = None) -> None:
        self.train_dataset = load_prepared_dataset(self.cfg.datamodule.train_dataset_path)

        self.teacher_tokenizer = AutoTokenizer.from_pretrained(self.cfg.model.teacher_checkpoint_path)
        self.student_tokenizer = AutoTokenizer.from_pretrained(self.cfg.model.student_checkpoint_path)

        # if pad token is not set, set it to eos token
        if self.teacher_tokenizer.pad_token_id is None:
            self.teacher_tokenizer.pad_token_id = self.teacher_tokenizer.eos_token_id
        if self.student_tokenizer.pad_token_id is None:
            self.student_tokenizer.pad_token_id = self.student_tokenizer.eos_token_id

        self.collator = DataCollatorForAIM(
            teacher_tokenizer=self.teacher_tokenizer,
            student_tokenizer=self.student_tokenizer,
        )

    def train_dataloader(self) -> DataLoader:
        return DataLoader(
            self.train_dataset,
            batch_size=self.cfg.datamodule.batch_size,
            num_workers=self.cfg.datamodule.num_workers,
            collate_fn=self.collator,
            shuffle=True,
        )
