import math
import matplotlib.pyplot as plt
import matplotlib
import pandas as pd
from pyzbar.pyzbar import decode
import numpy as np


# set meta information
version = '3-L'
N_total = 598
N_data = 560
N_format = 30
N_remainder = 8
M_codewords = 70
M_ecc_codewords = 15
data_path = "dataset_segno/ver3/data_domain_ver3_mask0_L/damaged/dataset_damaged.csv"

# version = '3-M'
# N_total = 598
# N_data = 560
# N_format = 30
# N_remainder = 8
# M_codewords = 70
# M_ecc_codewords = 26
# data_path = "dataset_segno/ver3/data_domain_ver3_mask0_M/damaged/dataset_damaged.csv"

# version = '2-L'
# N_total = 390
# N_data = 352
# N_format = 30
# N_remainder = 8
# M_codewords = 44
# M_ecc_codewords = 10
# data_path = "dataset_segno/ver2/data_domain_ver2_mask0_L/damaged/dataset_damaged.csv"

# version = '2-M'
# N_total = 390
# N_data = 359
# N_format = 30
# N_remainder = 8
# M_codewords = 44
# M_ecc_codewords = 16
# data_path = "dataset_segno/ver2/data_domain_ver2_mask0_M/damaged/dataset_damaged.csv"


def p_data(n):
    t = math.floor(M_ecc_codewords/2)
    all_patterns = math.comb(N_data, n)
    success_patterns = 0
    for k in range(math.ceil(n/8), t+1):
        for j in range(0, k+1):
            p = (-1)**j * math.comb(k, j) * math.comb(8*(k-j), n)
            success_patterns += math.comb(M_codewords, k) * p
    return success_patterns / all_patterns

# def p_format(n):
#     success_patterns = 0
#     for i in range(0, n+1):
#         if i <= 3 or (n-i) <= 3:
#             success_patterns += 1
#     return success_patterns / (n+1)]

def p_format(n):
    if n <= 7:
        return 1
    else:
        return 8/(n+1)


def p_success(n):
    tota_patterns = math.comb(N_total, n)
    success_patterns = 0
    for i in range(0, n + 1):
        for j in range(0, n-i+1):
            k = n - i - j
            ways = (math.comb(N_data, i) * math.comb(N_format, j) * math.comb(N_remainder, k))
            success_patterns += ways * p_data(i) * p_format(j)
    return success_patterns / tota_patterns


if __name__ == "__main__":
    x_values = range(21)
    y_values = []
    for i in x_values:
        p = p_success(i)
        y_values.append(p)
    
    df = pd.read_csv(data_path)
    data_num = len(df)
    accuracy_list = []
    for damage in x_values:
        if damage == 0:
            column = "input"
        else:
            column = f"flip_{damage}"
        print(f"Reading data for: {column}")
        num_hits = 0
        for i in range(data_num):
            bitstring = df.iloc[i][column]
            target = df.iloc[i]["target"]
            size = int(np.sqrt(len(bitstring)))
            arr = np.array(list(map(int, bitstring)))
            arr = arr.reshape(size, size)
            arr = np.pad(arr, pad_width=1, mode="constant", constant_values=0)
            image = ((1 - arr) * 255).astype(np.uint8)
            decoded = decode(image)
            read_text = decoded[0].data.decode("utf-8") if decoded else None
            if read_text == target:
                num_hits += 1
        accuracy = num_hits / data_num
        print(f"Accuracy for {damage} flipped bits: {accuracy}")
        accuracy_list.append(accuracy)

    plt.figure(figsize=(8, 5))
    plt.plot(x_values, y_values, marker="o", label="Theoretical Success Probability")
    plt.plot(x_values, accuracy_list, marker="x", label="Empirical Accuracy")
    plt.xlabel("Number of Flipped Bits (n)")
    plt.ylabel("Success Probability / Accuracy")
    plt.title(f"QR Code Decoding Success Probability (Version {version})")
    plt.legend()
    plt.grid(True)
    plt.xticks(x_values)
    plt.savefig("qr_readable_prob.png")
    print('Saved the plot as qr_readable_prob.png')
            

