import random
import torch
from transformers import AutoTokenizer
from torch.utils.data import Dataset
from tqdm import tqdm
from .utils import read_jsonl_file


class PretrainData(Dataset):
    def __init__(self, text_files: list[str], llama_path: str, max_length: int, num_mem: int, mem_id: int):
        self.tokenizer = AutoTokenizer.from_pretrained(llama_path, trust_remote_code=True)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        if self.tokenizer.bos_token_id == 128000:
            self.bos_token_id = 128000
            self.eos_token_id = 128001
        else:
            self.bos_token_id = self.tokenizer.bos_token_id
            self.eos_token_id = self.tokenizer.eos_token_id
        self.max_length = max_length
        self.num_mem = num_mem
        self.mem_id = mem_id

        self.json = []
        self.text = []
        for text_file in text_files:
            self.json += read_jsonl_file(text_file)
        for one in self.json:
            for chunk in one["chunks"]:
                if len(chunk) < 100:
                    test = self.tokenizer(
                        chunk,
                        return_tensors="pt",
                        add_special_tokens=False
                    ).input_ids.squeeze()
                    if test.shape == torch.Size([]):
                        continue
                self.text.append(chunk)
        self.text = list(set(self.text))

    def __len__(self):
        return len(self.text)

    def __getitem__(self, idx):
        input_ids = self.tokenizer(
            self.text[idx],
            return_tensors="pt",
            add_special_tokens=False
        ).input_ids.squeeze()
        prompt_ids = self.tokenizer(
            "Please repeat the previous content to me: ",
            return_tensors="pt",
            add_special_tokens=False
        ).input_ids.squeeze()
        padding_length = self.max_length - input_ids.size(0)
        if padding_length > 0:
            encoder_input_ids = torch.cat(
                [input_ids,
                 torch.full((self.num_mem,), self.mem_id, dtype=torch.long),
                 torch.full((padding_length,), self.eos_token_id, dtype=torch.long)])
            decoder_input_ids = torch.cat(
                [torch.full((1,), self.bos_token_id, dtype=torch.long),
                 prompt_ids,
                 input_ids,
                 torch.full((padding_length,), self.eos_token_id, dtype=torch.long)])
        else:
            encoder_input_ids = torch.cat(
                [input_ids[:self.max_length],
                 torch.full((self.num_mem,), self.mem_id, dtype=torch.long)])
            decoder_input_ids = torch.cat(
                [torch.full((1,), self.tokenizer.bos_token_id, dtype=torch.long),
                 prompt_ids,
                 input_ids[:self.max_length]])

        decoder_label = decoder_input_ids.clone()
        text_eos_tokens = prompt_ids.tolist() + input_ids.tolist()
        text_eos_tokens.append(self.eos_token_id)
        if len(text_eos_tokens) > decoder_input_ids.size(0):
            text_eos_tokens = text_eos_tokens[:decoder_input_ids.size(0)]
        text_eos_tokens_len = len(text_eos_tokens)
        decoder_label[:text_eos_tokens_len] = torch.tensor(text_eos_tokens, dtype=torch.long)
        decoder_label[:prompt_ids.shape[0]] = -100
        if padding_length > 0:
            decoder_label[-padding_length:] = -100
        return {"encoder_input_ids": encoder_input_ids.to(torch.long), "decoder_input_ids": decoder_input_ids.to(torch.long), "decoder_label": decoder_label.to(torch.long)}


class FTData(Dataset):
    def __init__(
            self,
            text_files: list[str],
            llama_path: str,
            max_chunk_length: int,
            max_chunk_num: int,
            max_qa_length: int,
            num_mem: int,
            mem_id: int,):
        self.text = []
        for text_file in text_files:
            self.text += read_jsonl_file(text_file)

        self.tokenizer = AutoTokenizer.from_pretrained(llama_path, trust_remote_code=True)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        if self.tokenizer.bos_token_id == 128000:
            self.bos_token_id = 128000
            self.eos_token_id = 128009
        else:
            self.bos_token_id = self.tokenizer.bos_token_id
            self.eos_token_id = self.tokenizer.eos_token_id
        
        self.max_chunk_length = max_chunk_length
        self.max_chunk_num = max_chunk_num
        self.max_qa_length = max_qa_length
        self.num_mem = num_mem
        self.mem_id = mem_id
        self.system_prompt = "You are an accurate and reliable AI assistant capable of answering questions by referencing external documents. Please note that the external documents may not always be related to the question. If the information in the documents contain the correct answer, you will provide an accurate response. If the documents do not contain the answer, you will refuse to answer."
        self.user_prompt = """The documents are as follows:
<chunks>

Question: """

    def __len__(self):
        return len(self.text)

    def __getitem__(self, idx):
        chunks = self.text[idx]["chunks"]
        chunk_input_ids = []
        for chunk in chunks:
            chunk_tokens = self.tokenizer(
                chunk,
                truncation=True,
                # padding="max_length",
                max_length=self.max_chunk_length,
                return_tensors="pt",
                add_special_tokens=False).input_ids.squeeze()
            if len(chunk_tokens.shape) == 0:
                continue
            c_padding_length = self.max_chunk_length - chunk_tokens.size(0)
            if c_padding_length > 0:
                c_input_ids = torch.cat(
                    [chunk_tokens,
                    torch.full((self.num_mem,), self.mem_id, dtype=torch.long),
                    torch.full((c_padding_length,), self.eos_token_id, dtype=torch.long)])
            else:
                c_input_ids = torch.cat(
                    [chunk_tokens[:self.max_chunk_length],
                    torch.full((self.num_mem,), self.mem_id, dtype=torch.long)])
            chunk_input_ids.append(c_input_ids)
        if len(chunk_input_ids) > self.max_chunk_num:
            chunk_input_ids = chunk_input_ids[:self.max_chunk_num]
        
        chunk_mask = torch.zeros(self.max_chunk_num, dtype=torch.bool)
        chunk_mask[-len(chunk_input_ids):] = True
        if len(chunk_input_ids) < self.max_chunk_num:
            chunk_pad_num = self.max_chunk_num - len(chunk_input_ids)
            chunk_input_ids = chunk_pad_num * [torch.cat(
                    [torch.full((self.max_chunk_length,), self.eos_token_id, dtype=torch.long),
                    torch.full((self.num_mem,), self.mem_id, dtype=torch.long)])] + chunk_input_ids
        chunk_input_ids = torch.stack(chunk_input_ids, dim=0)
        # answer tokens
        messages = [
            {"role": "system", "content": self.system_prompt},
            {"role": "user", "content": self.user_prompt + self.text[idx]["question"]},
        ]

        qa_inputs = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        pre_prompt = qa_inputs.split("<chunks>")[0]
        last_prompt = qa_inputs.split("<chunks>")[1]

        pre_prompt_tokens = self.tokenizer(
            pre_prompt, return_tensors="pt", add_special_tokens=False).input_ids.squeeze()
        last_prompt_tokens = self.tokenizer(
            last_prompt, return_tensors="pt", add_special_tokens=False).input_ids.squeeze()
        
        a_tokens = self.tokenizer(self.text[idx]['response'],
                                  return_tensors="pt",
                                  add_special_tokens=False).input_ids.reshape(-1)
            
        a_padding_length = self.max_qa_length - last_prompt_tokens.size(0) - a_tokens.size(0)
        if a_padding_length > 0:
            a_input_ids = torch.cat(
                [last_prompt_tokens,
                 a_tokens,
                 torch.full((a_padding_length,), self.eos_token_id, dtype=torch.long)])
            a_labels = torch.cat(
                [torch.full((last_prompt_tokens.size(0) - 1,), -100, dtype=torch.long),
                 a_tokens,
                 torch.full((1,), self.eos_token_id, dtype=torch.long),
                 torch.full((a_padding_length,), -100, dtype=torch.long)])
        else:
            a_input_ids = torch.cat(
                [last_prompt_tokens, a_tokens])[:self.max_qa_length]
            a_labels = torch.cat(
                [torch.full((last_prompt_tokens.size(0) - 1,), -100, dtype=torch.long),
                 a_tokens, torch.tensor([self.eos_token_id], dtype=torch.long)])[:self.max_qa_length]

        return {
            "chunk_input_ids": chunk_input_ids.to(torch.long),
            "chunk_mask": chunk_mask.to(torch.bool),
            "pre_prompt_tokens": pre_prompt_tokens.to(torch.long),
            "a_input_ids": a_input_ids.to(torch.long),
            "a_labels": a_labels.to(torch.long)}


class EditorData(Dataset):
    def __init__(
            self,
            text_files: list[str],
            llama_path: str,
            max_length: int,
            max_qa_length: int,
            num_mem: int,
            mem_id: int,
            num_edit: int = 4):
        self.text = []
        self.json = []
        for text_file in text_files:
            self.json += read_jsonl_file(text_file)
        self.tokenizer = AutoTokenizer.from_pretrained(llama_path, trust_remote_code=True)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        if self.tokenizer.bos_token_id == 128000:
            self.bos_token_id = 128000
            self.eos_token_id = 128009
        else:
            self.bos_token_id = self.tokenizer.bos_token_id
            self.eos_token_id = self.tokenizer.eos_token_id
        
        for one in tqdm(self.json):
            fact_tokens = self.tokenizer(one["new_facts"], add_special_tokens=False)
            if len(fact_tokens[0]) <= 40 and len(fact_tokens[0]) >= 6:
                self.text.append({
                    "text": one["text"],
                    "new_facts": one["new_facts"],
                    "question": one["question"],
                    "response": one["response"],
                })
                
        self.max_length = max_length
        self.max_qa_length = max_qa_length
        self.num_mem = num_mem
        self.num_edit = num_edit
        self.mem_id = mem_id
        self.system_prompt = "You are an accurate and reliable AI assistant capable of answering questions by referencing external documents. Please note that the external documents may not always be related to the question. If the information in the documents contain the correct answer, you will provide an accurate response. If the documents do not contain the answer, you will refuse to answer."
        self.user_prompt = """The documents are as follows:
<chunks>

Question: """

    def __len__(self):
        return len(self.text)

    def __getitem__(self, idx):
        c_tokens = self.tokenizer(
            self.text[idx]["text"],
            return_tensors="pt",
            add_special_tokens=False
        ).input_ids.squeeze()
        c_padding_length = self.max_length - c_tokens.size(0)
        if c_padding_length > 0:
            c_input_ids = torch.cat(
                [c_tokens,
                 torch.full((self.num_mem,), self.mem_id, dtype=torch.long),
                 torch.full((c_padding_length,), self.eos_token_id, dtype=torch.long)])
        else:
            c_input_ids = torch.cat(
                [c_tokens[:self.max_length],
                 torch.full((self.num_mem,), self.mem_id, dtype=torch.long)])
        
        e_tokens = self.tokenizer(
            self.text[idx]["new_facts"],
            return_tensors="pt",
            add_special_tokens=False
        ).input_ids.squeeze()
        e_padding_length = self.max_length - e_tokens.size(0)
        if e_padding_length > 0:
            e_input_ids = torch.cat(
                [e_tokens,
                 torch.full((self.num_edit,), self.mem_id, dtype=torch.long),
                 torch.full((e_padding_length,), self.eos_token_id, dtype=torch.long)])
        else:
            e_input_ids = torch.cat(
                [e_tokens[:self.max_length],
                 torch.full((self.num_edit,), self.mem_id, dtype=torch.long)])

        # answer tokens
        messages = [
            {"role": "system", "content": self.system_prompt},
            {"role": "user", "content": self.user_prompt + self.text[idx]["question"]},
        ]

        qa_inputs = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        pre_prompt = qa_inputs.split("<chunks>")[0]
        last_prompt = qa_inputs.split("<chunks>")[1]

        pre_prompt_tokens = self.tokenizer(
            pre_prompt, return_tensors="pt", add_special_tokens=False).input_ids.squeeze()
        last_prompt_tokens = self.tokenizer(
            last_prompt, return_tensors="pt", add_special_tokens=False).input_ids.squeeze()
        
        a_tokens = self.tokenizer(self.text[idx]['response'],
                                  return_tensors="pt",
                                  add_special_tokens=False).input_ids.reshape(-1)
            
        a_padding_length = self.max_qa_length - last_prompt_tokens.size(0) - a_tokens.size(0)
        if a_padding_length > 0:
            a_input_ids = torch.cat(
                [last_prompt_tokens,
                 a_tokens,
                 torch.full((a_padding_length,), self.eos_token_id, dtype=torch.long)])
            a_labels = torch.cat(
                [torch.full((last_prompt_tokens.size(0) - 1,), -100, dtype=torch.long),
                 a_tokens,
                 torch.full((1,), self.eos_token_id, dtype=torch.long),
                 torch.full((a_padding_length,), -100, dtype=torch.long)])
        else:
            a_input_ids = torch.cat(
                [last_prompt_tokens, a_tokens])[:self.max_qa_length]
            a_labels = torch.cat(
                [torch.full((last_prompt_tokens.size(0) - 1,), -100, dtype=torch.long),
                 a_tokens, torch.tensor([self.eos_token_id], dtype=torch.long)])[:self.max_qa_length]

        return {
            "c_input_ids": c_input_ids.to(torch.long),
            "e_input_ids": e_input_ids.to(torch.long),
            "pre_prompt_tokens": pre_prompt_tokens.to(torch.long),
            "a_input_ids": a_input_ids.to(torch.long),
            "a_labels": a_labels.to(torch.long)}