# -*- coding: utf-8 -*-
"""
Compute dataset-averaged cumulative probabilities P(X ≤ r) for the first n_i bits,
where n_i = 8 * len(target_i), N = 598 total positions, k = 0..20 total errors,
and r = 0..5. Input file: testset.csv (must have a 'target' column).
"""

import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path

# --------------------
# Config
# --------------------
CSV_PATH = "dataset_segno/ver3/data_domain_ver3_mask0_L_generalization/missspelled/testset.csv"      
N = 598                     
K_MAX = 20
R_MAX = 5
OUT_CSV = "avg_cumprob_dataset_N598_r0to5_k0to20.csv"
OUT_NDIST = "nbits_distribution.csv"

# --------------------
# Helpers
# --------------------
def comb(n, r):
    return math.comb(n, r) if 0 <= r <= n else 0

def hypergeom_pmf(j, k, n, N=598):
    """X ~ Hypergeom(N, n, k): P(X=j) = C(n,j) C(N-n, k-j) / C(N,k)"""
    if j < 0 or j > n or j > k or (k - j) > (N - n):
        return 0.0
    return comb(n, j) * comb(N - n, k - j) / comb(N, k)

def hypergeom_cdf_le_r(n, k, r, N=598):
    """P(X ≤ r) = sum_{j=0..min(r,k,n)} PMF"""
    m = min(r, k, n)
    return sum(hypergeom_pmf(j, k, n, N) for j in range(m + 1))

# --------------------
# Load
# --------------------
if not Path(CSV_PATH).exists():
    raise FileNotFoundError(f"File not found: {CSV_PATH}")

df = pd.read_csv(CSV_PATH)
if "target" not in df.columns:
    raise ValueError("CSVに 'target' 列が必要です。")

n_bits = (df["target"].astype(str).apply(len) * 8).astype(int)

n_bits = np.minimum(n_bits.values, N)

unique_n, counts = np.unique(n_bits, return_counts=True)
weights = counts / counts.sum()

# --------------------
# --------------------
k_values = list(range(0, K_MAX + 1))
r_values = list(range(0, R_MAX + 1))

per_n_probs = {}  # n -> shape (len(r_values), len(k_values))
for n in unique_n:
    arr = np.zeros((len(r_values), len(k_values)))
    for ri, r in enumerate(r_values):
        for ki, k in enumerate(k_values):
            arr[ri, ki] = hypergeom_cdf_le_r(n=n, k=k, r=r, N=N)
    per_n_probs[n] = arr

avg_probs = np.zeros((len(r_values), len(k_values)))
for n, w in zip(unique_n, weights):
    avg_probs += per_n_probs[n] * w

out = {"k": k_values}
for ri, r in enumerate(r_values):
    out[f"P_mean(X≤{r})"] = avg_probs[ri, :].tolist()
df_out = pd.DataFrame(out)
df_out.to_csv(OUT_CSV, index=False)

pd.DataFrame({"n_bits": unique_n, "count": counts}).sort_values("n_bits").to_csv(OUT_NDIST, index=False)

# --------------------
# --------------------
plt.figure()
for ri, r in enumerate(r_values):
    plt.plot(k_values, avg_probs[ri, :], marker="o", label=f"X≤{r}")
plt.ylim(0, 1)
plt.xlabel("number of flip")
plt.ylabel("Average P(X ≤ r)")
plt.title("Average P(X ≤ r) vs number of flip over dataset\n(N=598, first n_i = 8*len(target))")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig("avg_cumprob_dataset_N598_r0to5_k0to20.png")

print(f"Saved: {OUT_CSV}")
print(f"Saved: {OUT_NDIST}")
