import random
import numpy as np
import torch
import json
import os
from datasets import load_dataset

def set_seed(seed):
    """
    Helper function for reproducible behavior to set the seed in `random`, `numpy`, and `torch`.

    Args:
        seed (`int`): The seed to set.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def save_dict(dict_input, filename):
    if os.path.exists(filename):
        with open(filename, 'r') as file:
            dict_existing = json.load(file)
        dict_merged = {**dict_existing, **dict_input}
    else:
        dict_merged = dict_input

    with open(filename, 'w') as file:
        json.dump(dict_merged, file)


def find_all_indices(text, substring):
    indices = []
    start_index = 0
    while True:
        index = text.find(substring, start_index)
        if index == -1:
            break
        indices.append(index)
        start_index = index + 1
    return indices

def load_hhrlhf_template(template_path):
    with open(os.path.join(template_path, 'f1.txt')) as f:
        fschat = f.read()
    with open(os.path.join(template_path, 'f2.txt')) as f:
        fsred = f.read()
    with open(os.path.join(template_path, 'r1.txt')) as f:
        redA = f.read()
    with open(os.path.join(template_path, 'r2.txt')) as f:
        redB = f.read()
    
    template = {
        'fschat': fschat,
        'fsred': fsred,
        'redA': redA,
        'redB': redB
    }
    return template

def load_truthfulqa_template(template_path):
    with open(os.path.join(template_path, 'f1.txt')) as f:
        fschat = f.read()
    
    template = {
        'fschat': fschat,
    }
    return template


def load_truthfulqa_mcq_template(template_path):
    with open(os.path.join(template_path, 'f2.txt')) as f:
        fschat = f.read()
    
    template = {
        'fschat': fschat,
    }
    return template


def build_hhrlhf_dataset(
        tokenizer, fschat,
        start=0, end=38961, split="train", select=None,
):
    ds = load_dataset("Anthropic/hh-rlhf", data_dir="red-team-attempts", split=split)
    ds = ds.shuffle(seed=42)
    ds1 = ds.select(range(start, end))
    original_columns1 = ds1.column_names
    
    def preprocess_function(examples):
        new_examples = {
            "query": [],
            "input_ids": [],
            "queryf": [],
        }
        for i in range(len(examples['transcript'])):
            transcript = examples['transcript'][i]
            a = 'Assistant: '
            inds = find_all_indices(transcript, a)
            ind = inds[0]
            text = transcript[:ind + len(a) - 1]
            text = text.replace("\n\n", "\n")
            tokenized_question = tokenizer(text, truncation=True)
            new_examples["query"].append(text)
            new_examples["input_ids"].append(tokenized_question["input_ids"])
            textf = fschat + '\n' + text
            new_examples["queryf"].append(textf)

        return new_examples

    ds1 = ds1.map(
        preprocess_function,
        batched=True,
        remove_columns=original_columns1,
    )

    ds1 = ds1.filter(lambda x: len(x["input_ids"]) < 512, batched=False)

    ds1.set_format(type="torch")

    return ds1

def build_truthfulqa_dataset(
        tokenizer, fschat,
        start=0, end=817, split="validation", select=None,
):
    
    ds = load_dataset("truthful_qa", "generation", split='validation')
    ds1 = ds.select(range(start, end))
    original_columns1 = ds1.column_names
    
    def preprocess_function(examples):
        new_examples = {
            "query": [],
            "input_ids": [],
            "queryf": [],
        }
        for i in range(len(examples['question'])):
            text = examples['question'][i]
            text = text.replace("\n\n", "\n")
            text = f"\nHuman: {text}\nAssistant:"
            tokenized_question = tokenizer(text, truncation=True)
            new_examples["query"].append(text)
            new_examples["input_ids"].append(tokenized_question["input_ids"])
            textf = fschat + '\n' + text
            new_examples["queryf"].append(textf)

        return new_examples

    ds1 = ds1.map(
        preprocess_function,
        batched=True,
        remove_columns=original_columns1,
    )

    ds1.set_format(type="torch")

    return ds1

def get_truthfulqa_correct_reference(start, end):
    ds = load_dataset("truthful_qa", "generation", split='validation')
    ds = ds.select(range(start, end))
    
    reference_responses = []
    for row in ds:
        reference_responses.append([row['best_answer']]+row['correct_answers'])
    
    return reference_responses

def get_truthfulqa_incorrect_reference(start, end):
    ds = load_dataset("truthful_qa", "generation", split='validation')
    ds = ds.select(range(start, end))
    
    reference_responses = []
    for row in ds:
        reference_responses.append(row['incorrect_answers'])
    
    return reference_responses