import random as r
from util.constants import SCHEMA_KEYS, CLASSIFICATION, COT_SUFFIX

seed=2024
r.seed(seed)

# sort only incorrect answers; assuming no options are lists themeselves
flip_binary_classification_text_schema = {
	"instruction_id": "flip_binary_classification_text",
    "instruction_text": ["if there are two options, swap the text associated with the option label of the correct answer to the question, with that of the incorrect answer and print the text. Do not print the option label. ", "if there are two options, print the text corresponding to the incorrect answer choice for the question. Do not print the option label."],
	"operation": [" FLIP_BINARY_TEXT "],
	"python_function_name": "flip_binary_classification_text",
	"python_args": "instruction_instance",
	"target_answer_return_type": "str"
    }

def apply_instruction(candidate_answer_set, input_candidate: str):
    flipped_text = input_candidate
    if len(candidate_answer_set) == 2 and input_candidate in candidate_answer_set:
        flipped_text = [txt for txt in candidate_answer_set if txt != input_candidate][0]
    return flipped_text

def apply_instruction_incorrect(candidate_answer_label_space, input_candidate_label: str):
    flipped_label = input_candidate_label
    if len(candidate_answer_label_space) == 2 and input_candidate_label in candidate_answer_label_space:
        flipped_label = [value for value in candidate_answer_label_space if value != input_candidate_label][0]
    return flipped_label

def flip_binary_classification_text(input_instance: dict):
    assert input_instance["task_type"] == "MCQ"
    assert len(input_instance[SCHEMA_KEYS.CANDIDATE_ANSWER_LABEL_SPACE])==2
    input_instance["instruction_id"] = flip_binary_classification_text_schema["instruction_id"]
    new_instruction = flip_binary_classification_text_schema["instruction_text"][r.randint(0,len(flip_binary_classification_text_schema["instruction_text"])-1)]
    
    # we add reversing correct answer instruction to the instruction list
    input_instance["task_instructions"].append(new_instruction)

    input_instance["final_prefix_task_instruction"] = input_instance["dataset_instruction"] + " " + new_instruction
    input_instance[SCHEMA_KEYS.COT_INSTRUCTION]=input_instance[SCHEMA_KEYS.FINAL_PREFIX_TASK_INSTRUCTION]+COT_SUFFIX
    instr_output_flipped_correct_answer = apply_instruction(input_instance[SCHEMA_KEYS.CANDIDATE_ANSWER_SET],input_instance["ground_truth_answer_text"])
    input_instance['instruction_output'].append(instr_output_flipped_correct_answer)

    # convert candidate outputs to flipped binary text
    candidate_outputs = input_instance["candidate_answer_set"]
    for candidate_label, candidate in zip(input_instance[SCHEMA_KEYS.CANDIDATE_ANSWER_LABEL_SPACE], candidate_outputs):
        input_instance[SCHEMA_KEYS.INSTRUCTION_FOLLOWING_ERRORS_SET].append(apply_instruction_incorrect(input_instance[SCHEMA_KEYS.CANDIDATE_ANSWER_LABEL_SPACE], candidate_label))
        if candidate == input_instance["ground_truth_answer_text"]:
            input_instance[SCHEMA_KEYS.REASONING_ERROR_SET].append(candidate)
            input_instance[SCHEMA_KEYS.REASONING_ERROR_SET].append(candidate_label)
        else:
            input_instance[SCHEMA_KEYS.INSTRUCTION_FOLLOWING_ERRORS_SET].append(candidate_label)

    input_instance[CLASSIFICATION.CLASSIFICATION]=CLASSIFICATION.LABEL_MANIPULATION   
    return input_instance

