from typing import Union, Tuple, Dict
from omegaconf import DictConfig

import random

import torch
from torch.utils.data import DataLoader

from transformers import AutoTokenizer

from data.base import BaseDataset


class FEVERDataset(BaseDataset):

    def __getitem__(self, idx) -> Dict[str, Dict[str, torch.LongTensor]]:
        row = self.data[idx]

        prompt = row["prompt"]
        equiv_prompt = random.choice(row["equiv_prompt"])
        unrel_prompt = row["unrel_prompt"]
        alt = row["alt"]
        ans = row["unrel_ans"]

        return {
            "edit_tuples": self.tok_tuples(prompt, alt),
            "equiv_tuples": self.tok_tuples(equiv_prompt, alt),
            "unrel_tuples": self.tok_tuples(unrel_prompt, ans)
        }
    
    def tok_tuples(
        self,
        prompt: str,
        answer: str
    ) -> Dict[str, torch.LongTensor]:
        
        tok_tuples = self.tok(
            prompt,
            max_length = 512,
            return_tensors = "pt",
            truncation = True
        )

        tok_tuples["labels"] = torch.FloatTensor([[answer == "SUPPORTS"]])

        return tok_tuples

def make_loader(
    config: DictConfig,
    device: Union[str, int, torch.device]
) -> Tuple[DataLoader]:
    
    tok = AutoTokenizer.from_pretrained(config.model.name_or_path)

    train_set = FEVERDataset(
        config.data,
        config.data.train_path,
        tok,
        device
    )

    valid_set = FEVERDataset(
        config.data,
        config.data.valid_path,
        tok,
        device
    )

    train_loader = DataLoader(
        train_set,
        config.data.n_edits,
        True,
        collate_fn = train_set.collate_fn,
        drop_last = True
    )

    valid_loader = DataLoader(
        valid_set,
        config.data.n_edits,
        True,
        collate_fn = valid_set.collate_fn,
        drop_last = True
    )

    return train_loader, valid_loader