
"""
Convenience wrappers around classification datasets
"""
import torch as t

from abc import abstractmethod
from enum import Enum
from datasets import load_dataset
from transformers import AutoTokenizer
from collections import OrderedDict
from torch.utils.data import DataLoader, Dataset

# List of datasets available in this module
dsets = [
    "boolq",
    "obqa",
    "arc",
    "winogrande",
    "cqa",
    "cola",
    "mnli",
    "mrpc",
    "qnli",
    "qqp ",
    "rte ",
    "sst2",
    "wnli",
]


class ClassificationDataset:
    """
    An abstract base dataset for sequence classification problems. Multiple
    choice QA problems could also be made a subclass of this class with an
    appropriate collation / formatting.
    """

    def __init__(
        self,
        dset,
        tokenizer,
        n_labels: int,
        preamble: str = "",
        add_space: bool = False,
        numerical: bool = True,
        boolean: bool = False,
        few_shot: bool = False,
        max_len: int = 1024,
    ):
        """
        Args:
            dset: The loaded Dataset
            tokenizer: The model tokenizer
            n_labels: The number of labels / classes for each question
            preamble: Preamble for general pre-trained / 'CausalLM' models
            add_space: Add an explicit space suffix between preamble and answer tokens.
            numerical: whether labels are numerical (0, 1, etc.) or alphabetical (A, B, etc.)
            boolean: whether the labels are boolean (0, 1)
            few_shot: whether to use few-shot prompting (if available)
            max_len: the matximum length of the prompt.
        """
        self.dset = dset
        self.n_labels = n_labels
        self.preamble = preamble
        self.add_space = add_space
        self.tokenizer = tokenizer
        self.numerical = numerical
        self.few_shot = few_shot
        self.max_len = max_len

        spc = " " if self.add_space else ""

        # 1. Build up the token IDS of the class labels.
        if numerical and boolean:
            raise ValueError("Question type cannot be both numerical and boolean")
        if boolean:
            labels = [f"{spc}True", f"{spc}False"]
        elif numerical:
            labels = [f"{spc}{i}" for i in range(self.n_labels)]
        else:  # alphabetical
            labels = [f"{spc}{chr(ord('A')+i)}" for i in range(self.n_labels)]
        self.target_ids = tokenizer(
            labels, return_tensors="pt", add_special_tokens=False
        ).input_ids[:, -1:]
        assert (
            self.target_ids.unique().numel() == self.target_ids.numel()
        ), "Target label IDS are not unique! Try changing add_space or numerical."

        # 2. Get a mapping from the label indices (e.g. 0, 1, 2, etc.) to the
        # target token ids from above (e.g. 345, 346, etc.).
        # That is; {(0, 345), (1, 346), etc}
        self.label_idx2target_id = OrderedDict(
            [(i, self.target_ids[i]) for i in range(n_labels)]
        )
        self.target_id2label_idx = OrderedDict(
            [(self.target_ids[i], i) for i in range(n_labels)]
        )

    @abstractmethod
    def sc_collate_fn(self, batch):
        """Collate function for sequence classification models"""
        raise NotImplementedError

    def sc_loader(self, dset: Dataset, *args, **kwargs) -> DataLoader:
        """Returns the dataloader for sequence classification models"""
        return t.utils.data.DataLoader(
            dset, collate_fn=self.sc_collate_fn, *args, **kwargs
        )

    @abstractmethod
    def clm_collate_fn(self, batch):
        """Collate function for causal language models"""
        raise NotImplementedError

    def clm_loader(self, dset: Dataset, *args, **kwargs) -> DataLoader:
        """Returns the dataloader for causal language models"""
        return t.utils.data.DataLoader(
            dset, collate_fn=self.clm_collate_fn, *args, **kwargs
        )

    def loader(
        self,
        *args,
        is_sc: bool = False,
        split: str = "train",
        subset_size: int = -1,
        subset_seed: int = 42,
        grad_acc_steps: int = 1,
        drop_last: bool = True,
        **kwargs,
    ):
        if subset_size > 0:
            subset_size = (
                len(self.dset[split])
                if len(self.dset[split]) < subset_size
                else subset_size
            )
            dset = self.dset[split].shuffle(seed=subset_seed).select(range(subset_size))
        else:
            dset = self.dset[split]

        kwargs = {"batch_size": 1, "drop_last": drop_last} | kwargs
        assert (
            kwargs["batch_size"] % grad_acc_steps == 0
        ), "batch size must be divisible by gradient accumulation steps"
        kwargs["batch_size"] = kwargs["batch_size"] // grad_acc_steps

        if is_sc:
            return self.sc_loader(dset, *args, **kwargs)
        else:
            return self.clm_loader(dset, *args, **kwargs)


class WinograndeSplit(Enum):
    XS = "winogrande_xs"
    S = "winogrande_s"
    M = "winogrande_m"
    L = "winogrande_l"
    XL = "winogrande_xl"


class WinograndeDataset(ClassificationDataset):
    def __init__(
        self,
        tokenizer: AutoTokenizer,
        name: WinograndeSplit = WinograndeSplit.S,
        add_space: bool = True,
        few_shot: bool = False,
        max_len: int = 4096,
    ):
        dset = load_dataset("winogrande", name.value)
        prompt = self.few_shot_preamble if few_shot else self.zero_shot_preamble
        super().__init__(
            dset,
            tokenizer,
            2,
            prompt,
            add_space,
            numerical=False,
            few_shot=few_shot,
            max_len=max_len,
        )

    few_shot_preamble = """Return the label of the correct answer for each question below.

Adam put handwash only clothes in the washer but Aaron washed them by hand as _ was lazy.
Choices:
A) Adam
B) Aaron
Answer: A

Steven proudly showed Michael the mangoes he grew himself all this summer. _ is astonished.
Choices:
A) Stephen
B) Michael
Answer: B

{question}
Choices:
{choices}
Answer:"""

    zero_shot_preamble = """Return the label of the correct answer for the question below.

Question: {question}
Choices:
{choices}
Answer:"""

    def _format_prompts(self, batch):
        prompts = []
        for e in batch:
            choices = f"A) {e['option1']}\nB) {e['option2']}"
            prompts.append(
                self.preamble.format(question=e["sentence"], choices=choices)
            )
        return prompts

    def clm_collate_fn(self, batch):
        prompts = self._format_prompts(batch)
        # prompts = self.tokenizer(prompts, padding=True, return_tensors="pt")
        # prompts = {k: v[:, -self.max_len :] for k, v in prompts.items()}
        classes = t.tensor([int(e["answer"]) - 1 for e in batch])
        targets = t.cat([self.label_idx2target_id[c.item()] for c in classes])
        return prompts, classes, targets

    def sc_collate_fn(self, batch):
        prompts = self._format_prompts(batch)
        # prompts = [e["sentence"] for e in batch]
        # prompts = self.tokenizer(prompts, padding=True, return_tensors="pt")
        # prompts = {k: v[:, -self.max_len :] for k, v in prompts.items()}
        classes = t.tensor([int(e["answer"]) - 1 for e in batch])
        return prompts, classes, None


winogrande = WinograndeDataset



model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
model_path = "YOUR PATH"
tokenizer = AutoTokenizer.from_pretrained(model_path)

def class_to_label(classes, num_labels):
    # dset.n_labels
    problem_type="multi_label_classification"

    labels = t.sum(
        F.one_hot(classes[:, None], num_classes=num_labels), dim=1
    ).to(t.float)
    return labels

dset = winogrande(tokenizer, name=WinograndeSplit.M)
loader = dset.loader(is_sc=False)

count = 0
for batch in loader:
    prompts, classes, targets = batch
    print(prompts, classes, targets)
    count += 1
train_val_idx = int(count // 5 * 4)

import json

inputoutput_json = []

count = 0
for batch in loader:
    if count <= train_val_idx:
        prompts, classes, targets = batch
        inputoutput_json.append({'input': prompts[0], 'output': tokenizer.decode(targets)})
    count += 1
with open(f'../data/winogrande_m_inputoutput_train.json', 'w') as f:
    json.dump(inputoutput_json, f)

inputoutput_json_v = []
count = 0
for batch in loader:
    if count <= train_val_idx:
        count += 1
        continue
    prompts, classes, targets = batch
    inputoutput_json_v.append({'input': prompts[0], 'output': tokenizer.decode(targets)})
    count += 1
with open(f'../data/winogrande_m_inputoutput_val.json', 'w') as f:
    json.dump(inputoutput_json_v, f)
