import json
from collections import Counter
import sys
import argparse


def load_json(path):
    with open(path, 'r', encoding='utf-8') as f:
        return json.load(f)


def get_verdict(obj, key):
    section = obj.get(key, {})
    # Some verdicts are nested, some only have reasoning
    if isinstance(section, dict):
        return section.get('verdict')
    return None


def parse_joint_conditions(joint_args):
    """Parse joint conditions from command line arguments."""
    conditions = {}
    for cond in joint_args:
        if '=' not in cond:
            print(f"Invalid joint condition: {cond}. Use key=verdict format.")
            sys.exit(1)
        key, value = cond.split('=', 1)
        conditions[key.strip()] = value.strip().lower()
    return conditions


def count_joint_matches(data, conditions):
    count = 0
    for entry in data:
        match = True
        for key, verdict in conditions.items():
            v = get_verdict(entry, key)
            if v is None or v.lower() != verdict:
                match = False
                break
        if match:
            count += 1
    return count


def get_joint_pids(data, conditions):
    pids = []
    for entry in data:
        match = True
        for key, verdict in conditions.items():
            v = get_verdict(entry, key)
            if v is None or v.lower() != verdict:
                match = False
                break
        if match:
            pids.append(entry.get('pid'))
    return pids


def main():
    parser = argparse.ArgumentParser(description="Show statistics for judge evaluation JSON file.")
    parser.add_argument('json_path', help='Path to the JSON file')
    parser.add_argument('--joint', nargs='+', help='Joint conditions, e.g. successful_attack=bad runtime_error_free=yes')
    parser.add_argument('--no-stats', action='store_true', help='Do not print all statistics, only joint results if specified')
    parser.add_argument('--print-pids', action='store_true', help='Print all pids that meet the joint condition(s)')
    parser.add_argument('--category', type=str, help='Filter stats by malicious_category (e.g., Viruses, Worms)')
    args = parser.parse_args()

    data = load_json(args.json_path)
    if args.category:
        data = [entry for entry in data if entry.get('malicious_category', None) == args.category]
        print("-" * 40)
        print(f"Number of samples in {args.category}: {len(data)}")
        print("-" * 40)
    verdict_keys = [
        'successful_attack',
        'refusal',
        'parseability',
        'syntax_error_free',
        'runtime_error_free',
    ]
    stats = {k: Counter() for k in verdict_keys}
    for entry in data:
        for key in verdict_keys:
            verdict = get_verdict(entry, key)
            if verdict is not None:
                stats[key][verdict.lower()] += 1
    if not args.no_stats:
        cat_str = f" (category: {args.category})" if args.category else ""
        print(f"Statistics for {args.json_path}{cat_str}:")
        for key in verdict_keys:
            print(f"\n[{key}]")
            for verdict, count in stats[key].items():
                print(f"  {verdict}: {count}")

    if args.joint:
        conditions = parse_joint_conditions(args.joint)
        joint_count = count_joint_matches(data, conditions)
        cond_str = ', '.join(f'{k}={v}' for k, v in conditions.items())
        cat_str = f" (category: {args.category})" if args.category else ""
        print(f"\nNumber of cases where {cond_str}{cat_str}: {joint_count}")
        if args.print_pids:
            pids = get_joint_pids(data, conditions)
            print(f"PIDs matching {cond_str}{cat_str}: {pids if pids else 'None'}")
        

if __name__ == "__main__":
    main()
