import os
import os.path as osp
import json
from collections import Counter, defaultdict
import argparse

def read_truth_vectors(file_path):
    with open(file_path, 'r') as f:
        lines = [line.strip() for line in f if line.strip()]
    half = len(lines) // 2
    return [tuple(int(bit) for bit in line) for line in lines[half:]]

def compute_proportions_from_counter(pattern_counter):
    total = sum(pattern_counter.values())
    if total == 0:
        return {"unique_ratio": 0.0, "nontrivial_ratio": 0.0}

    # Define trivial patterns: all-zeros and all-ones
    trivial_patterns = set()
    if pattern_counter:
        sample_len = len(next(iter(pattern_counter)))
        trivial_patterns = {
            tuple([0] * sample_len),
            tuple([1] * sample_len),
        }

    nontrivial = sum(count for patt, count in pattern_counter.items() if patt not in trivial_patterns)
    if nontrivial == 0:
        print('all trivial')

    # full_ratio: ratio of non-trivial vectors (counting duplicates)
    full_nontrivial = sum(
        count for patt, count in pattern_counter.items() if patt not in trivial_patterns
    )
    # unique_ratio: ratio of unique non-trivial patterns
    unique_nontrivial = len([p for p in pattern_counter if p not in trivial_patterns])
    unique_ratio = unique_nontrivial / total
    full_ratio = nontrivial / total
    # print(nontrivial,total)

    return {
        "unique_ratio": round(unique_ratio, 4),
        "nontrivial_ratio": round(full_ratio, 4),
    }

def save_comparison_scores(an_dir, ano_dir, output_json):
    comparison_summary = {}

    # size_dir (in5, in6...)
    for size_dir in os.listdir(an_dir):
        size_path_an = osp.join(an_dir, size_dir)
        size_path_ano = osp.join(ano_dir, size_dir)
        if not osp.isdir(size_path_an) or not osp.isdir(size_path_ano):
            continue

        comparison_summary[size_dir] = {}

        # and* subdir
        for subdir in os.listdir(size_path_an):
            and_path_an = osp.join(size_path_an, subdir)
            and_path_ano = osp.join(size_path_ano, subdir)
            if not osp.isdir(and_path_an) or not osp.isdir(and_path_ano):
                continue

            tt_dir_an = osp.join(and_path_an, "tt")
            tt_dir_ano = osp.join(and_path_ano, "tt")
            if not osp.isdir(tt_dir_an) or not osp.isdir(tt_dir_ano):
                continue

            # AN
            counter_an = Counter()
            for fname in os.listdir(tt_dir_an):
                if fname.endswith(".truth"):
                    vectors = read_truth_vectors(osp.join(tt_dir_an, fname))
                    counter_an.update(vectors)

            # ANO
            counter_ano = Counter()
            for fname in os.listdir(tt_dir_ano):
                if fname.endswith(".truth"):
                    vectors = read_truth_vectors(osp.join(tt_dir_ano, fname))
                    counter_ano.update(vectors)

            if not counter_an or not counter_ano:
                continue

            # combined statistics
            merged_counter = counter_an + counter_ano
            total_patterns = len(counter_an) + len(counter_ano)  # AN/ANO unique
            merged_patterns = len(merged_counter)                # combined unique

            pattern_repeat_ratio = round(merged_patterns / total_patterns, 4)

            total_vectors = sum(counter_an.values()) + sum(counter_ano.values())
            merged_vectors = sum(merged_counter.values())
            vector_repeat_ratio = round(merged_vectors / total_vectors, 4)

            # nontrivial
            props = compute_proportions_from_counter(merged_counter)
            nontrivial_ratio = props["nontrivial_ratio"]

            comparison_summary[size_dir][subdir] = {
                "pattern_repeat_ratio": pattern_repeat_ratio,
                "vector_repeat_ratio": vector_repeat_ratio,
                "nontrivial_ratio": nontrivial_ratio,
                "AN_patterns": len(counter_an),
                "ANO_patterns": len(counter_ano),
                "merged_patterns": merged_patterns,
                "AN_vectors": sum(counter_an.values()),
                "ANO_vectors": sum(counter_ano.values()),
                "merged_vectors": merged_vectors,
            }

    with open(output_json, 'w') as f:
        json.dump(comparison_summary, f, indent=2)
    print(f"Saved comparison report to {output_json}")

def main():
    parser = argparse.ArgumentParser(description='Compare AN and ANO truth table proportions.')
    parser.add_argument('--an_dir', default="../AN", help='Directory for AN data.')
    parser.add_argument('--ano_dir', default="../ANO", help='Directory for ANO data.')
    parser.add_argument('--output_json', default='truth_comparison_AN_ANO.json', help='Output JSON path.')
    args = parser.parse_args()

    save_comparison_scores(args.an_dir, args.ano_dir, args.output_json)
