import csv
from collections import defaultdict

# input_file = 'composio_labels.csv'
output_file = 'filtered_langchain.csv'

# Dictionary to keep count of each id
id_counts = defaultdict(int)
filtered_rows = []

tool_invokations = 0

tools = set()

errors_tf = set()
errors_llm = set()
errors_consistency = set()

with open(output_file, newline='') as csvfile:
    reader = csv.DictReader(csvfile)
    header = reader.fieldnames  # Save header for writing out later
    
    iterator = iter(reader)


    tf_tp, tf_fp, tf_fn, tf_tn = 0, 0, 0, 0
    llm_tp, llm_fp, llm_fn, llm_tn = 0, 0, 0, 0
    consist_tp, consist_fp, consist_fn, consist_tn = 0, 0, 0, 0

    while True:
        item = next(iterator, None)
        if item is None:
            break
        tools.add(item['tool'])
        id_value = item['id']

        is_bug, is_llm_bug, is_consisteny_bug, is_consisteny_bug, is_tf_bug = False, False, False, False, False
        while id_value == item['id']:
            # Now we do the labeling :@
            if item['ToolArgs'] != 'No tool arguments found':
                tool_invokations += 1
            if item['Is Bug'] == 'y':
                is_bug = True
            if int(item['LLMScore']) <= 5:
                if item["Unique errro"] != '':
                    errors_llm.add(item['tool'] + ';'+ item['Unique errro'])
                is_llm_bug = True
            if item['IOConsist'] == 'TRUE':
                if item['Unique errro'] != '':
                    errors_consistency.add(item['tool'] + ';'+ item['Unique errro'])
                is_consisteny_bug = True
            if item['IOConsist'] == 'TRUE' and int(item['LLMScore']) <= 5:
                if item['Unique errro'] != '':
                    errors_tf.add(item['tool'] + ';'+ item['Unique errro'])
                is_tf_bug = True
            item = next(iterator, None)
            if item is None:
                break

        if is_bug:
            if is_llm_bug:
                llm_tp += 1
            else:
                llm_fn += 1
            if is_consisteny_bug:
                consist_tp += 1
            else:
                consist_fn += 1
            if is_tf_bug:
                tf_tp += 1
            else:
                tf_fn += 1
        else:
            if is_llm_bug:
                llm_fp += 1
            else:
                llm_tn += 1
            if is_consisteny_bug:
                consist_fp += 1
            else:
                consist_tn += 1
            if is_tf_bug:
                tf_fp += 1
            else:
                tf_tn += 1

    print(f"TF TP: {tf_tp}, TF FP: {tf_fp}, TF FN: {tf_fn}, TF TN: {tf_tn}")
    print(f"LLM TP: {llm_tp}, LLM FP: {llm_fp}, LLM FN: {llm_fn}, LLM TN: {llm_tn}")
    print(f"Consistency TP: {consist_tp}, Consistency FP: {consist_fp}, Consistency FN: {consist_fn}, Consistency TN: {consist_tn}")
    print(f"Tool invokations: {tool_invokations}, {len(tools)} tools")

    # errors_tf = sorted(list(errors_tf))
    # errors_llm = sorted(list(errors_llm))
    # errors_consistency = sorted(list(errors_consistency))

    # print(f"TF Errors: {len(errors_tf)}")
    # for item in errors_tf:
    #     print(item)
    # print('=============================================')
    # print(f"LLM Errors: {len(errors_llm)}")
    # for item in errors_llm:
    #     print(item)
    # print('=============================================')
    # print(f"Consistency Errors: {len(errors_consistency)}")
    # for item in errors_consistency:
    #     print(item)
    # print('=============================================')