import datasets
import pandas
import pandas as pd
import os
import random
import json

from pyarrow.dataset import dataset
from transformers import AutoTokenizer

# input-output pairs ICL example template
# example: Input: 2 -> Output: prime\n
ICL_EXAMPLE_TEMPLATE = """{input_prefix}{input}{separator_first}{output_prefix}{output}{separator_second}"""

# input-output pairs prompt template
# example: Map a country to its capital city.\nInput: number -> Output: parity\nInput: 3 -> Output: odd\nInput: 4 -> Output: .
PROMPT_TEMPLATE = """{instruction}{ICL_examples}{query}"""


class DatasetConfig:
    """
    Configuration class for dataset construction.
    Attributes:
    - n_shot (int): The number of examples to include in the prompt.
    - data_size (int): The size of the dataset.
    - is_save (bool): Whether to save the dataset to disk.
    - is_corrupted (bool): Whether to use corrupted examples.
    - instruction (str): The instruction string.
    - input_prefix (str): The prefix for the input string.
    - output_prefix (str): The prefix for the output string.
    - separator_first (str): The separator between input and output.
    - separator_second (str): The separator after the output.
    - corrupted_list (str): Path of the dict of words with the same token length if is_corrupted is random.
    """

    def __init__(self, n_shot, data_size, is_save, is_corrupted, instruction, input_prefix,
                 output_prefix,
                 separator_first,
                 separator_second,
                 corrupted_list):
        self.n_shot = n_shot
        self.data_size = data_size
        self.is_save = is_save
        self.is_corrupted = is_corrupted
        self.instruction = instruction
        self.input_prefix = input_prefix
        self.output_prefix = output_prefix
        self.separator_first = separator_first
        self.separator_second = separator_second
        self.corrupted_list = corrupted_list

    def __str__(self):
        return (f"n_shot:{self.n_shot}, data_size:{self.data_size}, is_save:{self.is_save}, /"
                f"is_corrupted:{self.is_corrupted}, instruction: {self.instruction}, /"
                f"input_prefix: {self.input_prefix}, output_prefix: {self.output_prefix}, /"
                f"separator_first: {self.separator_first}, separator_second: {self.separator_second}/"
                f"corrupted_list: {self.corrupted_list}")


def construct_ICL_example_str(input_prefix: str, input: str, output_prefix: str, output: str, separator_first: str,
                              separator_second: str) -> str:
    """
    Construct an basic ICL example from input and output strings.

    Parameters:
    - input_prefix (str): The prefix for the input string.
    - input (str): The input string.
    - output_prefix (str): The prefix for the output string.
    - output (str): The output string.
    - separator_first (str): The separator between input and output.
    - separator_second (str): The separator after the output.

    Returns:
    - str: The constructed ICL example.
    """
    return ICL_EXAMPLE_TEMPLATE.format(input_prefix=input_prefix, input=input, separator_first=separator_first,
                                       output_prefix=output_prefix, output=output, separator_second=separator_second)


def construct_prompt_str(ICL_examples: str, query: str, instruction: str) -> str:
    """
    Construct a prompt from ICL examples and an instruction.

    Parameters:
    - instruction (str): The instruction string. Describe the task to be performed.
    - ICL_examples (str): The ICL examples string.
    - query (str): The query string. Question and empty answer.

    Returns:
    - str: The constructed prompt.
    """
    return PROMPT_TEMPLATE.format(instruction=instruction, ICL_examples=ICL_examples, query=query)


def construct_ICL_example(input_prefix: str, input: str, output_prefix: str, output: str, separator_first: str,
                          separator_second: str) -> list:
    """
    Construct an basic ICL example from input and output strings.

    Parameters:
    - input_prefix (str): The prefix for the input string.
    - input (str): The input string.
    - output_prefix (str): The prefix for the output string.
    - output (str): The output string.
    - separator_first (str): The separator between input and output.
    - separator_second (str): The separator after the output.

    Returns:
    - list: The split to logic words ICL example list.
    """
    ICL_examples_list = []
    if input_prefix:
        ICL_examples_list.append(input_prefix)
    if input:
        ICL_examples_list.append(input)
    if separator_first:
        ICL_examples_list.append(separator_first)
    if output_prefix:
        ICL_examples_list.append(output_prefix)
    if output:
        ICL_examples_list.append(output)
    if separator_second:
        ICL_examples_list.append(separator_second)

    return ICL_examples_list


def construct_prompt(ICL_examples: list, query: list, instruction: str) -> list:
    """
    Construct a prompt from ICL examples and an instruction.

    Parameters:
    - instruction (str): The instruction string. Describe the task to be performed.
    - ICL_examples (list): The ICL examples list.
    - query (list): The query list. Question and empty answer.

    Returns:
    - list: The split to logic words prompt list.
    """
    prompt_list = []
    if instruction:
        prompt_list.append(instruction)
    if ICL_examples:
        ICL_examples_flat = [item for ICL_example in ICL_examples for item in ICL_example]
        prompt_list.extend(ICL_examples_flat)
    prompt_list.extend(query)

    return prompt_list


def construct_n_shot_dataset(dataset: pandas.DataFrame, dataset_config: DatasetConfig) -> datasets.Dataset:
    """
    Construct a dataset for n-shot learning.

    Parameters:
    - dataset (pandas.DataFrame): The input dataset.
    - dataset_config (PromptConfig): The configuration for the prompt.

    Returns:
    - datasets.Dataset: The constructed dataset.
    """
    prompt_list = []
    answer_list = []

    n_shot = dataset_config.n_shot
    data_size = dataset_config.data_size
    is_corrupted = dataset_config.is_corrupted

    # Get all possible indices from the dataset
    all_indices = list(range(len(dataset)))
    query_set = set()

    # maintain a unique random sequence
    corruption_task = None
    if is_corrupted:
        corruption_task = random.Random()
        corruption_task.seed(1)

    for _ in range(data_size):
        # Randomly sample n_shot + 1 unique indices
        try:
            selected_indices = random.sample(all_indices, k=n_shot + 1)
        except ValueError:
            # Avoid n_shot + 1 larger than the dataset size
            selected_indices = random.choices(all_indices, k=n_shot + 1)

        example_indices = selected_indices[:-1]
        query_index = selected_indices[-1]

        # Ensure the query index is unique
        while query_index in query_set:
            # Reselect a query index if it is already in the set
            query_index = random.choice(all_indices)

        # Record the query index to avoid duplicates
        query_set.add(query_index)

        # Create ICL examples
        ICL_examples = []

        # ICL examples
        if is_corrupted:
            input_list = [str(dataset.iloc[idx]['input']) for idx in example_indices]
            output_list = [str(dataset.iloc[idx]['output']) for idx in example_indices]

            if is_corrupted == "shuffle":
                # Shuffle the input and output in all example indices to create corrupted examples
                corruption_task.shuffle(input_list)
                corruption_task.shuffle(output_list)
            elif is_corrupted == "random_select_y":
                corrupted_candidates = pd.read_json(dataset_config.corrupted_list)
                # corrupted_candidates = json.load(open(dataset_config.corrupted_list, 'r', encoding='utf-8'))
                all_indices_corrupted_dataset = list(range(len(corrupted_candidates)))
                # Randomly select an output with the same length as the clean output for each input
                for i in range(len(output_list)):
                    current_output = output_list[i]
                    dataset_selected_indice = random.sample(all_indices_corrupted_dataset, k=1)
                    corrupted_output = corrupted_candidates.iloc[dataset_selected_indice[0]]['output']
                    output_list[i] = corrupted_output
            elif is_corrupted == "random_select_x":
                corrupted_candidates = json.load(open(dataset_config.corrupted_list, 'r', encoding='utf-8'))
                input_corrupted_candidates = corrupted_candidates["input"]
                # Randomly select an input with the same length as the clean input for each output
                for i in range(len(input_list)):
                    current_input = input_list[i]
                    # find the corresponding token length in the dict
                    for token_length, lst in input_corrupted_candidates.items():
                        if current_input in lst:
                            # Substitute the input with a random choice from the list except the current input
                            input_list[i] = corruption_task.choice([l for l in lst if l != current_input])
                            break

            for i in range(len(example_indices)):
                input = input_list[i]
                output = output_list[i]
                ICL_example = construct_ICL_example(input_prefix=dataset_config.input_prefix, input=input,
                                                    separator_first=dataset_config.separator_first,
                                                    output_prefix=dataset_config.output_prefix, output=output,
                                                    separator_second=dataset_config.separator_second)
                ICL_examples.append(ICL_example)

        else:
            for idx in example_indices:
                input = str(dataset.iloc[idx]['input'])
                output = str(dataset.iloc[idx]['output'])
                ICL_example = construct_ICL_example(input_prefix=dataset_config.input_prefix, input=input,
                                                    separator_first=dataset_config.separator_first,
                                                    output_prefix=dataset_config.output_prefix, output=output,
                                                    separator_second=dataset_config.separator_second)

                ICL_examples.append(ICL_example)

        # Create the query
        query_input = str(dataset.iloc[query_index]['input'])
        query_output = str(dataset.iloc[query_index]['output'])
        query = construct_ICL_example(input_prefix=dataset_config.input_prefix, input=query_input,
                                      separator_first=dataset_config.separator_first,
                                      output_prefix=dataset_config.output_prefix, output="",
                                      separator_second="")

        # Construct the prompt
        instruction = dataset_config.instruction
        prompt = construct_prompt(instruction=instruction, ICL_examples=ICL_examples,
                                  query=query)
        answer = query_output

        # Append the prompt and answer to the lists
        prompt_list.append(prompt)
        answer_list.append(answer)

    # Create a dataset from the prompt and answer pairs
    dataset_dict = {
        'prompt': prompt_list,
        'answer': answer_list
    }

    print(dataset_dict)
    dataset = datasets.Dataset.from_dict(dataset_dict)

    return dataset


def load_raw_dataset(dataset_name: str, dataset_config: DatasetConfig, save_name: str = None) -> datasets.Dataset:
    """
    Load a raw dataset from ../datasets directory

    Parameters:
    - name (str): The name of the dataset to load.
    - dataset_config (DatasetConfig): The configuration for the dataset.

    Returns:
    - datasets.DatasetDict: The loaded dataset.
    """
    BASE_DIR = os.path.dirname(__file__)
    DATA_DIR = os.path.join(BASE_DIR, f"../datasets/{dataset_name}")

    n_shot = dataset_config.n_shot
    data_size = dataset_config.data_size
    is_save = dataset_config.is_save

    raw_dataset = pd.read_json(DATA_DIR)

    dataset = construct_n_shot_dataset(raw_dataset, dataset_config=dataset_config)

    # Save the dataset to disk
    if is_save:
        if save_name:
            dataset_path = f"../datasets/processed/{save_name}"
        else:
            dataset_path = f"../datasets/processed/{dataset_name}_{n_shot}_{data_size}"

            if dataset_config.instruction:
                dataset_path = dataset_path + "_instruct"

            if dataset_config.is_corrupted:
                dataset_path = dataset_path + "_corrupted"

        dataset.save_to_disk(os.path.join(BASE_DIR, dataset_path))
        print(
            f"Dataset saved to {os.path.join(BASE_DIR, dataset_path)}")

    for i in range(len(dataset)):
        print(f"{dataset[i]['prompt']}")
        print(f"{dataset[i]['answer']}")
        print()

    return dataset


def load_dataset(dataset_name: str) -> datasets.Dataset:
    """
    Load a processed dataset from ../datasets/processed directory

    Parameters:
    - name (str): The name of the dataset to load.

    Returns:
    - datasets.DatasetDict: The loaded dataset.
    """
    BASE_DIR = os.path.dirname(__file__)
    DATA_DIR = os.path.join(BASE_DIR, f"../datasets/processed/{dataset_name}")

    dataset = datasets.load_from_disk(DATA_DIR)

    for i in range(len(dataset)):
        print(f"{dataset[i]['prompt']}")
        print(f"{dataset[i]['answer']}")
        print()

    return dataset


def load_combined_dataset(dataset_name: str, type: str) -> datasets.Dataset:
    """
    Load an ambiguous dataset from ../datasets/ambiguous directory

    Parameters:
    - name (str): The name of the dataset to load.
    - type (str): ambiguous / conflicting

    Returns:
    - datasets.DatasetDict: The loaded dataset.
    """
    BASE_DIR = os.path.dirname(__file__)
    DATA_DIR = os.path.join(BASE_DIR, f"../datasets/{type}/{dataset_name}")

    dataset = datasets.load_from_disk(DATA_DIR)

    for i in range(len(dataset)):
        print(f"{dataset[i]['prompt']}")
        print(f"{dataset[i]['answer']}")
        print(f"{dataset[i]['answer2']}")
        print()

    return dataset


def generate_conflicting_dataset(dataset1_name, dataset2_name, dataset_config, data1_num, data2_num,
                                 save_name, fix_position=None):
    pass


def generate_ambiguous_dataset(dataset1_name, dataset2_name, dataset_config, ambiguity_mode, data1_num, data2_num,
                               save_name, fix_position=None):
    """
    Construct ambiguous dataset from normal dataset and ambiguous dataset
    :param dataset1_name: normal dataset without ambiguous examples
    :param dataset2_name: ambiguous dataset with examples causing ambiguity
    :param dataset_config: DatasetConfig object containing dataset configuration
    :param ambiguity_mode: copy / constant
    :param data1_num: number of examples from dataset1 in one prompt
    :param data2_num: number of examples from dataset2 in one prompt
    :param save_name: name to save the generated dataset
    :param fix_position [optional]: if True, the position of the answer from dataset1 is fixed, otherwise it is random
    :return: combined ambiguous dataset ICL prompt and answer pairs (answer1 from dataset1, answer2 from dataset2)
    """
    datasize = dataset_config.data_size
    BASE_DIR = os.path.dirname(__file__)
    DATA_DIR_1 = os.path.join(BASE_DIR, f"../datasets/{dataset1_name}")
    DATA_DIR_2 = os.path.join(BASE_DIR, f"../datasets/{dataset2_name}")

    dataset1 = pd.read_json(DATA_DIR_1)
    dataset2 = pd.read_json(DATA_DIR_2)

    prompt_list = []
    answer1_list = []
    answer2_list = []

    # Get all possible indices from the dataset
    all_indices_dataset1 = list(range(len(dataset1)))
    all_indices_dataset2 = list(range(len(dataset2)))
    query_set = set()

    select_task = random.Random()
    select_task.seed(42)

    is_corrupted = dataset_config.is_corrupted
    corruption_task = None
    if is_corrupted:
        corruption_task = random.Random()
        corruption_task.seed(1)
        tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")

    for _ in range(datasize):
        # Randomly sample data1_num + 1 unique indices
        # data1_num for ICL examples, 1 for query
        dataset1_selected_indices = select_task.sample(all_indices_dataset1, k=data1_num)
        dataset2_selected_indices = select_task.sample(all_indices_dataset2, k=data2_num)

        dataset1_example_indices = dataset1_selected_indices
        dataset2_example_indices = dataset2_selected_indices

        # Try to avoid query index duplication as much as possible
        available_indices_dataset1 = list(set(all_indices_dataset1) - query_set)
        if available_indices_dataset1:
            query_index = select_task.choice(available_indices_dataset1)
        else:
            # If all indices are used, reselect a query index from the original dataset
            query_index = select_task.choice(all_indices_dataset1)

        # Record the query index to avoid duplicates
        query_set.add(query_index)

        # Create ICL examples
        ICL_examples_dataset1 = []

        if is_corrupted:
            input_list1 = [str(dataset1.iloc[idx]['input']) for idx in dataset1_example_indices]
            output_list1 = [str(dataset1.iloc[idx]['output']) for idx in dataset1_example_indices]
            input_list2 = [str(dataset2.iloc[idx]['input']) for idx in dataset2_example_indices]
            output_list2 = [str(dataset2.iloc[idx]['output']) for idx in dataset2_example_indices]

            if is_corrupted == "shuffle":
                # Shuffle the input and output in all example indices to create corrupted examples
                corruption_task.shuffle(input_list1)
                corruption_task.shuffle(output_list1)
                corruption_task.shuffle(input_list2)
                corruption_task.shuffle(output_list2)
            elif is_corrupted == "random_select_y":
                corrupted_candidates = json.load(open(dataset_config.corrupted_list, 'r', encoding='utf-8'))
                output_corrupted_candidates = corrupted_candidates["output"]
                # Randomly select an output with the same length as the clean output for each input
                for i in range(len(output_list1)):
                    current_output = output_list1[i]
                    current_output_token_length = len(tokenizer.tokenize(current_output))
                    # find the corresponding token length in the dict
                    for token_length, lst in output_corrupted_candidates.items():
                        if current_output in lst:
                            # Substitute the output with a random choice from the list except the current output
                            output_list1[i] = corruption_task.choice([l for l in lst if l != current_output])
                            break
                        elif current_output_token_length == int(token_length):
                            output_list1[i] = corruption_task.choice(lst)
                            break
                for i in range(len(output_list2)):
                    current_output = output_list2[i]
                    current_output_token_length = len(tokenizer.tokenize(current_output))
                    # find the corresponding token length in the dict
                    for token_length, lst in output_corrupted_candidates.items():
                        if current_output in lst:
                            # Substitute the output with a random choice from the list except the current output
                            output_list2[i] = corruption_task.choice([l for l in lst if l != current_output])
                            break
                        elif current_output_token_length == int(token_length):
                            output_list2[i] = corruption_task.choice(lst)
                            break
            elif is_corrupted == "random_select_x":
                corrupted_candidates = json.load(open(dataset_config.corrupted_list, 'r', encoding='utf-8'))
                input_corrupted_candidates = corrupted_candidates["input"]
                # Randomly select an input with the same length as the clean input for each output
                for i in range(len(input_list1)):
                    current_input = input_list1[i]
                    # find the corresponding token length in the dict
                    for token_length, lst in input_corrupted_candidates.items():
                        if current_input in lst:
                            # Substitute the input with a random choice from the list except the current input
                            input_list1[i] = corruption_task.choice([l for l in lst if l != current_input])
                            break
                for i in range(len(input_list2)):
                    current_input = input_list2[i]
                    # find the corresponding token length in the dict
                    for token_length, lst in input_corrupted_candidates.items():
                        if current_input in lst:
                            # Substitute the input with a random choice from the list except the current input
                            input_list2[i] = corruption_task.choice([l for l in lst if l != current_input])
                            break

        # ICL examples
        for i, idx in enumerate(dataset1_example_indices):
            if is_corrupted:
                input = input_list1[i]
                output = output_list1[i]
            else:
                input = str(dataset1.iloc[idx]['input'])
                output = str(dataset1.iloc[idx]['output'])

            ICL_example = construct_ICL_example(input_prefix=dataset_config.input_prefix, input=input,
                                                separator_first=dataset_config.separator_first,
                                                output_prefix=dataset_config.output_prefix, output=output,
                                                separator_second=dataset_config.separator_second)

            ICL_examples_dataset1.append(ICL_example)

        ICL_examples_dataset2 = []

        for i, idx in enumerate(dataset2_example_indices):
            if is_corrupted:
                input = input_list2[i]
                output = output_list2[i]
            else:
                input = str(dataset2.iloc[idx]['input'])
                output = str(dataset2.iloc[idx]['output'])

            ICL_example = construct_ICL_example(input_prefix=dataset_config.input_prefix, input=input,
                                                separator_first=dataset_config.separator_first,
                                                output_prefix=dataset_config.output_prefix, output=output,
                                                separator_second=dataset_config.separator_second)

            ICL_examples_dataset2.append(ICL_example)

        all_ICL_examples = ICL_examples_dataset1 + ICL_examples_dataset2

        ICL_examples = [None] * len(all_ICL_examples)  # Initialize with None to place examples

        if fix_position:
            # Place dataset1 examples at fixed positions
            random.shuffle(ICL_examples_dataset1)
            for i, pos in enumerate(fix_position):
                ICL_examples[pos] = ICL_examples_dataset1[i]

            # Get available positions for dataset2 examples
            available_positions = [i for i, x in enumerate(ICL_examples) if x is None]

            # Shuffle dataset2 examples and place them in available positions
            random.shuffle(ICL_examples_dataset2)
            for i, pos in enumerate(available_positions):
                ICL_examples[pos] = ICL_examples_dataset2[i]

        else:
            # shuffle ICL examples
            random.shuffle(all_ICL_examples)
            ICL_examples = all_ICL_examples  # Assign shuffled list

        # Create the query
        query_input = str(dataset1.iloc[query_index]['input'])
        query_output = str(dataset1.iloc[query_index]['output'])
        query = construct_ICL_example(input_prefix=dataset_config.input_prefix, input=query_input,
                                      separator_first=dataset_config.separator_first,
                                      output_prefix=dataset_config.output_prefix, output="",
                                      separator_second="")

        # Construct the prompt
        instruction = dataset_config.instruction
        prompt = construct_prompt(instruction=instruction, ICL_examples=ICL_examples,
                                  query=query)
        answer = query_output

        # Append the prompt and answer to the lists
        prompt_list.append(prompt)
        answer1_list.append(answer)

        if ambiguity_mode == "copy":
            # Copy the answer from dataset2
            answer2_list.append(query_input)
        elif ambiguity_mode == "constant":
            # Use a constant answer for dataset2
            answer2_list.append(str(dataset2.iloc[0]['output']))
        else:
            answer2_list.append(str(dataset1.iloc[query_index]['output_fr']))

    # Create a dataset from the prompt and answer pairs
    dataset_dict = {
        'prompt': prompt_list,
        'answer': answer1_list,
        'answer2': answer2_list
    }

    dataset = datasets.Dataset.from_dict(dataset_dict)

    for i in range(len(dataset)):
        print(f"{dataset[i]['prompt']}")
        print(f"{dataset[i]['answer']}")
        print(f"{dataset[i]['answer2']}")
        print()

    save_path = f"../datasets/processed/{save_name}"

    dataset.save_to_disk(os.path.join(BASE_DIR, save_path))
    print(f"Dataset saved to {os.path.join(BASE_DIR, save_path)}")

    return dataset


if __name__ == "__main__":
    # # test load raw dataset and generate processed dataset
    dataset_name = "english-japanese-chinese-ambiguous"
    dataset_path = f"{dataset_name}.json"

    dataset_config = DatasetConfig(
        n_shot=5,
        data_size=100,
        is_save=True,
        # is_corrupted="random_select_y",  # random_select_y, random_select_x, shuffle
        is_corrupted=False,
        instruction="",
        input_prefix="",
        output_prefix="",
        separator_first="\t",
        separator_second="\n",
        # is required if is_corrupted is random_select_y or random_select_x
        corrupted_list=f"../datasets/{dataset_name}-token-length.json"
    )

    random.seed(42)

    load_raw_dataset(dataset_path, dataset_config, save_name="japanese-ambiguous.json_5_100")

    # test load processed dataset and metadata
    dataset_name = "japanese-ambiguous.json_5_100"
    dataset = load_dataset(dataset_name)
