import random as r
from util.constants import SCHEMA_KEYS, CLASSIFICATION, COT_SUFFIX

seed=2024
r.seed(seed)

# sort only incorrect answers; assuming no options are lists themeselves
flip_binary_classification_labels_schema = {
	"instruction_id": "flip_binary_classification_labels",
    "instruction_text": ["if there are two options, swap the option label of the correct answer to the question, with that of the incorrect answer and print only that label.  Do not print the text associated with the option.", "if there are two options, print the option label corresponding to the incorrect answer to the question.  Do not print the text associated with the option."],
	"operation": [" FLIP_BINARY_LABELS "],
	"python_function_name": "flip_binary_classification_labels",
	"python_args": "instruction_instance",
	"target_answer_return_type": "str"
    }

def apply_instruction(candidate_answer_label_space, input_candidate_label: str):
    flipped_label = input_candidate_label
    if len(candidate_answer_label_space) == 2 and input_candidate_label in candidate_answer_label_space:
        flipped_label = [value for value in candidate_answer_label_space if value != input_candidate_label][0]
    return flipped_label

def apply_instruction_incorrect(candidate_answer_set, input_candidate_answer: str):
    flipped_text = input_candidate_answer
    if len(candidate_answer_set) == 2 and input_candidate_answer in candidate_answer_set:
        flipped_text = [txt for txt in candidate_answer_set if txt != input_candidate_answer][0]
    return flipped_text

def flip_binary_classification_labels(input_instance: dict):
    assert input_instance["task_type"] == "MCQ"
    assert len(input_instance[SCHEMA_KEYS.CANDIDATE_ANSWER_LABEL_SPACE])==2
    input_instance["instruction_id"] = flip_binary_classification_labels_schema["instruction_id"]
    new_instruction = flip_binary_classification_labels_schema["instruction_text"][r.randint(0,len(flip_binary_classification_labels_schema["instruction_text"])-1)]
    
    # we add reversing correct answer instruction to the instruction list
    input_instance["task_instructions"].append(new_instruction)

    input_instance["final_prefix_task_instruction"] = input_instance["dataset_instruction"] + " " + new_instruction
    input_instance[SCHEMA_KEYS.COT_INSTRUCTION]=input_instance[SCHEMA_KEYS.FINAL_PREFIX_TASK_INSTRUCTION]+COT_SUFFIX
    
    instr_output_flipped_label = apply_instruction(input_instance[SCHEMA_KEYS.CANDIDATE_ANSWER_LABEL_SPACE], input_instance["ground_truth_answer_label"])
    input_instance['instruction_output'].append(instr_output_flipped_label)
    
    candidate_outputs = input_instance["candidate_answer_set"]
    for candidate_label, candidate in zip(input_instance[SCHEMA_KEYS.CANDIDATE_ANSWER_LABEL_SPACE], candidate_outputs):
        input_instance[SCHEMA_KEYS.INSTRUCTION_FOLLOWING_ERRORS_SET].append(apply_instruction_incorrect(input_instance[SCHEMA_KEYS.CANDIDATE_ANSWER_SET], candidate))
        if candidate_label == input_instance["ground_truth_answer_label"]:
            input_instance[SCHEMA_KEYS.REASONING_ERROR_SET].append(candidate_label)
            input_instance[SCHEMA_KEYS.REASONING_ERROR_SET].append(candidate)
        else:
            input_instance[SCHEMA_KEYS.INSTRUCTION_FOLLOWING_ERRORS_SET].append(candidate)

    input_instance[CLASSIFICATION.CLASSIFICATION]=CLASSIFICATION.LABEL_MANIPULATION      
    return input_instance

