import torch
from torch import nn
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
import datasets
import os
from utils import get_model_identifiers_from_yaml
import json
import random


def dataset_to_json(dataset, filename, ):
    data_nums = len(dataset)
    with open(filename, "w") as f:
        for i in range(data_nums):
            row_data = dataset[i]
            json_data = json.dumps(row_data)
            f.write(json_data)
            f.write('\n')


# adopt from TOFU: https://github.com/locuslab/tofu/blob/80159d8ea39edf147fb09cd82aefa08e506e6718/data_module.py#L8
def convert_raw_forget_data_to_model_format(tokenizer, max_length, question, answer, model_configs, mask=True):
    question_start_token, question_end_token, answer_token = model_configs[
        'question_start_tag'], model_configs['question_end_tag'], model_configs['answer_tag']
    
    new_question = question_start_token + question + question_end_token
    new_answer = answer_token + answer
    full_text = new_question + new_answer

    num_question_tokens = len(tokenizer.tokenize(new_question, add_special_tokens=True))

    encoded = tokenizer(
        full_text,
        add_special_tokens=True,
        max_length=max_length,
        truncation=True,
    )
    pad_length = max_length - len(encoded.input_ids)
    pad_input_ids = encoded['input_ids'] + [tokenizer.eos_token_id] * pad_length
    pad_attention_mask = encoded['attention_mask'] + [0] * pad_length
    
    if not mask:
        label = pad_input_ids
    else:
        if len(encoded.input_ids) == max_length:
            label = encoded.input_ids
        else:
            label = encoded['input_ids'] + [tokenizer.eos_token_id] + [-100] * (pad_length-1)

        #c hange label to -100 for question tokens
        for i in range(num_question_tokens): label[i] = -100

    return torch.tensor(pad_input_ids), torch.tensor(label), torch.tensor(pad_attention_mask)


def convert_raw_data_to_model_format(tokenizer, max_length,  question, answer, model_configs):
    question_start_token, question_end_token, answer_token = model_configs['question_start_tag'], model_configs['question_end_tag'], model_configs['answer_tag']
    new_question = question_start_token + question + question_end_token
    new_answer = answer_token + answer
    full_text = new_question + new_answer
    num_question_tokens = len(tokenizer.tokenize(new_question, add_special_tokens=True))

    encoded = tokenizer(
        full_text, 
        add_special_tokens=True, 
        max_length=max_length, 
        truncation=True, 
    )
    pad_length = max_length - len(encoded.input_ids)
    pad_input_ids = encoded['input_ids'] + [tokenizer.eos_token_id] * pad_length
    pad_attention_mask = encoded['attention_mask'] + [0] * pad_length
    if len(encoded.input_ids) == max_length:
        label = encoded.input_ids
    else:
        label = encoded['input_ids'] + [tokenizer.eos_token_id] + [-100] * (pad_length-1)

    #change label to -100 for question tokens
    for i in range(num_question_tokens): label[i] = -100

    return torch.tensor(pad_input_ids),torch.tensor(label),torch.tensor(pad_attention_mask)


class TextForgetDatasetQA(Dataset):
    def __init__(self, tokenizer, model_family, forget_data, retain_data, max_length=512, mask=False):
        super(TextForgetDatasetQA, self).__init__()
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.forget_data = forget_data
        self.retain_data = retain_data
        self.mask = mask
     
        
        self.model_configs = get_model_identifiers_from_yaml(model_family)

        self.idontknowfile = "data/idontknow.jsonl"
        with open(self.idontknowfile, "r") as f:
            self.idk = f.readlines()

        self.data_types = ["forget", "retain", "forget_idk", "retain_idk", "forget_mismatch"]

    def __len__(self):
        return len(self.forget_data)

    def __getitem__(self, idx):
        rets = []
        
        torch.manual_seed(idx)
        retain_idx = (idx + torch.randint(0, len(self.retain_data), (1,)).item()) % len(self.retain_data)
        rand_pos = torch.randint(0, len(self.idk), (1,)).item()

        for data_type in self.data_types:
            
            if "retain" in data_type:
                data = self.retain_data
                question = data[retain_idx]['question']
                answer = data[retain_idx]['answer']
            else:
                data = self.forget_data
                question = data[idx]['question']
                answer = data[idx]['answer']
                # retain_question = self.retain_data[retain_idx]['question'] # v1
                retain_question = self.retain_data[idx]['question'] # v2

            if "idk" in data_type:
                answer = self.idk[rand_pos].strip()
            elif "mismatch" in data_type:
                answer = self.retain_data[retain_idx]['answer']
            
            if data_type == 'forget':
                # only consider mask/unmask on the forget loss
                converted_data = convert_raw_forget_data_to_model_format(self.tokenizer, self.max_length, question, answer, self.model_configs, mask=self.mask)
            else:
                converted_data = convert_raw_forget_data_to_model_format(self.tokenizer, self.max_length, question, answer, self.model_configs)
            rets.append(converted_data)
        
        return rets

class TextDatasetQA(Dataset):
    def __init__(self, data_path, tokenizer, model_family, max_length=512, split=None, question_key='question', answer_key='answer'):
        super(TextDatasetQA, self).__init__()
        self.tokenizer = tokenizer
        self.max_length = max_length

        self.data = datasets.load_dataset(
            'json', data_files=os.path.join(data_path, split+'.json'))['train']

        self.model_configs = get_model_identifiers_from_yaml(model_family)
        self.qk = question_key
        self.ak = answer_key

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        question = self.data[idx][self.qk]
        answers = self.data[idx][self.ak]

        if isinstance(answers, str):
            answers = [answers]

        pad_input_ids_list = []
        label_list = []
        pad_attention_mask_list = []

        for answer in answers:
            converted_data = convert_raw_data_to_model_format(
                self.tokenizer, self.max_length, question, answer, self.model_configs)
            pad_input_ids_list.append(converted_data[0])
            label_list.append(converted_data[1])
            pad_attention_mask_list.append(converted_data[2])

        return torch.stack(pad_input_ids_list).squeeze(), \
            torch.stack(label_list).squeeze(), \
            torch.stack(pad_attention_mask_list).squeeze()


def collate_fn(batch):
    input_ids, attention_masks = zip(*batch)
    input_ids = pad_sequence(input_ids, batch_first=True, padding_value=-100)
    attention_masks = pad_sequence(
        attention_masks, batch_first=True, padding_value=0)
    return input_ids, attention_masks


def custom_data_collator(samples):
    input_ids = [s[0] for s in samples]
    labels = [s[1] for s in samples]
    attention_mask = [s[2] for s in samples]
    return torch.stack(input_ids), torch.stack(labels), torch.stack(attention_mask)



def custom_data_collator_forget(samples):
    rets = []
  
    # Extracting samples for each data type
    data_types = ["forget", "retain", "forget_idk", "retain_idk", "forget_mismatch"]
    samples_dict = {data_type: [sample[i] for sample in samples] for i, data_type in enumerate(data_types)}

    for data_type in data_types:
        data = samples_dict[data_type]
        
        input_ids = [s[0] for s in data]
        labels = [s[1] for s in data]
        attention_mask = [s[2] for s in data]
        
        rets.append((torch.stack(input_ids), torch.stack(labels), torch.stack(attention_mask)))
    
    return rets

def get_batch_loss(output, labels):
    shifted_labels = labels[..., 1:].contiguous()
    output = output[..., :-1, :].contiguous()

    loss_function = nn.CrossEntropyLoss(ignore_index=-100, reduction='none')
    # get the sum loss for each sequence in a batch
    loss = loss_function(output.transpose(-1, -2), shifted_labels).sum(dim=-1)

    return loss
