from torch.utils.data import Dataset
from transformers import AutoTokenizer
import torch
import jsonlines
import random
import os
import numpy as np


class VIRTDataset(Dataset):
    def __init__(self, type, split, filtered, data_path=''):
        super().__init__()
        self.data_path = data_path
        self.type = type
        self.filtered = filtered
        self.split = split
        if self.filtered:
            with jsonlines.open(os.path.join(self.data_path, f'{self.type}_filtered_{self.split}.jsonl'), 'r') as f:
                self.data = [o for o in f]
        else:
            with jsonlines.open(os.path.join(self.data_path, f'{self.type}_{self.split}.jsonl'), 'r') as f:
                self.data = [o for o in f]

        
    def __getitem__(self, i):
        y = torch.tensor(self.data[i]['response'])
        return {'y': y}

    def __len__(self):
        return len(self.data)
    

def shuffle_examinee_collate_fn(batch):
    y = torch.stack([sample['y'] for sample in batch], dim=0)
    per = np.random.permutation(y.size(1))
    return {'y': y[:, per]}


class IRTGenDataset(Dataset):
    def __init__(self, type, split, tokenizer, max_length=64, data=None, data_path=''):
        super().__init__()
        self.type = type
        self.split = split
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.data = data
        self.data_path = data_path
        if not self.data:
            with jsonlines.open(os.path.join(self.data_path, f'{self.type}_gen_{self.split}.jsonl'), 'r') as f:
                self.data = [o for o in f]
        
    def __getitem__(self, i):
        # <s>prompt to be generated.</s>
        text = self.tokenizer(self.tokenizer.bos_token + self.data[i]['prompt'] + self.tokenizer.eos_token,
                                      padding='max_length',
                                      truncation=True,
                                      max_length=self.max_length,
                                      add_special_tokens=False)
        sample = {
            'input_ids': text['input_ids'],
            'parameters': [round(float(i), 4) for i in self.data[i]['parameters']],
            'attention_mask': text['attention_mask'],
            'labels': [text['input_ids'][i] if text['attention_mask'][i] else -100 for i in range(len(text['input_ids']))]
        }
        return sample

    def __len__(self):
        return len(self.data)
