import os
import os.path as osp
import json
from collections import Counter, defaultdict
import numpy as np
import argparse

def read_truth_vectors(file_path):
    with open(file_path, 'r') as f:
        lines = [line.strip() for line in f if line.strip()]
    half = len(lines) // 2
    return [tuple(int(bit) for bit in line) for line in lines[half:]]

def compute_proportions_from_counter(pattern_counter):
    total = sum(pattern_counter.values())
    if total == 0:
        return {"unique_ratio": 0.0, "nontrivial_ratio": 0.0}

    # Define trivial patterns: all-zeros and all-ones
    trivial_patterns = set()
    if pattern_counter:
        sample_len = len(next(iter(pattern_counter)))
        trivial_patterns = {
            tuple([0] * sample_len),
            tuple([1] * sample_len),
        }

    nontrivial = sum(count for patt, count in pattern_counter.items() if patt not in trivial_patterns)
    if nontrivial == 0:
        print('all trivial')

    # full_ratio: ratio of non-trivial vectors (counting duplicates)
    full_nontrivial = sum(
        count for patt, count in pattern_counter.items() if patt not in trivial_patterns
    )
    # unique_ratio: ratio of unique non-trivial patterns
    unique_nontrivial = len([p for p in pattern_counter if p not in trivial_patterns])
    unique_ratio = unique_nontrivial / total
    full_ratio = nontrivial / total
    # print(nontrivial,total)

    return {
        "unique_ratio": round(unique_ratio, 4),
        "nontrivial_ratio": round(full_ratio, 4),
    }


def compute_truth_proportions(truth_dir, fallback_noise0=False):
    noise_proportions = {}
    noise_counters = {}

    # trivial patterns
    trivial_patterns = set()

    if fallback_noise0:
        pattern_counter = Counter()
        truth_files = [f for f in os.listdir(truth_dir) if f.endswith(".truth")]
        total_files = len(truth_files)
        if total_files == 0:
            return {}, {}

        for fname in truth_files:
            vectors = read_truth_vectors(osp.join(truth_dir, fname))
            pattern_counter.update(vectors)

        proportions = compute_proportions_from_counter(pattern_counter)
        noise_proportions["noise_0"] = proportions
        noise_counters["noise_0"] = {
            "counter": pattern_counter,
            "total_files": total_files,
            "total_vectors": sum(pattern_counter.values()),
            "unique_patterns": len(pattern_counter),
            "max_repeat": max(pattern_counter.values()) if pattern_counter else 0,
        }

    else:
        for noise_name in os.listdir(truth_dir):
            noise_path = osp.join(truth_dir, noise_name)
            if not osp.isdir(noise_path):
                continue

            pattern_counter = Counter()
            truth_files = [f for f in os.listdir(noise_path) if f.endswith(".truth")]
            total_files = len(truth_files)
            if total_files == 0:
                continue

            for fname in truth_files:
                vectors = read_truth_vectors(osp.join(noise_path, fname))
                pattern_counter.update(vectors)

            proportions = compute_proportions_from_counter(pattern_counter)
            noise_proportions[noise_name] = proportions
            noise_counters[noise_name] = {
                "counter": pattern_counter,
                "total_files": total_files,
                "total_vectors": sum(pattern_counter.values()),
                "unique_patterns": len(pattern_counter),
                "max_repeat": max(pattern_counter.values()) if pattern_counter else 0,
            }

    return noise_proportions, noise_counters


def save_proportion_scores0(input_dir, output_json):
    proportion_summary = {}

    for size_dir in os.listdir(input_dir):
        size_path = osp.join(input_dir, size_dir)
        if not osp.isdir(size_path):
            continue

        proportion_summary[size_dir] = {
            "subdirs": {},
            "avg_unique_across_subdirs": 0.0,
            "avg_full_across_subdirs": 0.0,
            "size_level_proportion": {},
            "size_level_stats": {}
        }

        size_noise_counters = defaultdict(Counter)
        size_noise_stats = defaultdict(lambda: {"total_vectors": 0, "total_files": 0})
        sub_unique = []
        sub_full = []

        for subdir in os.listdir(size_path):
            and_path = osp.join(size_path, subdir)
            if not osp.isdir(and_path) or not subdir.startswith("and"):
                continue

            tt_dir = osp.join(and_path, "noise_0.05")
            if not osp.isdir(tt_dir):
                continue

            has_noise_subdirs = any(osp.isdir(osp.join(tt_dir, d)) for d in os.listdir(tt_dir))
            is_fallback_noise0 = not has_noise_subdirs

            noise_proportions, noise_counters = compute_truth_proportions(
                tt_dir, fallback_noise0=is_fallback_noise0
            )
            if not noise_counters:
                continue

            for noise_name, counter_info in noise_counters.items():
                pattern_counter = counter_info["counter"]
                size_noise_counters[noise_name] += pattern_counter
                size_noise_stats[noise_name]["total_vectors"] += counter_info["total_vectors"]
                size_noise_stats[noise_name]["total_files"] += counter_info["total_files"]

            # avg across noises in this subdir
            avg_unique = np.mean([p["unique_ratio"] for p in noise_proportions.values()])
            avg_full = np.mean([p["nontrivial_ratio"] for p in noise_proportions.values()])

            proportion_summary[size_dir]["subdirs"][subdir] = {
                "avg_unique": round(avg_unique, 4),
                "avg_full": round(avg_full, 4),
                "noise_proportions": noise_proportions
            }
            sub_unique.append(avg_unique)
            sub_full.append(avg_full)

        if sub_unique:
            proportion_summary[size_dir]["avg_unique_across_subdirs"] = round(np.mean(sub_unique), 4)
        if sub_full:
            proportion_summary[size_dir]["avg_full_across_subdirs"] = round(np.mean(sub_full), 4)

        # size level stats (aggregate across noises)
        size_level_props = {}
        for noise_name, counter in size_noise_counters.items():
            total_files = size_noise_stats[noise_name]["total_files"]
            props = compute_proportions_from_counter(pattern_counter)
            size_level_props = props

            total_vectors = size_noise_stats[noise_name]["total_vectors"]
            max_count = max(counter.values()) if counter else 0
            max_repeat_ratio = max_count / total_vectors if total_vectors > 0 else 0

            proportion_summary[size_dir]["size_level_stats"][noise_name] = {
                "unique_ratio": props["unique_ratio"],
                "nontrivial_ratio": props["nontrivial_ratio"],
                "total_vectors": total_vectors,
                "total_files": total_files,
                "unique_patterns": len(counter),
                "max_repeat_count": max_count,
                "max_repeat_ratio": round(max_repeat_ratio, 4),
                "avg_repeat": round(total_vectors / len(counter),2) if counter else 0,
            }

        if size_level_props:
            proportion_summary[size_dir]["size_level_proportion"] = size_level_props

    with open(output_json, 'w') as f:
        json.dump(proportion_summary, f, indent=2)
    print(f"Saved proportion report to {output_json}")

def save_proportion_scores(input_dir, output_json):
    proportion_summary = {}

    for size_dir in os.listdir(input_dir):
        size_path = osp.join(input_dir, size_dir)
        if not osp.isdir(size_path):
            continue

        proportion_summary[size_dir] = {
            "subdirs": {},
            "avg_unique_across_subdirs": 0.0,  # simply avg
            "avg_full_across_subdirs": 0.0,
            "global_unique_ratio": 0.0,        # cross eval
            "global_nontrivial_ratio": 0.0,
            "size_level_stats": {}
        }

        size_noise_counters = defaultdict(Counter)
        size_noise_stats = defaultdict(lambda: {"total_vectors": 0, "total_files": 0})

        sub_unique = []
        sub_full = []

        for subdir in os.listdir(size_path):
            and_path = osp.join(size_path, subdir)
            if not osp.isdir(and_path) or not subdir.startswith("and"):
                continue

            tt_dir = osp.join(and_path, "noise_0.05")
            if not osp.isdir(tt_dir):
                continue

            has_noise_subdirs = any(osp.isdir(osp.join(tt_dir, d)) for d in os.listdir(tt_dir))
            is_fallback_noise0 = not has_noise_subdirs

            noise_proportions, noise_counters = compute_truth_proportions(
                tt_dir, fallback_noise0=is_fallback_noise0
            )
            if not noise_counters:
                continue

            # aggregate to size_dir
            for noise_name, counter_info in noise_counters.items():
                pattern_counter = counter_info["counter"]
                size_noise_counters[noise_name] += pattern_counter
                size_noise_stats[noise_name]["total_vectors"] += counter_info["total_vectors"]
                size_noise_stats[noise_name]["total_files"] += counter_info["total_files"]

            # subdir avg
            avg_unique = np.mean([p["unique_ratio"] for p in noise_proportions.values()])
            avg_full = np.mean([p["nontrivial_ratio"] for p in noise_proportions.values()])
            proportion_summary[size_dir]["subdirs"][subdir] = {
                "avg_unique": round(avg_unique, 4),
                "avg_full": round(avg_full, 4),
                "noise_proportions": noise_proportions
            }
            sub_unique.append(avg_unique)
            sub_full.append(avg_full)

        # simple avg
        if sub_unique:
            proportion_summary[size_dir]["avg_unique_across_subdirs"] = round(np.mean(sub_unique), 4)
        if sub_full:
            proportion_summary[size_dir]["avg_full_across_subdirs"] = round(np.mean(sub_full), 4)

        # combined avg
        merged_counter = Counter()
        for noise_name, counter in size_noise_counters.items():
            merged_counter += counter
        if merged_counter:
            global_props = compute_proportions_from_counter(merged_counter)
            proportion_summary[size_dir]["global_unique_ratio"] = global_props["unique_ratio"]
            proportion_summary[size_dir]["global_nontrivial_ratio"] = global_props["nontrivial_ratio"]

        # size-level stats (noise)
        for noise_name, counter in size_noise_counters.items():
            total_files = size_noise_stats[noise_name]["total_files"]
            props = compute_proportions_from_counter(counter)
            total_vectors = size_noise_stats[noise_name]["total_vectors"]
            max_count = max(counter.values()) if counter else 0
            max_repeat_ratio = max_count / total_vectors if total_vectors > 0 else 0

            proportion_summary[size_dir]["size_level_stats"][noise_name] = {
                "unique_ratio": props["unique_ratio"],
                "nontrivial_ratio": props["nontrivial_ratio"],
                "total_vectors": total_vectors,
                "total_files": total_files,
                "unique_patterns": len(counter),
                "max_repeat_count": max_count,
                "max_repeat_ratio": round(max_repeat_ratio, 4),
                "avg_repeat": round(total_vectors / len(counter), 2) if counter else 0,
            }

    with open(output_json, 'w') as f:
        json.dump(proportion_summary, f, indent=2)
    print(f"Saved proportion report to {output_json}")


def main():
    parser = argparse.ArgumentParser(description='Compute truth table proportion metrics.')
    parser.add_argument('--input_dir', default="..\AN", help='Root directory containing size-level subdirectories.')
    parser.add_argument('--output_json', default='truth_proportion_summary_AN_noise005.json', help='Path to save output JSON with proportion statistics.')
    args = parser.parse_args()

    save_proportion_scores(args.input_dir, args.output_json)

if __name__ == '__main__':
    main()


