import polars as pl
import torch
import const_prompts as cp
from datasets import Dataset, DatasetDict

from utilities import generate_text_prompt, refine_directory_path, get_files_from_directory, preprocess_dataset, flatten_images
import const_configs as ccf
from tqdm import tqdm
import argparse

DATASET_PROMPT  = {
    "alpaca"    : {
        "prompt"        : "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.",
        "instruction"   : "Identify the validity of the following analogy.",
        "template"      : "{0}\n\n### Instruction:\n{1}\n\n### Input:\n{2}\n\n### Response:\n{3}",
    },
    "norobots"    : {
        "prompt"        : "",
        "instruction"   : "Classify the following analogy as valid or invalid. Provide no further explanation to your answer.",
        "template"      : "",
        "categories"    : ["Summarize","Generation","Rewrite","Open QA","Closed QA","Chat","Brainstorm","Coding","Classify","Extract"]

    },
    "gen_norobots"    : {
        "prompt"        : "",
        "instruction"   : "Complete the analogy below by generating the last element.",
        "template"      : "",
        "categories"    : ["Summarize","Generation","Rewrite","Open QA","Closed QA","Chat","Brainstorm","Coding","Classify","Extract"]

    },
    "gen_1hot_fixsize"  : {
        "prompt"        : "",
        "instruction"   : "Complete the analogy below by generating the last element. The output must not exceed {0} tokens.",
        "template"      : "",
        "categories"    : "",

    },
}

def create_generative_sft_dataset(text_analogies:list, labels:list, sep="\n", *args, **kwargs):
    norobots    = DATASET_PROMPT["gen_norobots"]
    instruction = norobots["instruction"]
    item_labels = list(zip(text_analogies, labels))
    all_records = []
    for analogy, label in tqdm(item_labels):
        if label == 0:
            continue
        parts   = analogy.split(sep)
        _input   = sep.join(parts[:-1])
        output  = parts[-1]
        prompt  = f"{instruction} {_input}"
        all_records.append({
            "prompt"    : prompt,
            "prompt_id" : "",
            "completion": output,
            
        })
    return Dataset.from_list(all_records)


def create_generative_1hot_dataset(_analogies:list, labels:list, *args, **kwargs):
    configs     = DATASET_PROMPT["gen_1hot_fixsize"]
    instruction = configs["instruction"]
    try:
        separator   = kwargs["separator"]
    except:
        separator   = ""
    item_labels = list(zip(_analogies, labels))
    all_records = []
    offset      = int(len(_analogies[0])/4)
    for analogy, label in tqdm(item_labels):
        if label == 0:
            continue
        
        _input  = separator.join((str(item) for item in analogy[:-offset]))
        output  = separator.join((str(item) for item in analogy[-offset:]))
        prompt  = f"{instruction.format(len(output))} {_input}.\n"
        all_records.append({
            "prompt"    : prompt,
            "prompt_id" : "",
            "completion": output,
            
        })
    return Dataset.from_list(all_records)


def create_sft_dataset(analogies:list, labels:list, *args, **kwargs):
    # Supervised finetuning
    try:
        instruction = kwargs["instruction"]
    except:
        instruction = cp.INSTRUCTIONS[3]
    valid_responses = ["Valid" if item==1 else "Invalid" for item in labels]
    irs         = list(zip(analogies, valid_responses, labels))
    all_records = []
    for _input, valid, label in tqdm(irs):
        prompt  = f"{instruction} {_input}"
        all_records.append({
            "input" : prompt,
            "output": valid,
            "label" : label,
        })
    return Dataset.from_list(all_records)

def create_1hot_vector_dataset(analogies:list, labels:list, *args, **kwargs):
    try:
        separator   = kwargs["separator"]
    except:
        separator   = ""
    xys         = list(zip(analogies, labels))
    all_records = []
    for x, y in xys:
        all_records.append({
            "input" : separator.join([str(item) for item in x]),
            "label" : y,
        })
    return Dataset.from_list(all_records)

def create_generative_1hot_simple_dataset(analogies:list, labels:list, *args, **kwargs):
    try:
        separator   = kwargs["separator"]
    except:
        separator   = ""
    try:
        df_dict     = kwargs["df_dict"]
        dict_flag   = True
    except:
        dict_flag   = False
    xys         = list(zip(analogies, labels))
    all_records = []
    offset      = int(len(analogies[0])/4)
    for analogy, label in xys:
        if label == 0:
            continue
        _input  = separator.join((str(item) for item in analogy[:-offset]))
        if dict_flag:
            _output = analogy[-offset:]
            y       = df_dict.filter(pl.col("encoded_image")==_output).select("index").item()
        else:
            y       = separator.join((str(item) for item in analogy[-offset:]))
        all_records.append({
            "input" : _input,
            "label" : y,
        })
    return Dataset.from_list(all_records)


def generate_text_prompts_(analogies, analogy_template, image_template, shapes_dict, colours_dict):
    all_prompts = []
    empty       = f"{colours_dict[0]} {shapes_dict[0]}"
    for images in analogies:
        analogy_description     = analogy_template.format(*[image_template.format(*[f"{colours_dict[c]} {shapes_dict[s]}" for s,c in cells]) for cells in images])
        analogy_description     = analogy_description.replace(empty, "empty")
        all_prompts.append(analogy_description)
    return all_prompts

GENERATORS  =   {
    "sft"       : create_sft_dataset, # WORKS - tested
    "gen_sft"   : create_generative_sft_dataset, # WORKS - tested
    "gen_1hot"  : create_generative_1hot_dataset, # WORKS - tested
    "gen_1hot_mcls" : create_generative_1hot_simple_dataset, # WORKS - tested
    "1hot"      : create_1hot_vector_dataset, # WORKS - tested many times
}

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Finetuning the model")
    parser.add_argument("input_dir", type=str, default="data/12/", help="")
    parser.add_argument("--type", type=str, default="sft", help="")
    parser.add_argument("--export", type=int, default=1, help="Export")
    parser.add_argument("--export_path", type=str, default="data/finetune/", help="Export path")

    args    = parser.parse_args()
    arg_input_dir   = args.input_dir
    arg_type        = args.type
    arg_export      = args.export
    arg_export_path = args.export_path
    # Set the default flag for prompting
    is_prompt   = True
    # Choose the dataset format
    if arg_type.lower() in GENERATORS.keys():
        creator = GENERATORS[arg_type.lower()]
    else:
        creator = GENERATORS["1hot"]
    if "1hot" in arg_type.lower():
        is_prompt   = False
    # Choose the interested datasets
    target_datasets = ["train", "dev", "test"]
    # Choose the template
    analogy_template    = cp.ANALOGY_TEMPLATES[3]
    image_template      = cp.IMAGE_TEMPLATES[2]
    instruction         = cp.INSTRUCTIONS[3]
    # Load df dict TODO: fix this later, this is just a quick and dirty fix
    df_dict     = pl.read_json(f"{arg_input_dir}dict.json").with_row_index()
    # Load the dataframes
    data_files  = get_files_from_directory(arg_input_dir, "*.json")
    ds_dict     = {}
    for path in data_files:
        part_name   = path.split("/")[-1].split(".")[0]
        if part_name in target_datasets:
            print("Processing ", path)
            df  = pl.read_json(path)
            df  = preprocess_dataset(df)
            y   = df.select("is_valid").to_series()
            # Create datasets
            if is_prompt:
                X   = df.select("original_analogy").to_series()
                X   = generate_text_prompts_(X, analogy_template, image_template, cp.SHAPES, cp.COLOURS)
            else:
                X   = flatten_images(df.select("encoded_analogy").to_series())
            ds  = creator(X, y, instruction=instruction, separator=",", df_dict=df_dict)
            ds_dict[part_name]  = ds
    # Save them all in to a dataset dict
    dataset     = DatasetDict(ds_dict)
    # And save it to disk
    if arg_export:
        dataset.save_to_disk(f"{arg_export_path}{arg_type}/")
        print("Write to disk successfully")