import csv
from collections import Counter, defaultdict

# input_file = 'composio_labels.csv'
output_file = 'filtered_composio_labels.csv'

# Dictionary to keep count of each id
id_counts = defaultdict(int)
filtered_rows = []

tool_invokations = 0

tools = set()

errors_tf = list()
errors_llm = list()
errors_consistency = list()

with open(output_file, newline='') as csvfile:
    reader = csv.DictReader(csvfile)
    header = reader.fieldnames  # Save header for writing out later
    
    iterator = iter(reader)


    tf_tp, tf_fp, tf_fn, tf_tn = 0, 0, 0, 0
    llm_tp, llm_fp, llm_fn, llm_tn = 0, 0, 0, 0
    consist_tp, consist_fp, consist_fn, consist_tn = 0, 0, 0, 0

    while True:
        item = next(iterator, None)
        if item is None:
            break
        tools.add(item['tool'])
        id_value = item['id']

        is_bug, is_llm_bug, is_consisteny_bug, is_consisteny_bug, is_tf_bug = False, False, False, False, False
        while id_value == item['id']:
            # Now we do the labeling :@
            if item['ToolArgs'] != 'None':
                tool_invokations += 1
            if int(item['LLMScore']) <= 5 and item["just llm"] == "tp":
                is_bug = True
            elif int(item['LLMScore']) > 5 and item["just llm"] == "fn":
                is_bug = True
            if int(item['LLMScore']) <= 5:
                if item['unque error'] != '':
                    errors_llm.append(item['tool'] + ';'+ item['unque error'])
                is_llm_bug = True
            if item['IOConsist'] == 'TRUE':
                if item['unque error'] != '':
                    errors_consistency.append(item['tool'] + ';'+ item['unque error'])
                is_consisteny_bug = True
            if item['IOConsist'] == 'TRUE' and int(item['LLMScore']) <= 5:
                if item['unque error'] != '':
                    errors_tf.append(item['tool'] + ';'+ item['unque error'])
                is_tf_bug = True
            item = next(iterator, None)
            if item is None:
                break

        if is_bug:
            if is_llm_bug:
                llm_tp += 1
            else:
                llm_fn += 1
            if is_consisteny_bug:
                consist_tp += 1
            else:
                consist_fn += 1
            if is_tf_bug:
                tf_tp += 1
            else:
                tf_fn += 1
        else:
            if is_llm_bug:
                llm_fp += 1
            else:
                llm_tn += 1
            if is_consisteny_bug:
                consist_fp += 1
            else:
                consist_tn += 1
            if is_tf_bug:
                tf_fp += 1
            else:
                tf_tn += 1

    print(f"TF TP: {tf_tp}, TF FP: {tf_fp}, TF FN: {tf_fn}, TF TN: {tf_tn}")
    print(f"LLM TP: {llm_tp}, LLM FP: {llm_fp}, LLM FN: {llm_fn}, LLM TN: {llm_tn}")
    print(f"Consistency TP: {consist_tp}, Consistency FP: {consist_fp}, Consistency FN: {consist_fn}, Consistency TN: {consist_tn}")
    print(f"Tool invokations: {tool_invokations}, {len(tools)} tools")

    tf_c = Counter(errors_tf)    
    count_of_ones = sum(1 for count in tf_c.values() if count == 1)
    count_of_more_than_ones = sum(1 for count in tf_c.values() if count == 2)
    total_error_hits = sum(tf_c.values())
    print(f"TF Total Errors: {total_error_hits}")
    print(f"TF Errors: {len(tf_c)}, {count_of_ones}, {count_of_more_than_ones}")
    llm_c = Counter(errors_llm)
    count_of_ones = sum(1 for count in llm_c.values() if count == 1)
    count_of_more_than_ones = sum(1 for count in llm_c.values() if count == 2)
    total_error_hits = sum(llm_c.values())
    print(f"LLM Total Errors: {total_error_hits}")
    print(f"LLM Errors: {len(llm_c)}, {count_of_ones}, {count_of_more_than_ones}")
    consist_c = Counter(errors_consistency)
    count_of_ones = sum(1 for count in consist_c.values() if count == 1)
    count_of_more_than_ones = sum(1 for count in consist_c.values() if count == 2)
    total_error_hits = sum(consist_c.values())
    print(f"Consistency Errors: {len(consist_c)}, {count_of_ones}, {count_of_more_than_ones}")
    print(f"Consistency Total Errors: {total_error_hits}")