import torch
import string
from torch.utils.data import Dataset
from datasets import load_dataset
import copy
from tqdm import tqdm

ALPHBET = list(string.ascii_uppercase)
IGNORE_INDEX = -100

class SQUADDataset(Dataset):
    def __init__(self, dataset_name, tokenizer, partition='train'):
        
        self.dataset = load_dataset(dataset_name)[partition]
        self.tokenizer = tokenizer
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, index):
        instance = self.dataset[index]
        
        context = instance['context']
        question = instance['question']
        answer = instance['answers']['text'][0]

        prompt = f'USER: {context}\n{question} Answer the question by using a single word or a single phrase.\nASSISTANT:'
        example = prompt + answer
        
        prompt = torch.tensor(
            self.tokenizer.encode(prompt), dtype=torch.int64
        )
        example = self.tokenizer.encode(example)
        example.append(self.tokenizer.eos_token_id)
        example = torch.tensor(
            example, dtype=torch.int64
        )
        labels = copy.deepcopy(example)
        labels[: len(prompt)] = -1
        example_mask = example.ge(0)
        label_mask = labels.ge(0)
        example[~example_mask] = 0
        labels[~label_mask] = IGNORE_INDEX

        return {
            "input_ids": example,
            "labels": labels,
            "attention_mask":example_mask,
            "prompt": f'USER: {context}\n{question} Answer the question by using a single word or a single phrase.\nASSISTANT:',
            "example": f'USER: {context}\n{question} Answer the question by using a single word or a single phrase.\nASSISTANT:' + answer,
        }

class MMLUDataset(Dataset):
    def __init__(self, dataset_name, tokenizer, config_name='all', partition='train'):
        '''
            dataset_name: A hugging face dataset name card, which is supposed to point to some QA dataset
                - rajpurkar/squad: SQuAD dataset
                - cais/mmlu: MMLU dataset
                - truthfulqa/truthful_qa: Truthful QA
        '''
        self.dataset = load_dataset(dataset_name, config_name)[partition]
        self.tokenizer = tokenizer
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, index):
        instance = self.dataset[index]
        question = instance['question']
        subject = instance['subject'].replace('_', ' ')
        choices = instance['choices']
        
        
        instruct = f"Please give your answer for this {subject} multiple choice question. Remember to give only the answer alphbet index and do not give any explanation."
        question = f"Question:\n{question}"
        choices = [char + '. ' + choice for char, choice in zip(ALPHBET[:len(choices)], choices)]
        
        prompt = "\n\n".join([instruct, question, "\n".join(['Choices:']+choices), 'Answer: '])
        example = prompt + choices[instance['answer']]
        
        prompt = torch.tensor(
            self.tokenizer.encode(prompt), dtype=torch.int64
        )
        example = self.tokenizer.encode(example)
        example.append(self.tokenizer.eos_token_id)
        example = torch.tensor(
            example, dtype=torch.int64
        )
        labels = copy.deepcopy(example)
        labels[: len(prompt)] = -1
        example_mask = example.ge(0)
        label_mask = labels.ge(0)
        example[~example_mask] = 0
        labels[~label_mask] = IGNORE_INDEX

        return {
            "input_ids": example,
            "labels": labels,
            "attention_mask":example_mask,
            "prompt": "\n\n".join([instruct, question, "\n".join(['Choices:']+choices), 'Answer: ']),
            "example": "\n\n".join([instruct, question, "\n".join(['Choices:']+choices), 'Answer: ']) + choices[instance['answer']],
        }
        
        
class ConcatDataset(Dataset):
    def __init__(self, dataset, chunk_size=4096):
        self.dataset = dataset
        self.chunk_size = chunk_size

        self.samples = []

        buffer = {
            "input_ids": [],
            "attention_mask": [],
            "labels": [],
            }

        for sample in tqdm(self.dataset, desc="Preprocessing dataset", dynamic_ncols=True):
            buffer = {k: v + sample[k].squeeze().tolist() for k,v in buffer.items()}

            while len(next(iter(buffer.values()))) > self.chunk_size:
                self.samples.append({k: v[:self.chunk_size] for k,v in buffer.items()})
                buffer = {k: v[self.chunk_size:] for k,v in buffer.items()}

    def __getitem__(self, idx):
        return self.samples[idx]

    def __len__(self):
        return len(self.samples)
'''
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-3.2-1B-Instruct')
from data import MMLUDataset
dataset = MMLUDataset('cais/mmlu', tokenizer, partition='test')
model = AutoModelForCausalLM.from_pretrained('meta-llama/Llama-3.2-1B-Instruct')
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)
for instance in dataset:
    prompt = instance['prompt']
    tokens = torch.tensor(tokenizer.encode(prompt), dtype=torch.int64).to(device).unsqueeze(0)
    attention_mask = torch.ones_like(tokens)
    outputs = model.generate(input_ids=tokens, attention_mask=attention_mask,)
    output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print(output_text)
    break
'''