import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import argparse

# ===== Parse command-line arguments =====
parser = argparse.ArgumentParser(description="Analyze node counts in JSONL samples.")
parser.add_argument('--input', type=str, required=True, help="Path to input JSONL file")
parser.add_argument('--output', type=str, default='node_distribution', help="Output picture name prefix (default: node_distribution)")
args = parser.parse_args()

# ===== Load data =====
data = pd.read_json(args.input, lines=True)
pic_name = args.output

# ===== Basic variable statistics =====
min_node = float('inf')
min_idx = []
max_node = 0
max_idx = []
target_0 = 0
target_1 = 0

ingore_idx = []  # An ID filter list can be added as needed

for i in range(len(data)):
    if int(data['target'][i]) == 0:
        target_0 += 1
    else:
        target_1 += 1

    if int(data['id'][i]) in ingore_idx:
        continue

    node_len = len(data['nodes'][i])

    if node_len <= min_node:
        if node_len < min_node:
            min_node = node_len
            min_idx = [int(data['id'][i])]
        else:
            min_idx.append(int(data['id'][i]))

    if node_len >= max_node:
        if node_len > max_node:
            max_node = node_len
            max_idx = [int(data['id'][i])]
        else:
            max_idx.append(int(data['id'][i]))

# ===== print =====
print(f'data length: {len(data)}')
print(f'target = 0: {target_0} | target = 1: {target_1}')
print(f'min_node: {min_node} | min_idx: {min_idx} (count: {len(min_idx)})')
print(f'max_node: {max_node} | max_idx: {max_idx} (count: {len(max_idx)})')

# ===== Count the number of nodes =====
node_counts = [len(nodes) for nodes in data['nodes']]

print("\nNode count statistics:")
print(f"Maximum node count: {max(node_counts)}")
print(f"Average node count: {np.mean(node_counts):.2f}")
print(f"Median node count: {np.median(node_counts)}")
print(f"Standard deviation: {np.std(node_counts):.2f}")

# ===== Plot a histogram =====
max_count = max(node_counts)
bins = list(range(0, max_count + 100, 100))

plt.figure(figsize=(12, 6))
n, bins, patches = plt.hist(node_counts, bins=bins, edgecolor='k', alpha=0.7)

plt.title('Node Count Distribution (100 per group)', fontsize=14)
plt.xlabel('Node Count Interval', fontsize=12)
plt.ylabel('Sample Count', fontsize=12)
plt.grid(axis='y', linestyle='--', alpha=0.7)

for i in range(len(patches)):
    if n[i] > 0:
        plt.text(bins[i] + 50, n[i] + 0.5, str(int(n[i])),
                 ha='center', va='bottom', fontsize=9)

plt.tight_layout()
plt.savefig(f'{pic_name}.png', dpi=300)
# plt.show()
