import copy
import json

import torch
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence

B_INST, E_INST = "[INST]", "[/INST]"

class QADataset(Dataset):
    def __init__(self, dataset_config, tokenizer, partition="train"):
        if partition == 'train':
            self.ann = json.load(open(dataset_config.train_data_path))
        else:
            self.ann = json.load(open(dataset_config.valid_data_path))
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.ann)

    def __getitem__(self, index):
        IGNORE_INDEX = -100  # The default setting in CrossEntropyLoss


        ann = self.ann[index]
        prompt = B_INST + ' ' + ann['question'] + ' ' + E_INST
        example = prompt + ann["answer"]
        prompt = torch.tensor(
            self.tokenizer.encode(prompt), dtype=torch.int64
        )
        example = self.tokenizer.encode(example)
        example.append(self.tokenizer.eos_token_id)
        example = torch.tensor(
            example, dtype=torch.int64
        )
        labels = copy.deepcopy(example)
        labels[: len(prompt)] = -1
        example_mask = example.ge(0)
        label_mask = labels.ge(0)
        example[~example_mask] = 0
        labels[~label_mask] = IGNORE_INDEX

        return {
            "input_ids": example.tolist(),
            "labels": labels.tolist(),
            "attention_mask":example_mask.tolist(),
        }
    
class Contrastive_QADataset(Dataset):
    def __init__(self, dataset_config, tokenizer, partition="train"):
        if partition == 'train':
            self.ann = json.load(open(dataset_config.train_data_path))
        else:
            self.ann = json.load(open(dataset_config.valid_data_path))
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.ann)

    def __getitem__(self, index):
        IGNORE_INDEX = -100  # The default setting in CrossEntropyLoss


        ann = self.ann[index]
        prompt = B_INST + ' ' + ann['question'] + ' ' + E_INST
        example = prompt + ann["answer"]
        example_positive = prompt + ann["answer_positive"]
        example_negative = prompt + ann["answer_negative"]
        prompt = torch.tensor(
            self.tokenizer.encode(prompt), dtype=torch.int64
        )
        
        example = self.tokenizer.encode(example)
        example.append(self.tokenizer.eos_token_id)
        example = torch.tensor(
            example, dtype=torch.int64
        )
        
        # positive sample and negative sample
        example_positive = self.tokenizer.encode(example_positive)
        example_positive.append(self.tokenizer.eos_token_id)
        example_negative = self.tokenizer.encode(example_negative)
        example_negative.append(self.tokenizer.eos_token_id)
        example_positive = torch.tensor(example_positive, dtype=torch.int64)
        example_negative = torch.tensor(example_negative, dtype=torch.int64)
        example,example_positive,example_negative = pad_sequence(
            [example,example_positive,example_negative], batch_first=True,padding_value=-1
        )
        

        labels = copy.deepcopy(example)
        # 将提示部分的labels值设置为-1，表示这部分不参与损失计算
        labels[: len(prompt)] = -1 
        # 创建mask，用于标识example中大于等于0的元素
        example_mask = example.ge(0)
        label_mask = labels.ge(0)
        example[~example_mask] = 0
        labels[~label_mask] = IGNORE_INDEX

        positive_mask = example_positive.ge(0)
        negative_mask = example_negative.ge(0)
        example_positive[~positive_mask] = 0
        example_negative[~negative_mask] = 0

        labels_pos = copy.deepcopy(example_positive)
        labels_neg = copy.deepcopy(example_negative)
        label_mask = labels_pos.ge(0)
        labels_pos[~label_mask] = IGNORE_INDEX
        label_mask = labels_neg.ge(0)
        labels_neg[~label_mask] = IGNORE_INDEX


        return {
            "input_ids": example.tolist(),
            "labels": labels.tolist(),
            "attention_mask":example_mask.tolist(),
            "input_ids_pos": example_positive.tolist(),
            "input_ids_neg": example_negative.tolist(),
            "labels_pos": labels_pos.tolist(),
            "labels_neg": labels_neg.tolist(),
            "attention_mask_pos": positive_mask.tolist(),
            "attention_mask_neg": negative_mask.tolist()
        }

class Contrastive2_QADataset(Dataset):
    def __init__(self, dataset_config, tokenizer, partition="train"):
        if partition == 'train':
            self.ann = json.load(open(dataset_config.train_data_path))
        else:
            self.ann = json.load(open(dataset_config.valid_data_path))
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.ann)

    def __getitem__(self, index):
        IGNORE_INDEX = -100  # The default setting in CrossEntropyLoss


        ann = self.ann[index]
        prompt = B_INST + ' ' + ann['question'] + ' ' + E_INST
        example = prompt + ann["answer"]
        example_positive_1 = prompt + ann["answer_positive"][0]
        example_negative_1 = prompt + ann["answer_negative"][0]
        example_positive_2 = prompt + ann["answer_positive"][1]
        example_negative_2 = prompt + ann["answer_negative"][1]
        prompt = torch.tensor(
            self.tokenizer.encode(prompt), dtype=torch.int64
        )
        
        example = self.tokenizer.encode(example)
        example.append(self.tokenizer.eos_token_id)
        example = torch.tensor(
            example, dtype=torch.int64
        )
        
        # positive sample and negative sample
        example_positive_1 = self.tokenizer.encode(example_positive_1)
        example_positive_1.append(self.tokenizer.eos_token_id)
        example_positive_2 = self.tokenizer.encode(example_positive_2)
        example_positive_2.append(self.tokenizer.eos_token_id)
        example_negative_1 = self.tokenizer.encode(example_negative_1)
        example_negative_1.append(self.tokenizer.eos_token_id)
        example_negative_2 = self.tokenizer.encode(example_negative_2)
        example_negative_2.append(self.tokenizer.eos_token_id)
        example_positive_1 = torch.tensor(example_positive_1, dtype=torch.int64)
        example_negative_1 = torch.tensor(example_negative_1, dtype=torch.int64)
        example_positive_2 = torch.tensor(example_positive_2, dtype=torch.int64)
        example_negative_2 = torch.tensor(example_negative_2, dtype=torch.int64)
        example,example_positive_1,example_negative_1,example_positive_2,example_negative_2 = pad_sequence(
            [example,example_positive_1,example_negative_1,example_positive_2,example_negative_2], batch_first=True,padding_value=-1
        )
        

        labels = copy.deepcopy(example)
        # 将提示部分的labels值设置为-1，表示这部分不参与损失计算
        labels[: len(prompt)] = -1 
        # 创建mask，用于标识example中大于等于0的元素
        example_mask = example.ge(0)
        label_mask = labels.ge(0)
        example[~example_mask] = 0
        labels[~label_mask] = IGNORE_INDEX

        positive_mask_1 = example_positive_1.ge(0)
        negative_mask_1 = example_negative_1.ge(0)
        positive_mask_2 = example_positive_2.ge(0)
        negative_mask_2 = example_negative_2.ge(0)
        example_positive_1[~positive_mask_1] = 0
        example_negative_1[~negative_mask_1] = 0
        example_positive_2[~positive_mask_2] = 0
        example_negative_2[~negative_mask_2] = 0



        return {
            "input_ids": example.tolist(),
            "labels": labels.tolist(),
            "attention_mask":example_mask.tolist(),
            "input_ids_pos_1": example_positive_1.tolist(),
            "input_ids_neg_1": example_negative_1.tolist(),
            "attention_mask_pos_1": positive_mask_1.tolist(),
            "attention_mask_neg_1": negative_mask_1.tolist(),
            "input_ids_pos_2": example_positive_2.tolist(),
            "input_ids_neg_2": example_negative_2.tolist(),
            "attention_mask_pos_2": positive_mask_2.tolist(),
            "attention_mask_neg_2": negative_mask_2.tolist()
        }
    
class Contrastive3_QADataset(Dataset):
    def __init__(self, dataset_config, tokenizer, partition="train"):
        if partition == 'train':
            self.ann = json.load(open(dataset_config.train_data_path))
        else:
            self.ann = json.load(open(dataset_config.valid_data_path))
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.ann)

    def __getitem__(self, index):
        IGNORE_INDEX = -100  # The default setting in CrossEntropyLoss


        ann = self.ann[index]
        prompt = B_INST + ' ' + ann['question'] + ' ' + E_INST
        example = prompt + ann["answer"]
        example_positive_1 = prompt + ann["answer_positive"][0]
        example_negative_1 = prompt + ann["answer_negative"][0]
        example_positive_2 = prompt + ann["answer_positive"][1]
        example_negative_2 = prompt + ann["answer_negative"][1]
        prompt = torch.tensor(
            self.tokenizer.encode(prompt), dtype=torch.int64
        )
        
        example = self.tokenizer.encode(example)
        example.append(self.tokenizer.eos_token_id)
        example = torch.tensor(
            example, dtype=torch.int64
        )

        
        # positive sample and negative sample
        example_positive_1 = self.tokenizer.encode(example_positive_1)
        example_positive_1.append(self.tokenizer.eos_token_id)
        example_positive_2 = self.tokenizer.encode(example_positive_2)
        example_positive_2.append(self.tokenizer.eos_token_id)
        example_negative_1 = self.tokenizer.encode(example_negative_1)
        example_negative_1.append(self.tokenizer.eos_token_id)
        example_negative_2 = self.tokenizer.encode(example_negative_2)
        example_negative_2.append(self.tokenizer.eos_token_id)
        example_positive_1 = torch.tensor(example_positive_1, dtype=torch.int64)
        example_negative_1 = torch.tensor(example_negative_1, dtype=torch.int64)
        example_positive_2 = torch.tensor(example_positive_2, dtype=torch.int64)
        example_negative_2 = torch.tensor(example_negative_2, dtype=torch.int64)
        example,example_positive_1,example_negative_1,example_positive_2,example_negative_2 = pad_sequence(
            [example,example_positive_1,example_negative_1,example_positive_2,example_negative_2], batch_first=True,padding_value=-1
        )
        length = len(example)
        acc = [ann["acc"]] * length
        acc = torch.tensor(acc,dtype=torch.float32)

        labels = copy.deepcopy(example)
        # 将提示部分的labels值设置为-1，表示这部分不参与损失计算
        labels[: len(prompt)] = -1 
        # 创建mask，用于标识example中大于等于0的元素
        example_mask = example.ge(0)
        label_mask = labels.ge(0)
        example[~example_mask] = 0
        labels[~label_mask] = IGNORE_INDEX

        positive_mask_1 = example_positive_1.ge(0)
        negative_mask_1 = example_negative_1.ge(0)
        positive_mask_2 = example_positive_2.ge(0)
        negative_mask_2 = example_negative_2.ge(0)
        example_positive_1[~positive_mask_1] = 0
        example_negative_1[~negative_mask_1] = 0
        example_positive_2[~positive_mask_2] = 0
        example_negative_2[~negative_mask_2] = 0



        return {
            "input_ids": example.tolist(),
            "labels": labels.tolist(),
            "attention_mask":example_mask.tolist(),
            "input_ids_pos_1": example_positive_1.tolist(),
            "input_ids_neg_1": example_negative_1.tolist(),
            "attention_mask_pos_1": positive_mask_1.tolist(),
            "attention_mask_neg_1": negative_mask_1.tolist(),
            "input_ids_pos_2": example_positive_2.tolist(),
            "input_ids_neg_2": example_negative_2.tolist(),
            "attention_mask_pos_2": positive_mask_2.tolist(),
            "attention_mask_neg_2": negative_mask_2.tolist(),
            "acc": acc.tolist()
        }