import json
import random
import pandas as pd
import torch
from datasets import load_from_disk, load_dataset
from torch.utils.data import Dataset
import sys
sys.path.append("../../reward")
from utils import load_data_subset

class UtilityScenarioDataset(Dataset):
    def __init__(self, train_path, tokenizer, split, max_length=550):

        self.post_list = []
        dataset = load_from_disk(train_path)[split]
        for sample in dataset:
            if sample["prompt"].endswith('.'):
                sample["prompt"] = sample["prompt"] + " "
            self.post_list.append(sample["prompt"] + sample["label"])
        if "valid" in train_path:
            self.post_list = self.post_list[0:500]
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.input_ids = []
        self.attn_masks = []

    def __len__(self):
        return len(self.post_list)

    def __getitem__(self, idx):
        txt = self.post_list[idx]
        encodings_dict = self.tokenizer(txt, truncation=True, max_length=self.max_length, padding="max_length")
        input_ids = torch.tensor(encodings_dict["input_ids"])
        attn_masks = torch.tensor(encodings_dict["attention_mask"])

        return {
            "input_ids": input_ids,
            "attention_mask": attn_masks,
            "labels": input_ids,
        }

class MoralStoriesScenarioDataset(Dataset):
    def __init__(self, train_path, tokenizer, split, max_length=550, add_moral=True):

        self.post_list = []
        dataset = load_from_disk(train_path)[split]
        for sample in dataset:
            self.post_list.append(sample["prompt"] + " " + sample["label"])
            if add_moral:
                self.post_list.append(sample["prompt"] + " " + sample["moral_label"])
        if "valid" in train_path:
            self.post_list = self.post_list[0:500]
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.input_ids = []
        self.attn_masks = []

    def __len__(self):
        return len(self.post_list)

    def __getitem__(self, idx):
        txt = self.post_list[idx]
        encodings_dict = self.tokenizer(txt, truncation=True, max_length=self.max_length, padding="max_length")
        input_ids = torch.tensor(encodings_dict["input_ids"])
        attn_masks = torch.tensor(encodings_dict["attention_mask"])

        return {
            "input_ids": input_ids,
            "attention_mask": attn_masks,
            "labels": input_ids,
        }

class HHDataset(Dataset):
    def __init__(self, train_path, tokenizer, split, max_length=1024):

        self.post_list = []
        dataset = load_dataset("Anthropic/hh-rlhf")
        for sample in dataset[split]:
            self.post_list.append(sample["chosen"])
        
        random.seed(10)
        random.shuffle(self.post_list)

        if split == 'train':
            self.post_list = self.post_list[-30000:] # 160800 in total
        elif split == 'test':
            self.post_list = self.post_list[0:500]
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.input_ids = []
        self.attn_masks = []

    def __len__(self):
        return len(self.post_list)

    def __getitem__(self, idx):
        txt = self.post_list[idx]
        encodings_dict = self.tokenizer(txt, truncation=True, max_length=self.max_length, padding="max_length")
        input_ids = torch.tensor(encodings_dict["input_ids"])
        attn_masks = torch.tensor(encodings_dict["attention_mask"])

        return {
            "input_ids": input_ids,
            "attention_mask": attn_masks,
            "labels": input_ids,
        }


class SHPDataset(Dataset):
    def __init__(self, train_path, tokenizer, split, max_length=1024):

        self.post_list = []
        dataset = load_dataset("stanfordnlp/SHP")
        for sample in dataset[split]:
            # self.post_list.append(sample["chosen"])
            if sample['upvote_ratio'] > 0.9 and sample['score_A'] > 50 and sample['score_B'] > 50:
                example_a = "\n\nHuman: " + sample['history'] + "\n\nAssistant: " + sample['human_ref_A']
                example_b = "\n\nHuman: " + sample['history'] + "\n\nAssistant: " + sample['human_ref_B']
                comment_tuple = (example_a, example_b)
                preferred_comment = 1 - int(sample['labels'])
                self.post_list.append(comment_tuple[preferred_comment])
        
        random.seed(10)
        random.shuffle(self.post_list)

        if split == 'train':
            self.post_list = self.post_list[500:] # 20719 in total
        elif split == 'test':
            self.post_list = self.post_list[0:500]
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.input_ids = []
        self.attn_masks = []

    def __len__(self):
        return len(self.post_list)

    def __getitem__(self, idx):
        txt = self.post_list[idx]
        encodings_dict = self.tokenizer(txt, truncation=True, max_length=self.max_length, padding="max_length")
        input_ids = torch.tensor(encodings_dict["input_ids"])
        attn_masks = torch.tensor(encodings_dict["attention_mask"])

        return {
            "input_ids": input_ids,
            "attention_mask": attn_masks,
            "labels": input_ids,
        }


class OASSTDataset(Dataset):
    def __init__(self, train_path, tokenizer, split, max_length=1024):
        data_dir = "/data/datasets/oasst"

        self.post_list = []
        for split_ in ["train", "validation", "test"]:
            datapath = {
                "train": f"{data_dir}/rank_pairs_train.jsonl",
                "validation": f"{data_dir}/rank_pairs_val.jsonl",
                "test": f"{data_dir}/rank_pairs_test.jsonl"
            }[split_]
            with open(datapath, "r") as f:
                role_to_name = {
                    "prompter": "Human",
                    "assistant": "Assistant",
                }
                for line in f.readlines():
                    data = json.loads(line)
                    good_msgs_list = data['context'] + [data['good_response']]
                    good_example = "\n\n".join([f"{role_to_name[msg['role']]}: {msg['text']}" for msg in good_msgs_list])
                    self.post_list.append(good_example)

        random.seed(10)
        random.shuffle(self.post_list)

        if split == 'train':
            self.post_list = self.post_list[-30000:] # 44934 in total
        elif split == 'test':
            self.post_list = self.post_list[0:500]
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.input_ids = []
        self.attn_masks = []

    def __len__(self):
        return len(self.post_list)

    def __getitem__(self, idx):
        txt = self.post_list[idx]
        encodings_dict = self.tokenizer(txt, truncation=True, max_length=self.max_length, padding="max_length")
        input_ids = torch.tensor(encodings_dict["input_ids"])
        attn_masks = torch.tensor(encodings_dict["attention_mask"])

        return {
            "input_ids": input_ids,
            "attention_mask": attn_masks,
            "labels": input_ids,
        }


class MergedDataset(Dataset):
    """
    Merged dataset of HH, SHP, OASST and GPT4All
    """
    def __init__(self, train_path, tokenizer, split, max_length=1024, train_type="sft", positive_samples_only=True):
        assert train_type in ["sft", "rl"]
        subset_type = {
            "sft": "train20",
            "rl": "train80"
        }[train_type]

        self.post_list = []

        # HH
        # Use complement of subset for reward modelling
        sentences, labels = load_data_subset(data_dir="hh", dataset="hh", split="train", subset=subset_type, percent=1.0)
        # reward sentences are in pairs, the first is the preferred comment
        preferred_examples = sentences[::2] if positive_samples_only else sentences
        print("HH:", len(preferred_examples)) # 32160 for sft, 128640 for rl
        self.post_list.extend(preferred_examples)

        # Shuffle all datasets
        random.seed(0)
        random.shuffle(self.post_list)
        print("Total:", len(self.post_list))

        n_test = len(self.post_list) // 10
        print("Holding out", n_test, "examples for test")
        if split == 'train':
            self.post_list = self.post_list[n_test:]
        elif split == 'test':
            self.post_list = self.post_list[:n_test]
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.input_ids = []
        self.attn_masks = []

        filtered_set = []
        for s in self.post_list:
            if '   ' in s:
                continue
            filtered_set.append(s)
        self.post_list = filtered_set
        print(f"Num examples in {split} set: {len(self.post_list)}")

    def __len__(self):
        return len(self.post_list)

    def __getitem__(self, idx):
        txt = self.post_list[idx]
        encodings_dict = self.tokenizer(txt, truncation=True, max_length=self.max_length, padding="max_length")
        input_ids = torch.tensor(encodings_dict["input_ids"])
        attn_masks = torch.tensor(encodings_dict["attention_mask"])
        
        # try:
        #     sample_list = []
        #     mask_list = []
        #     for chunk in txt.split("\n\nHuman: "):
        #         if not chunk: continue
        #         question, response = chunk.split("\n\nAssistant: ")[:2]
        #         sample_list.extend(("\n\nHuman: " + question, "\n\nAssistant: " + response))
        #         mask_list.extend((True, False))
        #     tokens_list = self.tokenizer(sample_list)
        #     cleaned_inputs, cleaned_labels = [],[]
        #     for i, (tokens, mask) in enumerate(zip(tokens_list['input_ids'], mask_list)):
        #         if i != 0 and 'llama' in str(self.tokenizer.__class__): tokens = tokens[1:]
        #         if tokens[0] == 29871 and 'llama' in str(self.tokenizer.__class__): tokens = tokens[1:]
        #         if not mask:
        #             cleaned_labels.extend(tokens)
        #         else:
        #             tokens = [-100] * len(tokens)
        #             tokens[0] = self.tokenizer.eos_token_id
        #             cleaned_labels.extend(tokens)

        #     if len(cleaned_labels) >= self.max_length:
        #         cleaned_labels = cleaned_labels[-self.max_length:]
        #     else:
        #         padding = [self.tokenizer.pad_token_id] * (self.max_length - len(cleaned_labels))
        #         cleaned_labels += padding

        #     labels = torch.tensor(cleaned_labels)

        #     assert input_ids.shape == labels.shape
        # except:
        #     print('SKIPPING', txt)
        #     labels = input_ids * 0 - 100

        labels = input_ids

        return {
            "input_ids": input_ids,
            "attention_mask": attn_masks,
            "labels": labels,
        }

class MergedPreferencePairsDataset(Dataset):
    """
    Merged dataset of HH, SHP, OASST only (no GPT4All)
    """
    def __init__(self, train_path, tokenizer, split, max_length=1024, subset_type="train20", positive_samples_only=False):
        assert subset_type in ["train20", "train80"]

        self.post_list = []

        # HH
        # Use complement of subset for reward modelling
        sentences, labels = load_data_subset(data_dir="hh", dataset="hh", split="train", subset=subset_type, percent=1.0)
        # reward sentences are in pairs, the first is the preferred comment
        preferred_examples = sentences[::2] if positive_samples_only else sentences
        print("HH:", len(preferred_examples))
        self.post_list.extend(preferred_examples)

        # SHP
        # Use complement of subset for reward modelling
        sentences, labels = load_data_subset(data_dir="shp", dataset="shp", split="train", subset=subset_type, percent=1.0)
        # reward sentences are in pairs, the first is the preferred comment
        preferred_examples = sentences[::2] if positive_samples_only else sentences
        print("SHP:", len(preferred_examples))
        self.post_list.extend(preferred_examples)

        # OASST
        # Use complement of subset for reward modelling
        sentences, labels = load_data_subset(data_dir="oasst", dataset="oasst", split="train", subset=subset_type, percent=1.0)
        # reward sentences are in pairs, the first is the preferred comment
        preferred_examples = sentences[::2] if positive_samples_only else sentences
        print("OASST:", len(preferred_examples))
        self.post_list.extend(preferred_examples)

        # Shuffle all datasets
        random.seed(0)
        random.shuffle(self.post_list)
        print("Total:", len(self.post_list))

        n_test = len(self.post_list) // 10
        print("Holding out", n_test, "examples for test")
        if split == 'train':
            self.post_list = self.post_list[n_test:]
        elif split == 'test':
            self.post_list = self.post_list[:n_test]
        print(f"Num examples in {split} set: {len(self.post_list)}")
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.input_ids = []
        self.attn_masks = []

    def __len__(self):
        return len(self.post_list)

    def __getitem__(self, idx):
        txt = self.post_list[idx]
        encodings_dict = self.tokenizer(txt, truncation=True, max_length=self.max_length, padding="max_length")
        input_ids = torch.tensor(encodings_dict["input_ids"])
        attn_masks = torch.tensor(encodings_dict["attention_mask"])

        return {
            "input_ids": input_ids,
            "attention_mask": attn_masks,
            "labels": input_ids,
        }


class RolloutPreferencePairsDataset(Dataset):
    """
    Merged dataset of HH, SHP, OASST only (no GPT4All)
    """
    def __init__(self, split):
        subset_type = "train80"
        positive_samples_only = True

        self.post_list = []

        # HH
        # Use complement of subset for reward modelling
        sentences, labels = load_data_subset(data_dir="hh", dataset="hh", split="train", subset=subset_type, percent=1.0)
        # reward sentences are in pairs, the first is the preferred comment
        preferred_examples = sentences[::2] if positive_samples_only else sentences
        print("HH:", len(preferred_examples))
        self.post_list.extend(preferred_examples)

        # SHP
        # Use complement of subset for reward modelling
        sentences, labels = load_data_subset(data_dir="shp", dataset="shp", split="train", subset=subset_type, percent=1.0)
        # reward sentences are in pairs, the first is the preferred comment
        preferred_examples = sentences[::2] if positive_samples_only else sentences
        print("SHP:", len(preferred_examples))
        self.post_list.extend(preferred_examples)

        # OASST
        # Use complement of subset for reward modelling
        sentences, labels = load_data_subset(data_dir="oasst", dataset="oasst", split="train", subset=subset_type, percent=1.0)
        # reward sentences are in pairs, the first is the preferred comment
        preferred_examples = sentences[::2] if positive_samples_only else sentences
        print("OASST:", len(preferred_examples))
        self.post_list.extend(preferred_examples)

        # Shuffle all datasets
        random.seed(0)
        random.shuffle(self.post_list)
        print("Total:", len(self.post_list))

        n_test = len(self.post_list) // 10
        print("Holding out", n_test, "examples for test")
        if split == 'train':
            self.post_list = self.post_list[n_test:]
        elif split == 'test':
            self.post_list = self.post_list[:n_test]
        print(f"Num examples in {split} set: {len(self.post_list)}")

    def __len__(self):
        return len(self.post_list)

    def __getitem__(self, idx):
        return self.post_list[idx]


class RolloutDataset(Dataset):
    """
    Merged dataset of HH, SHP, OASST and GPT4All
    """
    def __init__(self, split):
        train_type = "rl"
        subset_type = {
            "sft": "train20",
            "rl": "train80"
        }[train_type]

        self.post_list = []

        # HH
        # Use complement of subset for reward modelling
        sentences, labels = load_data_subset(data_dir="hh", dataset="hh", split="train", subset=subset_type, percent=1.0)
        # reward sentences are in pairs, the first is the preferred comment
        preferred_examples = sentences[::2]
        print("HH:", len(preferred_examples)) # 32160 for sft, 128640 for rl
        self.post_list.extend(preferred_examples)

        # Shuffle all datasets
        random.seed(0)
        random.shuffle(self.post_list)
        print("Total:", len(self.post_list))

        n_test = len(self.post_list) // 10
        print("Holding out", n_test, "examples for test")
        if split == 'train':
            self.post_list = self.post_list[n_test:]
        elif split == 'test':
            self.post_list = self.post_list[:n_test]
        
    def __len__(self):
        return len(self.post_list)
    
    def __getitem__(self, idx):
        txt = self.post_list[idx]
        return txt


if __name__ == "__main__":
    # print("HH")
    # dset = HHDataset(None, None, "train")
    # print(len(dset))

    # print("SHP")
    # dset = SHPDataset(None, None, "train")
    # print(len(dset))
    
    # print("OASST")
    # dset = OASSTDataset(None, None, "train")
    # print(len(dset))

    print("MergedDataset - SFT")
    dset = MergedDataset(None, None, "train", train_type="sft")
    print(len(dset))
    print("")

    print("MergedDataset - RL")
    dset = MergedDataset(None, None, "train", train_type="rl")
    print(len(dset))
