# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

# For dataset details visit: https://huggingface.co/datasets/samsum
import json
import os
import random
import datasets
import itertools
from src.xlogominiprog.translator import taskjs2ascii, taskjs2nl, codejs2python
from src.xlogominiprog.prompts.prompt_template import *
from src.xlogomini.components.task import Task
from src.xlogomini.components.code.xlogo_code import Code
from src.xlogominiprog.utils.remove_nulls import remove_nulls_from_json

B_INST, E_INST = "[INST]", "[/INST]"


def print_task_and_code(dataset, dataset_tok, tokenizer, sample_idx):
    print("\n\n\n============ Dataset-INFO ============")
    print(f"Dataset Size: {len(dataset)}")
    if 'success' in dataset.column_names:
        print(f"Positive samples: {sum(dataset['success'])}")
        print(f"Negative samples: {len(dataset) - sum(dataset['success'])}")
        print(f"\n\n--- DEBUG-INFO: Sample index (condition: {dataset['success'][sample_idx]}): ", sample_idx)
    task_json = remove_nulls_from_json(dataset[sample_idx]["task_json"])
    task_json['constraints'] = remove_nulls_from_json(dataset[sample_idx]['constraints'])
    if 'goal' not in task_json.keys():
        task_json['goal'] = []
    code_json = remove_nulls_from_json(dataset[sample_idx]["code_json"])
    print("Task:\n", Task.init_from_json(task_json))
    print("Code:\n", Code(code_json))
    print(tokenizer.decode(dataset_tok[sample_idx]["input_ids"]))
    print("==============================\n\n\n")


def tokenize_dialog(dialog, tokenizer):
    prompt_tokens = [tokenizer.encode(f"{tokenizer.bos_token}{B_INST} {(prompt['content']).strip()} {E_INST}") for
                     prompt in dialog[::2]]
    answer_tokens = [tokenizer.encode(f"{answer['content'].strip()} {tokenizer.eos_token}")
                     for answer in dialog[1::2]]
    dialog_tokens = list(itertools.chain.from_iterable(zip(prompt_tokens, answer_tokens)))
    # Add labels, convert prompt token to -100 in order to ignore in loss function
    labels_tokens = [len(c) * [-100, ] if i % 2 == 0 else c for i, c in enumerate(dialog_tokens)]

    combined_tokens = {
        "input_ids": list(itertools.chain(*(t for t in dialog_tokens))),
        "labels"   : list(itertools.chain(*(t for t in labels_tokens))),
    }

    return dict(combined_tokens, attention_mask=[1] * len(combined_tokens["input_ids"]))


def get_condition_and_instruction_template(sample, dataset_config, eval_mode=False):
    def why_buggy(sample):
        # get condition and instruction
        exec_res = sample['exec_info']['exec_res']
        cond_no_success_no_format_error = (not sample['success'] and
                                           not sample['exec_info']['code_format_error'])

        # Error 1: format error
        cond_format_error = (not sample['success'] and sample['exec_info']['code_format_error'])

        # Error 2: Execution crash
        cond_crash_into_wall = (cond_no_success_no_format_error and
                                exec_res['crashed'] and
                                exec_res['crashed_msg']['crash_type'] == 'WALL')
        cond_crash_forbidden_cell = (cond_no_success_no_format_error and
                                     exec_res['crashed'] and
                                     exec_res['crashed_msg']['crash_type'] == 'FORBIDDEN_AREA')
        cond_out_of_grid = (cond_no_success_no_format_error and
                            exec_res['crashed'] and
                            exec_res['crashed_msg']['crash_type'] == 'OUT_OF_WORLD')

        # Error 3: goal / constraints fail
        cond_constraints_fail = (cond_no_success_no_format_error and
                                 exec_res['goal_ok'] and
                                 not exec_res['cons_ok'])
        cond_goal_fail = (cond_no_success_no_format_error and
                          not exec_res['goal_ok'] and
                          exec_res['cons_ok'])
        cond_goal_and_constraints_fail = (cond_no_success_no_format_error and
                                          not exec_res['goal_ok'] and
                                          not exec_res['cons_ok'])

        if cond_format_error:
            instr4reason = INSTR_BUGGY_FORMAT_WRONG
            reason = COND_BUGGY_FORMAT_WRONG
        else:
            if cond_crash_into_wall:
                reason = COND_BUGGY_CRASH_INTO_WALL.format(
                    x=exec_res['crashed_msg']['pos'][0],
                    y=exec_res['crashed_msg']['pos'][1])
                instr4reason = INSTR_BUGGY_CRASH_INTO_WALL.format(
                    x=exec_res['crashed_msg']['pos'][0],
                    y=exec_res['crashed_msg']['pos'][1])
            elif cond_crash_forbidden_cell:
                reason = COND_BUGGY_CRASH_FORBIDDEN_CELL.format(
                    x=exec_res['crashed_msg']['pos'][0],
                    y=exec_res['crashed_msg']['pos'][1])
                instr4reason = INSTR_BUGGY_CRASH_FORBIDDEN_CELL.format(
                    x=exec_res['crashed_msg']['pos'][0],
                    y=exec_res['crashed_msg']['pos'][1])
            elif cond_out_of_grid:
                reason = COND_BUGGY_OUT_OF_GRID
                instr4reason = INSTR_BUGGY_OUT_OF_GRID
            elif cond_constraints_fail:
                reason = COND_BUGGY_CODE_CONSTRAINTS
                instr4reason = INSTR_BUGGY_CODE_CONSTRAINTS
            elif cond_goal_fail:
                reason = COND_BUGGY_GOAL_FAIL
                instr4reason = INSTR_BUGGY_GOAL_FAIL
            elif cond_goal_and_constraints_fail:
                reason = COND_BUGGY_GOAL_AND_CONSTRAINTS
                instr4reason = INSTR_BUGGY_GOAL_AND_CONSTRAINTS
            else:
                raise ValueError("Unknown reason")
        return reason, instr4reason

    # if sample["success"] or eval_mode: use `code_json`; else use `model_output` for code
    code_template = (f"```python\n{{code}}\n```")

    if eval_mode:
        condition = COND_CORRECT
        instruction = INSTR_CORRECT
        code = code_template.format(code=codejs2python(sample["code_json"]))
        return condition, instruction, code

    if sample["success"]:
        condition = COND_CORRECT
        instruction = INSTR_CORRECT
        code = code_template.format(code=codejs2python(sample["code_json"]))
    else:
        assert sample["model_output"] is not None
        buggy_code = sample["model_output"]
        correct_code = code_template.format(code=codejs2python(sample["code_json"]))

        reason, instr4reason = why_buggy(sample)

        if dataset_config.enhance_type == 'buggy-binary':
            # for buggy code, only say it's buggy
            condition = COND_WRONG
            instruction = INSTR_WRONG
            code = buggy_code

        elif dataset_config.enhance_type == 'buggy-multi-correct':
            # first show buggy code, then append correct code
            code = f"{buggy_code}\n\n### {COND_CORRECT}:\n{correct_code}"
            condition = reason
            instruction = instr4reason

        elif dataset_config.enhance_type == 'buggy-multi':
            # only show buggy code
            code = f"{buggy_code}"
            condition = reason
            instruction = instr4reason

        elif dataset_config.enhance_type == 'buggy-correct-multi':
            # first show the correct code, then show the buggy code
            code = f"{correct_code}\n\n### {reason}:\n{buggy_code}"
            condition = COND_CORRECT
            instruction = INSTR_CORRECT

        elif dataset_config.enhance_type == 'buggy-correct':
            # for buggy code, we use the correct code instead of the buggy code
            condition = COND_CORRECT
            instruction = INSTR_CORRECT
            code = correct_code
        else:
            raise ValueError(f"Unknown enhance_type: {dataset_config.enhance_type}")

    return condition, instruction, code


def apply_prompt_template(sample, dataset_config, eval_mode=False):
    """
    If `eval_mode` is True, then use the positive condition.
    """
    # apply prompt template
    if dataset_config.prompt_template == 'ascii':
        prompt_template = PROMPT_TEMPLATE_ASCII
        task_repr = taskjs2ascii(sample["task_json"])
    elif dataset_config.prompt_template == 'nl':
        prompt_template = PROMPT_TEMPLATE_NL
        task_repr = taskjs2nl(sample["task_json"])
    else:
        raise ValueError(f"Unknown template: {dataset_config.prompt_template}")

    condition, instruction, code = get_condition_and_instruction_template(sample, dataset_config, eval_mode)

    prompt = prompt_template.format(description=sample["task_json"]["description"],
                                    task=task_repr,
                                    condition=condition,
                                    instruction=instruction)

    return {"prompt": prompt, "code": code}


def get_custom_dataset(dataset_config, tokenizer, split, **kwargs):
    print("--> Loading XLOGO dataset")

    if split == 'train':
        dataset = json.load(open(f'./data/xlogomini-dataset-train.json', 'r'))
    elif split == 'validation':
        dataset = json.load(open('./data/xlogomini-dataset-validation.json', 'r'))
    else:
        raise ValueError(f"Unknown split: {split}, only support 'train' and 'validation'")

    if dataset_config.enhanced_dataset_path is not None and split == 'train':
        # if enhanced dataset is provided, we need to append it to the original dataset
        if not os.path.exists(dataset_config.enhanced_dataset_path):
            raise ValueError(f"Enhanced dataset file does not exist: {dataset_config.enhanced_dataset_path}")
        extra_dataset = json.load(open(dataset_config.enhanced_dataset_path, 'r'))
        dataset.extend(extra_dataset)

    enhanced_dataset = {"task_json"   : [],
                        "code_json"   : [],
                        "model_codejs": [],
                        "model_output": [],
                        "constraints" : [],
                        "success"     : [],
                        "exec_info"   : []}

    for i, sample in enumerate(dataset):
        # three types of samples
        is_sample_in_train_dataset = 'success' not in sample.keys()
        is_buggy_sample_in_enhanced_dataset = 'success' in sample.keys() and not sample['success']
        is_correct_sample_in_enhanced_dataset = 'success' in sample.keys() and sample['success']
        assert not (is_buggy_sample_in_enhanced_dataset and is_correct_sample_in_enhanced_dataset)

        if is_sample_in_train_dataset:
            enhanced_dataset["task_json"].append(sample["task_json"])
            enhanced_dataset["constraints"].append(sample["constraints"])
            enhanced_dataset["code_json"].append(sample["code_json"])
            enhanced_dataset['model_codejs'].append(None)
            enhanced_dataset["model_output"].append(None)
            enhanced_dataset["success"].append(True)
            enhanced_dataset['exec_info'].append({})
        elif is_buggy_sample_in_enhanced_dataset:
            enhanced_dataset["task_json"].append(sample["task_json"])
            enhanced_dataset["constraints"].append(sample["constraints"])
            enhanced_dataset["code_json"].append(sample["code_json"])
            enhanced_dataset["model_codejs"].append(sample['model_codejs'])
            enhanced_dataset["model_output"].append(sample['model_output'])
            enhanced_dataset["success"].append(sample['success'])
            exec_info = {'code_format_error'    : sample['code_format_error'],
                         'code_format_error_msg': sample['code_format_error_msg'],
                         'execution_error'      : sample['execution_error'],
                         'execution_error_msg'  : sample['execution_error_msg'],
                         'exec_res'             : sample['exec_res'] if 'exec_res' in sample.keys() else None}
            enhanced_dataset['exec_info'].append(exec_info)
        elif is_correct_sample_in_enhanced_dataset:
            # if the model can generate the correct code, it means that the model is good at it
            # so we can ignore this sample
            pass
        else:
            raise ValueError(
                f"Unknown sample type: {sample}, it should be either in the original dataset or the enhanced dataset")

    dataset = enhanced_dataset

    assert (len(dataset["task_json"]) == len(dataset["code_json"]) ==
            len(dataset["constraints"]) == len(dataset["success"]))
    dataset = datasets.Dataset.from_dict(dataset)

    if dataset_config.debug:
        dataset = dataset.shuffle(seed=42).select(range(1000))

    if split == 'validation':
        # randomly select 300 samples from the validation set to save time
        dataset = dataset.shuffle(seed=42)

    def tokenize_add_label(sample):
        prompt_ids = tokenizer.encode(sample["prompt"])
        code_ids = tokenizer.encode(sample["code"] + tokenizer.eos_token)

        sample = {
            "input_ids"     : prompt_ids + code_ids,
            "attention_mask": [1] * (len(prompt_ids) + len(code_ids)),
            "labels"        : [-100] * len(prompt_ids) + code_ids,
        }

        return sample

    print("=== DEBUG: dataset_config.enhance_type: ", dataset_config.enhance_type)

    dataset = dataset.map(lambda x: apply_prompt_template(x, dataset_config))
    dataset_tok = dataset.map(tokenize_add_label, remove_columns=list(dataset.features))

    return dataset_tok


def get_custom_dataset_eval(template, dataset_path):
    """
    This function is used for inference with vllm.
    """
    print("--> Loading XLOGO dataset")
    dataset = datasets.load_dataset('json', data_files={'test': dataset_path}, split='test')

    if template is not None:
        dataset_config = type('DatasetConfig', (object,), {"prompt_template": template})()
        dataset = dataset.map(lambda x: apply_prompt_template(x, dataset_config, eval_mode=True))
    return dataset
