import os
import json
import torch
import random
from torch.utils.data import TensorDataset

import datasets

import numpy as np
import pandas as pd
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig, T5ForConditionalGeneration
from custom_models import ContextPoolerCustomToken, EncoderWithLogitHead, FrozenWrapper
from accelerate.utils import DummyOptim
from sam import SAM

T5_MODEL_NAMES = ["google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large", "google/flan-t5-xl", "google/flan-t5-xxl", "google/ul2", "google/flan-ul2"]
DEBERTA_MODEL_NAMES = ["microsoft/deberta-v3-xsmall", "microsoft/deberta-v3-small", "microsoft/deberta-v3-base", "microsoft/deberta-v3-large"]
LLAMA_MODEL_NAMES = ["llama-7b"]
DEFAULT_EOS_TOKEN = "</s>"

def get_tokenizer(model):
    tokenizer = AutoTokenizer.from_pretrained(model)
    if "llama" in model or "gpt" in model:
        if not tokenizer.eos_token:
            tokenizer.eos_token = DEFAULT_EOS_TOKEN
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.padding_side = "left"
    tokenizer.truncation_side= "left"
    return tokenizer

# optional
NUM_LABELS = 1

def load_model(args, load_path=None, cache_dir=None, accelerator=None, sam=False):
    if cache_dir is not None:
        config = AutoConfig.from_pretrained(args.model, num_labels=NUM_LABELS, cache_dir=cache_dir, dropout_rate=args.dropout)
    else:
        config = AutoConfig.from_pretrained(args.model, num_labels=NUM_LABELS, dropout_rate=args.dropout, attention_dropout_prob=args.dropout, hidden_dropout_prob=args.dropout)
    
    if args.model in T5_MODEL_NAMES:
        base_model = T5ForConditionalGeneration.from_pretrained(args.model, config=config)
        model = EncoderWithLogitHead(base_model, get_tokenizer(args.model), args.model, args.custom_tokens)
    else:
        model = AutoModelForSequenceClassification.from_pretrained(args.model, config=config)
    
    if "llama" in args.model or "gpt" in args.model:
        model.config.pad_token_id = model.config.eos_token_id
    
    if args.custom_tokens:
        if args.model in DEBERTA_MODEL_NAMES:
            model.pooler = ContextPoolerCustomToken(config, args.custom_tokens, dropout=args.dropout)
        elif args.model in T5_MODEL_NAMES:
            pass
        else:
            raise ValueError("Custom token position (Context Pooler) not supported for this model")

    if args.freeze_base:
        if args.model in DEBERTA_MODEL_NAMES:
            model = FrozenWrapper(config, model, args.custom_tokens, dropout=args.dropout)
            for param in model.model.deberta.parameters():
                param.requires_grad = False
        else:
            raise ValueError("Freezing base model not supported for this model")

    if load_path is not None:
        model.load_state_dict(torch.load(load_path))

    if accelerator is not None:
        model = model
    elif args.ngpus > 0:
        model.cuda()
        model = torch.nn.DataParallel(model, device_ids=[i for i in range(args.ngpus)])

    print('\nPretrained model "{}" loaded'.format(args.model))
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in model.named_parameters()
                    if not any(nd in n for nd in no_decay)],
         'weight_decay': args.weight_decay},
        {'params': [p for n, p in model.named_parameters()
                    if any(nd in n for nd in no_decay)],
         'weight_decay': 0.0}
    ]
    if accelerator is not None:
        optimizer_cls = (
            torch.optim.AdamW
            if accelerator.state.deepspeed_plugin is None
            or "optimizer" not in accelerator.state.deepspeed_plugin.deepspeed_config
            else DummyOptim
        )
        if sam:
            optimizer = SAM(optimizer_grouped_parameters, optimizer_cls, lr=args.learning_rate)
        else:
            optimizer = optimizer_cls(optimizer_grouped_parameters, lr=args.learning_rate)

    else:
        if sam:
            optimizer = SAM(optimizer_grouped_parameters, torch.optim.AdamW, lr=args.learning_rate)
        else:
            optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)

    return model, optimizer

def split_data(split, data, nsplits=5):
    all_idxs = np.arange(len(data))
    train_mask = np.ones(len(data)).astype(bool)
    test_mask = np.zeros(len(data)).astype(bool)
    start, end = (len(data) // nsplits)*split, (len(data) // nsplits)*(split+1)
    train_mask[start:end] = False
    test_mask[start:end] = True
    train_idxs = all_idxs[train_mask]
    test_idxs = all_idxs[test_mask]
    train_data = torch.utils.data.Subset(data, train_idxs)
    test_data = torch.utils.data.Subset(data, test_idxs)
    return train_data, test_data

def load_cm_sentences(data_dir, split="train"):
    if "long" in split:
        path = os.path.join(data_dir, "cm_{}.csv".format(split.split("long_")[1]))
        df = pd.read_csv(path)
        df = df[df["is_short"] == False]
    else:
        path = os.path.join(data_dir, "cm_{}.csv".format(split))
        df = pd.read_csv(path)

    if split == "ambig":
        labels = [-1 for _ in range(df.shape[0])]
        sentences = [df.iloc[i, 0] for i in range(df.shape[0])]
    else:
        labels = [df.iloc[i, 0] for i in range(df.shape[0])]
        sentences = [df.iloc[i, 1] for i in range(df.shape[0])]
    return sentences, labels

def load_paired_cm_sentences(data_dir, split="train"):
    if "long" in split:
        path = os.path.join(data_dir, "cm_{}.csv".format(split.split("long_")[1]))
        df = pd.read_csv(path)
        df = df[df["is_short"] == False]
        raise NotImplementedError
    else:
        path = os.path.join(data_dir, "cm_{}.csv".format(split))
        df = pd.read_csv(path)

    if split == "ambig":
        labels = [-1 for _ in range(df.shape[0])]
        sentences = [df.iloc[i, 0] for i in range(df.shape[0])]
        raise NotImplementedError
    else:
        labels = [df.iloc[i, 0] for i in range(df.shape[0])]
        sentences = [df.iloc[i, 1] for i in range(df.shape[0])]

    paired_sentences = []
    for s, l in zip(sentences, labels):
        if l == 1:
            paired_sentences.append("")
            paired_sentences.append(s)
        elif l == 0:
            paired_sentences.append(s)
            paired_sentences.append("")
        else:
            raise ValueError("Unknown label {}".format(l))
    # paired_labels = [[-1]*10 for _ in range(len(paired_sentences))]
    paired_labels = [[1,0,0,0,0,0] for _ in range(len(paired_sentences))]
    
    return paired_sentences, paired_labels

def load_paired_moral_stories_sentences(data_dir, split="train"):
    
    dataset = datasets.load_from_disk(os.path.join(data_dir, 'moral_stories_dataset.hf'))[split]
    paired_sentences = []
    for sample in dataset:
        s, immoral, moral  = sample["prompt"], sample["label"], sample["moral_label"]
        paired_sentences.append("")
        paired_sentences.append(s + " " + immoral)
        
        paired_sentences.append(s + " " + moral)
        paired_sentences.append("")
    # paired_labels = [[-1]*10 for _ in range(len(paired_sentences))]
    paired_labels = [[1,0,0,0,0,0] for _ in range(len(paired_sentences))]
    
    return paired_sentences, paired_labels

def load_justice_sentences(data_dir, split="train"):
    path = os.path.join(data_dir, "justice_{}.csv".format(split))
    df = pd.read_csv(path)
    labels = [df.iloc[i, 0] for i in range(df.shape[0])]
    sentences = [df.iloc[i, 1] for i in range(df.shape[0])]
    return sentences, labels

def load_paired_justice_sentences(data_dir, split="train"):
    path = os.path.join(data_dir, "justice_{}.csv".format(split))
    df = pd.read_csv(path)
    labels = [df.iloc[i, 0] for i in range(df.shape[0])]
    sentences = [df.iloc[i, 1] for i in range(df.shape[0])]
    
    paired_sentences = []
    for s, l in zip(sentences, labels):
        if l == 1:
            paired_sentences.append("")
            paired_sentences.append(s)
        elif l == 0:
            paired_sentences.append(s)
            paired_sentences.append("")
        else:
            raise ValueError("Unknown label {}".format(l))
    # paired_labels = [[-1]*10 for _ in range(len(paired_sentences))]
    paired_labels = [[0,1,0,0,0,0] for _ in range(len(paired_sentences))]
    
    return paired_sentences, paired_labels

def load_virtue_sentences(data_dir, split="train"):
    path = os.path.join(data_dir, "virtue_{}.csv".format(split))
    df = pd.read_csv(path)
    labels = [df.iloc[i, 0] for i in range(df.shape[0])]
    sentences = [df.iloc[i, 1] for i in range(df.shape[0])]
    return sentences, labels

def load_paired_virtue_sentences(data_dir, split="train"):
    path = os.path.join(data_dir, "virtue_{}.csv".format(split))
    df = pd.read_csv(path)
    labels = [df.iloc[i, 0] for i in range(df.shape[0])]
    sentences = [df.iloc[i, 1].split(' [SEP] ') for i in range(df.shape[0])]
    
    paired_sentences = []
    for s, l in zip(sentences, labels):
        if l == 1:
            paired_sentences.append(["", s[1]])
            paired_sentences.append(s)
        elif l == 0:
            paired_sentences.append(s)
            paired_sentences.append(["", s[1]])
        else:
            raise ValueError("Unknown label {}".format(l))
    # paired_labels = [[-1]*10 for _ in range(len(paired_sentences))]
    paired_labels = [[0,0,1,0,0,0] for _ in range(len(paired_sentences))]
    
    return paired_sentences, paired_labels


def load_virtue_dialogues(data_dir, split="train"):
    path = os.path.join(data_dir, "virtue_dialogues.csv")
    df = pd.read_csv(path)
    labels = [df.iloc[i, 0] for i in range(df.shape[0])]
    sentences = [df.iloc[i, 1] for i in range(df.shape[0])]
    return sentences, labels


def load_paired_virtue_dialogues(data_dir, split="train"):
    path = os.path.join(data_dir, "virtue_dialogues.csv")
    df = pd.read_csv(path)
    labels = [df.iloc[i, 0] for i in range(df.shape[0])]
    sentences = [df.iloc[i, 1].split(' [SEP] ') for i in range(df.shape[0])]
    
    paired_sentences = []
    for s, l in zip(sentences, labels):
        if l == 1:
            paired_sentences.append(["", s[1]])
            paired_sentences.append(s)
        elif l == 0:
            paired_sentences.append(s)
            paired_sentences.append(["", s[1]])
        else:
            raise ValueError("Unknown label {}".format(l))
    # paired_labels = [[-1]*10 for _ in range(len(paired_sentences))]
    paired_labels = [[0,0,1,0,0,0] for _ in range(len(paired_sentences))]
    
    return paired_sentences, paired_labels


def load_deontology_sentences(data_dir, split="train"):
    path = os.path.join(data_dir, "deontology_{}.csv".format(split))
    df = pd.read_csv(path)
    labels = [df.iloc[i, 0] for i in range(df.shape[0])]
    scenarios = [df.iloc[i, 1] for i in range(df.shape[0])]
    excuses = [df.iloc[i, 2] for i in range(df.shape[0])]
    sentences = [sc + " [SEP] " + exc for (sc, exc) in zip(scenarios, excuses)]
    return sentences, labels

def load_paired_deontology_sentences(data_dir, split="train"):
    path = os.path.join(data_dir, "deontology_{}.csv".format(split))
    df = pd.read_csv(path)
    labels = [df.iloc[i, 0] for i in range(df.shape[0])]
    scenarios = [df.iloc[i, 1] for i in range(df.shape[0])]
    excuses = [df.iloc[i, 2] for i in range(df.shape[0])]
    sentences = [sc + " " + exc for (sc, exc) in zip(scenarios, excuses)]
    paired_sentences = []
    for s, l in zip(sentences, labels):
        if l == 1:
            paired_sentences.append("")
            paired_sentences.append(s)
        elif l == 0:
            paired_sentences.append(s)
            paired_sentences.append("")
        else:
            raise ValueError("Unknown label {}".format(l))
    # paired_labels = [[-1]*10 for _ in range(len(paired_sentences))]
    paired_labels = [[0,0,0,1,0,0] for _ in range(len(paired_sentences))]
    
    return paired_sentences, paired_labels

def load_util_sentences(data_dir, split="train"):
    path = os.path.join(data_dir, "util_{}.csv".format(split))
    df = pd.read_csv(path, header=None)
    sentences = []
    for i in range(df.shape[0]):
        sentences.append(df.iloc[i, 0])
        sentences.append(df.iloc[i, 1])
    # labels = [[-1]*10 for _ in range(len(sentences))]
    labels = [[0,0,0,0,1,0] for _ in range(len(sentences))]
    return sentences, labels

def load_pairwise_csv(data_dir, split="train"):
    path = os.path.join(data_dir, "{}.csv".format(split))
    df = pd.read_csv(path, header=None)
    sentences = []
    for i in range(df.shape[0]):
        sentences.append(df.iloc[i, 0])
        sentences.append(df.iloc[i, 1])
    # labels = [[-1]*10 for _ in range(len(sentences))]
    labels = [[0,0,0,0,1,0] for _ in range(len(sentences))]
    return sentences, labels

types_of_power = ['coercive', 'reward', 'legitimate', 'referent', 'expert', 'informational', 'economic', 'political', 'military', 'personal']
def load_power_sentences(data_dir, split="train"):
    L = [0,0,0,0,0,1]

    train_sentences, test_sentences, train_labels, test_labels = [], [], [], []

    data = json.load(open(os.path.join(data_dir, 'power_examples.json')))
    for p_type in data.keys():
        for d in data[p_type]:
            res = d['res']
            if len(res.keys()) < 5: continue

            if np.random.random() < 0.1:
                sentences, labels = test_sentences, test_labels
            else:
                sentences, labels = train_sentences, train_labels

            for i1, i2 in zip(range(1, 5), range(2, 6)):
                s1, s2 = res[str(i1)], res[str(i2)]
                sentences.append(s1)
                sentences.append(s2)
                # label = [p == p_type for p in types_of_power]
                label = L
                labels.extend([label] * 2)

    data = json.load(open(os.path.join(data_dir, 'power_examples_2.json')))
    for p_type in data.keys():
        for d in data[p_type]:
            res = d['res']
            if len(res['scenarios'].keys()) < 5: continue

            res_scenarios = res['scenarios']
            context = res['context']

            if np.random.random() < 0.1:
                sentences, labels = test_sentences, test_labels
            else:
                sentences, labels = train_sentences, train_labels

            for i1, i2 in zip(range(1, 5), range(2, 6)):
                s1, s2 = res_scenarios[str(i1)], res_scenarios[str(i2)]
                sentences.append(context + " " + s1)
                sentences.append(context + " " + s2)
                # label = [p == p_type for p in types_of_power]
                label = L
                labels.extend([label] * 2)

    path = os.path.join(data_dir, "all_data.csv")
    df = pd.read_csv(path, header=0, index_col=0)
    for i in range(df.shape[0]):
        context = df.iloc[i, 0]

        if np.random.random() < 0.1:
            sentences, labels = test_sentences, test_labels
        else:
            sentences, labels = train_sentences, train_labels

        if df.iloc[i, 3] == 'A > B':
            sentences.append(context + " " + df.iloc[i, 1])
            sentences.append(context + " " + df.iloc[i, 2])
        else:
            sentences.append(context + " " + df.iloc[i, 2])
            sentences.append(context + " " + df.iloc[i, 1])
        # label = [p in df.iloc[i, 4] for p in types_of_power]
        label = L
        labels.extend([label] * 2)
    
    if split == "train":
        return train_sentences, train_labels
    elif split == "test":
        return test_sentences, test_labels

def load_hh_sentences(data_dir, split="train"):
    hf_dataset = datasets.load_dataset("Anthropic/hh-rlhf")
    sentences = []
    for i in hf_dataset[split]:
        sentences.append(i['chosen'])
        sentences.append(i['rejected'])
    labels = [[0,0,0,0,0,0] for _ in range(len(sentences))]
    return sentences, labels

def load_harmless_sentences(data_dir, split="train"):
    hf_dataset = datasets.load_dataset("Anthropic/hh-rlhf", data_dir="harmless-base")
    sentences = []
    for i in hf_dataset[split]:
        sentences.append(i['chosen'])
        sentences.append(i['rejected'])
    labels = [[0,0,0,0,0,0] for _ in range(len(sentences))]
    return sentences, labels

def load_shp_sentences(data_dir, split="train"):
    hf_dataset = datasets.load_dataset("stanfordnlp/SHP")
    sentences = []
    for sample in hf_dataset[split]:
        if sample['upvote_ratio'] > 0.9 and sample['score_A'] > 50 and sample['score_B'] > 50:
            example_a = "\n\nHuman: " + sample['history'] + "\n\nAssistant: " + sample['human_ref_A']
            example_b = "\n\nHuman: " + sample['history'] + "\n\nAssistant: " + sample['human_ref_B']
            comment_tuple = (example_a, example_b)
            preferred_comment = 1 - int(sample['labels'])
            sentences.append(comment_tuple[preferred_comment])
            sentences.append(comment_tuple[not preferred_comment])
    labels = [[0,0,0,0,0,0] for _ in range(len(sentences))]
    return sentences, labels

def load_cm_dialog(data_dir, split="train"):
    morality_data = json.load(open('/data/[ANONYMIZED]/proxy_gaming/data/generate_data/cm_dialogue_merged.json'))
    sentences = []
    for d in morality_data:
        data = d["dialogue"]
        if "moral_dialogue" not in data: continue
        for p_list in [data["moral_dialogue"], data["immoral_dialogue"]]:
            prompt = ""
            for p in p_list:
                k,v = list(p.keys())[0], list(p.values())[0]
                if k == 'user':
                    prompt += "\n\nHuman: "
                elif k == 'chatbot':
                    prompt += "\n\nAssistant: "
                prompt += v
            sentences.append(prompt)
    labels = [[0,0,0,0,0,0] for _ in range(len(sentences))]
    return sentences, labels

def load_virtue_dialog(data_dir, split="train"):
    virtue_data = json.load(open('/data/[ANONYMIZED]/proxy_gaming/data/generate_data/virtue_dialogue_merged.json'))
    sentences = []
    predictions = []
    for d in virtue_data:
        data = d["dialogues"]
        if "good_dialogue" not in data: continue
        for p_list in [data["good_dialogue"], data["bad_dialogue"]]:
            prompt = ""
            for p in p_list:
                k,v = list(p.keys())[0], list(p.values())[0]
                if k == 'user':
                    prompt += "\n\nHuman: "
                elif k == 'chatbot':
                    prompt += "\n\nAssistant: "
                prompt += v
            sentences.append(prompt)
    labels = [[0,0,0,0,0,0] for _ in range(len(sentences))]
    return sentences, labels

def load_oasst_dialog(data_dir, split="train"):
    data_dir = "/data/datasets/oasst"
    assert split in ["train", "validation", "test"]
    datapath = {
        "train": f"{data_dir}/rank_pairs_train.jsonl",
        "validation": f"{data_dir}/rank_pairs_val.jsonl",
        "test": f"{data_dir}/rank_pairs_test.jsonl"
    }[split]

    sentences = []
    role_to_name = {
        "prompter": "Human",
        "assistant": "Assistant",
    }
    with open(datapath, "r") as f:
        for line in f.readlines():
            data = json.loads(line)
            good_msgs_list = data['context'] + [data['good_response']]
            good_example = "\n\n".join([f"{role_to_name[msg['role']]}: {msg['text']}" for msg in good_msgs_list])
            bad_msgs_list = data['context'] + [data['bad_response']]
            bad_example = "\n\n".join([f"{role_to_name[msg['role']]}: {msg['text']}" for msg in bad_msgs_list])
            sentences.append(good_example)
            sentences.append(bad_example)
    labels = [[0,0,0,0,0,0] for _ in range(len(sentences))]
    return sentences, labels

def load_oasst_safety_dialog(data_dir, split="train"):
    data_dir = "/data/datasets/oasst"
    assert split in ["train", "validation", "test"]
    datapath = {
        "train": f"{data_dir}/unsafe_score_pairs_train.jsonl",
        "validation": f"{data_dir}/unsafe_score_pairs_val.jsonl",
        "test": f"{data_dir}/unsafe_score_pairs_test.jsonl"
    }[split]

    sentences = []
    with open(datapath, "r") as f:
        for line in f.readlines():
            data = json.loads(line)
            good_msgs_list = data['context'] + [data['good_response']]
            good_example = "\n---\n".join([f"{msg['role']}: {msg['text']}" for msg in good_msgs_list])
            bad_msgs_list = data['context'] + [data['bad_response']]
            bad_example = "\n---\n".join([f"{msg['role']}: {msg['text']}" for msg in bad_msgs_list])
            sentences.append(good_example)
            sentences.append(bad_example)
    labels = [[0,0,0,0,0,0] for _ in range(len(sentences))]
    return sentences, labels

def load_data_subset(data_dir, dataset, split="train", subset="train80", percent=1.0):
    """
    This loads a deterministic subset of the specific dataset(s) used for the paper.
    The datasets are pairwise, and pairs are kept together.
    `subset` is either "train80" or "train20", referring to the 80% or 20% split of the data.
    `percent` is the percentage of the subset to use, e.g. 0.5 for 50% of the subset.
    """
    assert split == "train", "We only use the train splits for this dataset"

    if dataset == "hh":
        sentences, labels = load_hh_sentences(data_dir, split=split)
    elif dataset == "shp":
        sentences, labels = load_shp_sentences(data_dir, split=split)
    elif dataset == "oasst":
        sentences, labels = load_oasst_dialog(data_dir, split=split)
    else:
        raise NotImplementedError(f"Subset not implemented for dataset {dataset}")

    # We want to shuffle the data, but keep the pairs together
    # So we operate on the pair indices
    assert len(sentences) % 2 == 0, "Should be even number of sentences"
    idxs = list(range(int(len(sentences) / 2)))

    # Shuffle and get the relevant data subset
    random.seed(0)
    random.shuffle(idxs)
    sep_idx = int(len(idxs)*0.8)
    if subset == "train80":
        idxs = idxs[:sep_idx]
        if percent < 1.0:
            idxs = idxs[:int(len(idxs)*percent)]
    elif subset == "train20":
        idxs = idxs[sep_idx:]
        if percent < 1.0:
            raise NotImplementedError("Cannot use percent < 1.0 with subset=train20")
    else:
        raise ValueError(f"Unknown subset {subset}")
    
    # Get the sentences and labels in pairs
    sentences_subset, labels_subset = [], []
    for i in idxs:
        sentences_subset.append(sentences[2*i])
        sentences_subset.append(sentences[2*i+1])
        labels_subset.append(labels[2*i])
        labels_subset.append(labels[2*i+1])
    
    print(f"Loaded {len(sentences_subset)}/{len(sentences)} sentences from {dataset} subset {subset} with percent={percent}")
    return sentences_subset, labels_subset


def load_process_data(args, data_dir, dataset, split="train", model_name: str = None):
    data_dir = os.path.join(data_dir, dataset.removeprefix("paired_").removesuffix("_dialogues"))
    if "train80" in dataset or "train20" in dataset:
        # Name looks like {dataset}_{subset}_pc{percent}
        dataset_name = dataset.split("_train")[0]
        subset = dataset.split("_")[-2]
        percent = float(dataset.split("_")[-1].split("pc")[-1]) / 100
        print(f"Loading subset {subset} of {dataset_name} with percent={percent}")
        sentences, labels = load_data_subset(dataset=dataset_name, data_dir=data_dir, split=split, subset=subset, percent=percent)
    elif dataset == "gm_labeled":
        sentences, labels = load_pairwise_csv(data_dir, split=split)
    else:
        load_fn = {"cm": load_cm_sentences, "paired_cm": load_paired_cm_sentences, "paired_moral_stories": load_paired_moral_stories_sentences,
                "deontology": load_deontology_sentences, "paired_deontology": load_paired_deontology_sentences,
                "justice": load_justice_sentences, "paired_justice": load_paired_justice_sentences,
                "virtue": load_virtue_sentences, "paired_virtue": load_paired_virtue_sentences, 
                "virtue_dialogues": load_virtue_dialogues, "paired_virtue_dialogues": load_paired_virtue_dialogues,
                "util": load_util_sentences, "power": load_power_sentences,
                "hh": load_hh_sentences, "harmless": load_harmless_sentences,
                "shp": load_shp_sentences, "generated_util": load_util_sentences,
                "oasst": load_oasst_dialog, "oa_safety": load_oasst_safety_dialog,
                "cm_dialog": load_cm_dialog, "virtue_dialog": load_virtue_dialog,
                }[dataset]
        sentences, labels = load_fn(data_dir, split=split)
    labels = torch.tensor(labels)
    
    if args.add_prefix:
        if model_name in T5_MODEL_NAMES:
            if dataset == "cm" or "moral_stories" in dataset:
                sentences = [f'Rate the acceptability of the action in the following scenario.\n{s}' for s in sentences]
            elif dataset == "util":
                sentences = [f'Rate the happiness or pleasantness of the following scenario.\n{s}' for s in sentences]
            elif dataset == "power":
                sentences = [f'Rate the amount of power the first-person character has in the following scenario.\n{s}' for s in sentences]
            else:
                raise NotImplementedError
        else:
            if "cm" in dataset or "moral_stories" in dataset:
                sentences = [f'Consider the acceptability of the action in the following scenario: "{s}"' for s in sentences]
            elif "justice" in dataset:
                sentences = [f'Consider the justice level of the following scenario: "{s}"' for s in sentences]
            elif "virtue_dialogues" in dataset:
                sentences = [f'Cnosider whether the chatbot is {trait} in the following dialogue: "{s}"' for s, trait in sentences]
            elif "virtue" in dataset:
                sentences = [f'Consider whether the character is {trait} in the following scenario: "{s}"' for s, trait in sentences]
            elif "deontology" in dataset:
                sentences = [f'Consider the reasonableness of the following excuse: "{s}"' for s in sentences]
            elif "util" in dataset:
                sentences = [f'Consider the instantaneous pleasantness of the following scenario: "{s}"' for s in sentences]
            elif "power" in dataset:
                sentences = [f'Consider the amount of power the first-person character has in the following scenario: "{s}"' for s in sentences]
            elif "hh" in dataset:
                sentences = [f'Consider the helpfulness and harmlessness of the following scenario: "{s}"' for s in sentences]
            elif "harmless" in dataset:
                sentences = [f'Consider the harmlessness of the following scenario: "{s}"' for s in sentences]
            elif "shp" in dataset:
                sentences = [f'Consider the helpfulness of the comment to the original post: "{s}"' for s in sentences]
            elif "oasst" in dataset:
                sentences = [f'Consider the quality of the last assistant message in the conversation thread:\n\n"{s}"' for s in sentences]
            else:
                raise NotImplementedError
    
    if args.custom_prefix:
        sentences = [args.custom_prefix + " " + s for s in sentences]

    tokenizer = get_tokenizer(args.model)
    tokenized_inputs = tokenizer(sentences, padding='max_length', truncation=True, max_length=args.max_length, return_tensors="pt")
    inputs, masks = tokenized_inputs["input_ids"], tokenized_inputs["attention_mask"]
    within_bounds = [inputs[i, -1] == tokenizer.pad_token_id or inputs[i, 0] == tokenizer.pad_token_id for i in range(len(inputs))]
    if np.mean(within_bounds) < 1:
        print("{} {} fraction of examples within context window ({} tokens): {:.3f}".format(dataset, split, args.max_length, np.mean(within_bounds)))
    if "paired" in dataset or "util" in dataset or "power" in dataset or "hh" in dataset or "harmless" in dataset or "shp" in dataset or "dialog" in dataset or "oasst" in dataset or "oa_safety" in dataset or "gm_labeled" in dataset:
        assert len(inputs) % 2 == 0, "Should be even number of sentences for pairwise data"
        even_mask = [i for i in range(inputs.shape[0]) if i % 2 == 0]
        odd_mask = [i for i in range(inputs.shape[0]) if i % 2 == 1]
        even_inputs, odd_inputs = inputs[even_mask], inputs[odd_mask]
        even_labels, odd_labels = labels[even_mask], labels[odd_mask]
        even_masks, odd_masks = masks[even_mask], masks[odd_mask]
        inputs = torch.stack([even_inputs, odd_inputs], axis=1)
        labels = torch.stack([even_labels, odd_labels], axis=1)
        masks = torch.stack([even_masks, odd_masks], axis=1)

    data = TensorDataset(inputs, masks, labels)
    return data
