import random as r
from util.constants import SCHEMA_KEYS, CLASSIFICATION, COT_SUFFIX

seed=2024
r.seed(seed)

# reverse correct answer and alternate case
reverse_correct_answer_alternate_case_schema={
	"instruction_id": "reverse_correct_answer_alternate_case",
	"instruction_text": ["reverse the text associated with the answer label that correctly answers the question. Print this reversed text in alternate case starting with upper case.  Do not print the option label."],#, "Change the right answer to its reverse and be certain that it's in an alternating case format starting with upper case in the reversed string.", "Alter the correct answer to its reverse and confirm that it's presented in an alternating case manner starting with upper case in the reversed string.", "Reformat the correct response to its reverse and ascertain that it's displayed in a mixed-case style, alternating between upper and lower case starting with upper case in the reversed string."],
	"operation": [" REVERSE_CORRECT_ANSWER_ALT_CASE "],
	"python_function_name": "reverse_correct_answer_alternate_case",
	"python_args": "instruction_instance",
	"target_answer_return_type": "text"
    }

def reverse_correct_answer_alternate_case(input_instance: dict):
    assert input_instance["task_type"] == "MCQ"
    input_instance["instruction_id"] = reverse_correct_answer_alternate_case_schema["instruction_id"]
    new_instruction = reverse_correct_answer_alternate_case_schema["instruction_text"][r.randint(0,len(reverse_correct_answer_alternate_case_schema["instruction_text"])-1)]

    # we add reversing correct answer instruction to the instruction list
    input_instance["task_instructions"].append(new_instruction)
    
    # we add reversing correct answer instruction to the instruction list
    input_instance["final_prefix_task_instruction"] = input_instance["dataset_instruction"] + " " + new_instruction
    input_instance[SCHEMA_KEYS.COT_INSTRUCTION]=input_instance[SCHEMA_KEYS.FINAL_PREFIX_TASK_INSTRUCTION]+COT_SUFFIX

    # reverse the correct answer
    if isinstance(input_instance["ground_truth_answer_text"],list):
        input_instance['instruction_output'].append(input_instance["ground_truth_answer_text"][-1][::-1])
    else:
        input_instance['instruction_output'].append(input_instance["ground_truth_answer_text"][::-1])
        
    # convert the correct answer into alternate case format
    alternate_case_correct_answer = [ele.upper() if not idx % 2 else ele.lower() for idx, ele in enumerate(input_instance["instruction_output"][-1])]
    input_instance['instruction_output'].append("".join(alternate_case_correct_answer))


    candidate_outputs = input_instance["candidate_answer_set"]
    gd=input_instance["ground_truth_answer_text"]
    if isinstance(gd,list):
        gd=gd[0]
    for candidate_label, candidate in zip(input_instance[SCHEMA_KEYS.CANDIDATE_ANSWER_LABEL_SPACE],candidate_outputs):
        if isinstance(candidate,list):
             candidate=candidate[0]
        
        alternate_case_candidate_answer = [ele.upper() if not idx % 2 else ele.lower() for idx, ele in enumerate(candidate)]
        #if and even length instruction then instruction is not invariant to order of reversing and case changes
        if (len(alternate_case_correct_answer)%2==0) and all(x.isalpha() or x.isspace() for x in candidate):
            input_instance[SCHEMA_KEYS.INSTRUCTION_FOLLOWING_ERRORS_SET].append("".join(alternate_case_candidate_answer))
            input_instance[SCHEMA_KEYS.INSTRUCTION_FOLLOWING_ERRORS_SET].append(reversed("".join(alternate_case_candidate_answer)))
            
        if all(x.isalpha() or x.isspace() for x in candidate) and len(candidate.strip())>1:
            candidate=str(candidate[::-1])
            input_instance[SCHEMA_KEYS.INSTRUCTION_FOLLOWING_ERRORS_SET].append(candidate)
            input_instance[SCHEMA_KEYS.INSTRUCTION_FOLLOWING_ERRORS_SET].append(candidate[::-1])
            input_instance[SCHEMA_KEYS.INSTRUCTION_FOLLOWING_ERRORS_SET].append(candidate_label)

        if candidate != gd:    
            alternate_case_correct_answer = [ele.upper() if not idx % 2 else ele.lower() for idx, ele in enumerate(candidate)]
            input_instance[SCHEMA_KEYS.REASONING_ERROR_SET].append("".join(alternate_case_correct_answer))
            input_instance[SCHEMA_KEYS.REASONING_ERROR_SET].append(reversed("".join(alternate_case_candidate_answer)))
            input_instance[SCHEMA_KEYS.REASONING_ERROR_SET].append(candidate)
            input_instance[SCHEMA_KEYS.REASONING_ERROR_SET].append(candidate[::-1])
            input_instance[SCHEMA_KEYS.REASONING_ERROR_SET].append(candidate_label)

    input_instance[CLASSIFICATION.CLASSIFICATION]=CLASSIFICATION.STRING_MANIPULATION
    return input_instance