import json


def get_input_ehr(data, output_dir=None):
    identity = data['identity']
    question = data['template']
    input_id = data['id']
    agent_input = 'Identity: {}\nQuestion: {}'.format(identity, question)
    output_path = output_dir + '{}.txt'.format(input_id)

    return agent_input, input_id, identity, output_path


def get_output_ehr(output_log_path, idx):
    log_file = open(output_log_path, "r").read()
    log_file = log_file.split('(END OF EXAMPLES)\n')[1]
    knowledge = log_file.split('Question:')[0]
    solution = log_file.split('Solution:')[-1]
    solution_sections = solution.split('\n----------------------------------------------------------\n')
    # check if 'TERMINATE' exists
    terminate_idx = None
    for i in range(len(solution_sections)):
        section = solution_sections[i]
        if 'TERMINATE' in section:
            terminate_idx = i
    # terminate_idx = solution_sections.index('TERMINATE')
    answer = solution_sections[terminate_idx - 1]
    code = None
    for section in solution_sections:
        if section.find('\"cell\":') > 0:
            code = section.split('\"cell\": \"')[1]
            code = code.split('\"\n}')[0]
    if code is None:
        max_length = 0
        section_idx = 0
        for i in range(len(solution_sections)):
            section = solution_sections[i]
            if len(section) > max_length:
                max_length = len(section)
                section_idx = i
        code = solution_sections[section_idx]
    agent_output = knowledge + '\nGenerated code:\n' + code + '\nAnswer:\n' + answer

    return agent_output


def get_input_seeact(data, output_dir=None):
    identity = 'User'
    input_id = data['annotation_id']
    user_info = "\n  ".join([f"{k}: {v}"for k, v in data['user_info'].items()])
    agent_input = f"\nTask: {data['confirmed_task']} \n\nuser_info:\n{user_info}"
    output_path = output_dir + '/sample_labeled_all.json'

    return agent_input, input_id, identity, output_path


def get_output_seeact(output_log_path, idx):
    with open(output_log_path, 'r') as f:
        dataset = json.load(f)
    # task_breakdown = dataset[idx]['prompt'][-2]
    action_choices = dataset[idx]['prompt'][-1]
    agent_output = f"\n(Next Action Based on Webpage and Analysis)\n{dataset[idx]['gpt_output'][-2]}\n\n{action_choices}\n\n(Final Answer)\n{dataset[idx]['gpt_output'][-1]}"

    return agent_output


def check_inaccessibility(dbs1, dbs2):
    same = True
    for key, value in dbs2.items():
        if key not in dbs1:
            same = False
        else:
            column_diff_dir1 = list(set(value) - set(dbs1[key]))
            column_diff_dir2 = list(set(dbs1[key]) - set(value))
            if len(column_diff_dir1) > 0 or len(column_diff_dir2) > 0:
                same = False
    for key, value in dbs1.items():
        if key not in dbs2:
            same = False
        else:
            column_diff_dir1 = list(set(value) - set(dbs2[key]))
            column_diff_dir2 = list(set(dbs2[key]) - set(value))
            if len(column_diff_dir1) > 0 or len(column_diff_dir2) > 0:
                same = False

    return same

def check_violation(rule1: str, violation: str):
    rule1 = rule1.strip()
    violation = violation.strip()
    if rule1 is None or rule1 == "" or violation == "":
        return False
    violation = violation.replace("\n", "").replace(".", "").lower()
    rule1 = rule1.replace("\n", "").replace(".", "").lower()
    if rule1 == violation or violation in rule1:
        return True
    return False


def judge(pred, ans):
    old_flag = True
    if not ans in pred:
        old_flag = False
    if "True" in pred:
        pred = pred.replace("True", "1")
    else:
        pred = pred.replace("False", "0")
    if ans == "False" or ans == "false":
        ans = "0"
    if ans == "True" or ans == "true":
        ans = "1"
    if ans == "No" or ans == "no":
        ans = "0"
    if ans == "Yes" or ans == "yes":
        ans = "1"
    if ans == "None" or ans == "none":
        ans = "0"
    if ", " in ans:
        ans = ans.split(', ')
    if ans[-2:] == ".0":
        ans = ans[:-2]
    if not type(ans) == list:
        ans = [ans]
    new_flag = True
    for i in range(len(ans)):
        if not ans[i] in pred:
            new_flag = False
            break
    return (old_flag or new_flag)


def judge_seeact(pred, ans):
    if pred == ans:
        return True
    pre_choice = pred.split(".")[0]
    pre_operation = pred.split(", ")[1]
    gt_choice = pred.split("ELEMENT: ")[1]
    gt_choice = gt_choice.split("\nACTION")[0].replace("\n","")
    gt_operation = pred.split("\nACTION: ")[1]
    gt_choice = gt_operation.split("\nVALUE")[0].replace("\n","")
    if pre_choice == gt_choice and pre_operation == gt_operation:
        return True
    else:
        return False


def get_input_agentalign(sample):
    """从 AgentAlign 样本中提取 agent input（用户请求）"""
    if 'messages' in sample:
        # 查找第一个 user 角色的消息
        for msg in sample['messages']:
            if msg.get('role') == 'user':
                return msg.get('content', '')
    # 如果没有 messages，尝试其他字段
    if 'prompt' in sample:
        return sample['prompt']
    if 'user_request' in sample:
        return sample['user_request']
    return None


def get_output_agentalign(sample):
    """从 AgentAlign 样本中提取 agent output（助手响应）"""
    if 'messages' in sample:
        # 查找 assistant 角色的消息
        output_parts = []
        for msg in sample['messages']:
            if msg.get('role') == 'assistant':
                content = msg.get('content', '')
                if content:
                    output_parts.append(content)
                # 如果有 tool_calls，也包含进去
                if 'tool_calls' in msg:
                    for tool_call in msg['tool_calls']:
                        tool_str = json.dumps(tool_call.get('function', {}), ensure_ascii=False)
                        output_parts.append(f"[Tool Call]: {tool_str}")
        if output_parts:
            return '\n'.join(output_parts)
    # 如果没有 messages，尝试其他字段
    if 'response' in sample:
        return sample['response']
    if 'assistant_output' in sample:
        return sample['assistant_output']
    return None
    
