from datasets import Dataset, concatenate_datasets
from tqdm import tqdm
import random
import json
import re
import os
import importlib
import argparse
import pandas as pd
from run_experiment import DATASET2CONFIGS
from datasets import concatenate_datasets
import copy


PROMPT_DICT = {
    "prompt_input": (
        "Below is an instruction that describes a task, paired with an input that provides further context. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
    ),
    "prompt_no_input": (
        "Below is an instruction that describes a task. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Response:"
    ),
}


def alpaca_add_template(item):
    prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
    source = prompt_input.format(**item) if item.get("input", "") != "" else prompt_no_input.format(**item)
    targets = item["output"]
    return source, targets


def load_openai_requests(json_file):
    items = json.load(open(json_file, "r"))
    if "choices" not in items.keys():
        return []
    else:
        results = []
        for choice in items["choices"]:
            results.append(choice["message"]["content"])
        return results

def load_alpaca_paraphrases(args):
    instruction_paraphrases = dict()
    root_dir = os.path.join(args.alpaca_input_dir, "alpaca-paraphrases")
    all_paraphrase_folders = [f for f in os.listdir(root_dir) if "." not in f]
    print("Loading Alpaca instruction paraphrases...")
    for folder in tqdm(all_paraphrase_folders):
        idx = int(folder)
        instruction_paraphrases[idx] = []
        folder_path = os.path.join(root_dir, folder)
        for file in os.listdir(folder_path):
            instruction_paraphrases[idx].extend(load_openai_requests(os.path.join(folder_path, file)))
    return instruction_paraphrases

def load_flan_paraphrases(args, submix):
    instruction_paraphrases = dict()
    root_dir = os.path.join(args.flan_input_dir, submix)
    all_paraphrase_folders = [f for f in os.listdir(root_dir) if "." not in f]
    print("Loading Flan instruction paraphrases...")
    for folder in tqdm(all_paraphrase_folders):
        idx = int(folder)
        instruction_paraphrases[idx] = []
        folder_path = os.path.join(root_dir, folder)
        for file in os.listdir(folder_path):
            instruction_paraphrases[idx].extend(load_openai_requests(os.path.join(folder_path, file)))
            
    return instruction_paraphrases


def process_train(args):

    def to_data(input_text, output_text, prototype, task_id, instance_id, do_alignment=True):
        data = {
            "input": input_text,
            "output": output_text
        }
        if do_alignment:
            data["prototype"] = prototype
            data["task_id"] = task_id
            data["instance_id"] = instance_id
        return data

    if args.alignment:
        assert args.flan_instructions_per_task > 0

    if args.model_type == "flan":
        input_dir = args.flan_input_dir
        training_samples = [json.load(open(os.path.join(input_dir, f"{c}.json"))) for c in ["flan2021", "t0", "niv2", "cot"]]
        training_partition = [args.flan2021_num, args.t0_num, args.niv2_num, args.cot_num]
        if args.flan_instructions_per_task > 0:
            paraphrased_samples = [load_flan_paraphrases(args, c) for c in ["flan2021", "t0", "niv2", "cot"]]
            idxs_pools = []
            for submix_samples in paraphrased_samples:
                submix_idxs_pool = []
                for task_id in submix_samples.keys():
                    if len(submix_samples[task_id]) >= args.flan_instructions_per_task:
                        submix_idxs_pool.append(task_id)
                idxs_pools.append(submix_idxs_pool)
        else:
            idxs_pools = [range(len(submix_training_samples)) for submix_training_samples in training_samples]

        used_task_ids = []
        for idxs_pool, partition, name in zip(idxs_pools, training_partition, ["flan2021", "t0", "niv2", "cot"]):
            if not len(idxs_pool) >= partition:
                raise ValueError("Not enough samples ({}) in {} for the partition ({})".format(len(idxs_pool), name,partition))
            submix_task_ids = random.sample(idxs_pool, partition)
            used_task_ids.append(submix_task_ids)
        
        datasets = []
        for i, (submix_training_samples, submix_task_ids) in enumerate(zip(training_samples, used_task_ids)):
            submix_dataset = []
            for task_id in submix_task_ids:
                task_dict = submix_training_samples[task_id]
                input_text, target_text = task_dict["input_text"], task_dict["target_text"]
                submix_dataset.append(to_data(input_text, target_text, True, i, task_id, args.alignment))
                if args.flan_instructions_per_task > 0:
                    for paraphrase in paraphrased_samples[i][task_id][:args.flan_instructions_per_task]:
                        submix_dataset.append(to_data(paraphrase, target_text, False, i, task_id, args.alignment))
            submix_dataset = Dataset.from_list(submix_dataset)
            datasets.append(submix_dataset)
        datasets = concatenate_datasets(datasets)
        save_dir = os.path.join(args.output_dir, "flan_train_{}_{}".format(sum(training_partition), args.alpaca_instructions_per_task))
        save_dir += "_alignment" if args.alignment else ""
        datasets.save_to_disk(save_dir)
    else:
        input_dir = args.alpaca_input_dir
        alpaca_data = json.load(open(os.path.join(input_dir, "alpaca_data.json"), "r"))
        train_set = []
        if args.alpaca_instructions_per_task > 0:
            alpaca_paraphrases = load_alpaca_paraphrases(args)
            insufficient_paraphrases = []
            for task in alpaca_paraphrases.keys():
                if len(alpaca_paraphrases[task]) <= args.alpaca_instructions_per_task:
                    insufficient_paraphrases.append(task)
            
            for task in insufficient_paraphrases:
                alpaca_paraphrases.pop(task)
            
            assert len(alpaca_paraphrases) >= args.alpaca_samples
            idxs_pool = list(alpaca_paraphrases.keys())
        else:
            assert len(alpaca_data) >= args.alpaca_samples
            idxs_pool = range(len(alpaca_data))

        tasks_ids = random.sample(idxs_pool, args.alpaca_samples)
        
        for task_id in tasks_ids:
            input_text, output_text = alpaca_add_template(alpaca_data[task_id])
            train_set.append(to_data(input_text, output_text, True, task_id, 0, args.alignment))
            if args.alpaca_instructions_per_task > 0:
                for instruction_paraphrase in alpaca_paraphrases[task_id][:args.alpaca_instructions_per_task]:
                    item = alpaca_data[task_id].copy()
                    item["instruction"] = instruction_paraphrase
                    input_text, output_text = alpaca_add_template(item)
                    train_set.append(to_data(input_text, output_text, False, task_id, 0, args.alignment))
        train_set = Dataset.from_list(train_set)
        save_dir = os.path.join(args.output_dir, "alpaca_train_{}_{}".format(len(tasks_ids), args.flan_instructions_per_task))
        save_dir += "_alignment" if args.alignment else ""
        train_set.save_to_disk(save_dir)



def process_test(args):

    def load_test_instructions(df: pd.DataFrame):
        instructions = {}
        for i in df.index:
            dataset = df.loc[i]["Dataset"]
            if dataset not in instructions.keys():
                instructions[dataset] = []
            instruction = "{}/{}/{}".format(df.loc[i]["Collection"], df.loc[i]["Type"], df.loc[i]["ID"])
            instructions[dataset].append(instruction)
        return instructions

    instruction_csv = "./results_csv/Improvement/Alpaca/Instructions.csv" if args.model_type == "alpaca" else "./results_csv/Improvement/Flan/Instructions.csv"
    test_instructions = load_test_instructions(pd.read_csv(instruction_csv, index_col=None))

    output_dir = os.path.join(args.output_dir, "{}_test".format(args.model_type))

    for dataset in test_instructions.keys():
        dataset_observed_dir = os.path.join(output_dir, dataset, "observed")
        dataset_unobserved_dir = os.path.join(output_dir, dataset, "unobserved")

        if not os.path.exists(dataset_observed_dir):
            os.makedirs(dataset_observed_dir, exist_ok=True)
        if not os.path.exists(dataset_unobserved_dir):
            os.makedirs(dataset_unobserved_dir, exist_ok=True)
        
        input_dir, config_dir = DATASET2CONFIGS[dataset]
        spec = importlib.util.spec_from_file_location("config", config_dir)
        config = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(config)

        unobserved_datasets, observed_datasets = [], []

        for instruction in test_instructions[dataset]:
            test_set = config.load_data_testing(input_dir, instruction, args)

            if "Unobserved" in instruction or "Default" in instruction:
                unobserved_datasets.append(test_set)
            else:
                observed_datasets.append(test_set)
            
        observed_dataset = concatenate_datasets(observed_datasets)
        unobserved_dataset = concatenate_datasets(unobserved_datasets)
        observed_dataset.save_to_disk(dataset_observed_dir)
        unobserved_dataset.save_to_disk(dataset_unobserved_dir)
    

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    
    parser.add_argument('--model_type', choices=["flan", "alpaca"], type=str, default="flan")
    parser.add_argument("--target", type=str, choices=["train", "test"], default="train")

    parser.add_argument("--output_dir", type=str, default="./negation_data")
    parser.add_argument("--seed", type=int, default=42)

    parser.add_argument('--alignment', default=True, type=bool)
    
    # For Training Flan Models
    parser.add_argument('--flan_input_dir', default="./data/Negation", type=str)
    parser.add_argument('--flan_instructions_per_task', default=3, type=int)

    parser.add_argument('--flan2021_num', default=455, type=int)
    parser.add_argument('--t0_num', default=273, type=int)
    parser.add_argument('--niv2_num', default=256, type=int)
    parser.add_argument('--cot_num', default=0, type=int)

    # For Training Alpaca Models
    parser.add_argument('--alpaca_input_dir', default="./data/Alpaca", type=str)
    parser.add_argument('--alpaca_instructions_per_task', default=3, type=int)
    parser.add_argument('--alpaca_samples', default=1000, type=int)

    # For Testing
    parser.add_argument('--maximum_test_samples', default=500, type=int)

    args = parser.parse_args()
    if args.target == "train":
        process_train(args)
    elif args.target == "test":
        process_test(args)
