import math
import matplotlib.pyplot as plt
import matplotlib

# Configuration for saving images without GUI
matplotlib.use("Agg")

# QR Code (Version 3-L) parameters
TOTAL_MODULES = 598
DATA_MODULES = 560
CODEWORDS = 70  
ECC_CODEWORDS = 15 # Number of error correction codewords (actually can correct up to floor(15/2)=7 codewords)
REMAINDER_MODULES = 8
data_path = "dataset_segno/ver3/data_domain_ver3_mask0_L/damaged/dataset_damaged.csv"

# #QR Code (Version 3-M) parameters
# TOTAL_MODULES = 598
# DATA_MODULES = 560
# CODEWORDS = 70
# ECC_CODEWORDS = 26  # Number of error correction codewords (actually can correct up to floor(15/2)=7 codewords)
# REMAINDER_MODULES = 8
# data_path = "dataset_segno/ver3/data_domain_ver3_mask0_M/damaged/dataset_damaged.csv"

# # QR Code (Version 2-L) parameters
# TOTAL_MODULES = 390
# DATA_MODULES = 352
# CODEWORDS = 44
# REMAINDER_MODULES = 8
# data_path = "dataset_segno/ver2/data_domain_ver2_mask0_L/damaged/dataset_damaged.csv"

# # QR Code (Version 2-M) parameters
# TOTAL_MODULES = 390
# DATA_MODULES = 359
# CODEWORDS = 44
# REMAINDER_MODULES = 8
# data_path = "dataset_segno/ver2/data_domain_ver2_mask0_M/damaged/dataset_damaged.csv"




def success_probability_data(n: int) -> float:
    t = ECC_CODEWORDS // 2  # Error correction capability t

    # Range check
    if n < 0 or n > DATA_MODULES:
        return 0.0

    # Total number of patterns
    total = math.comb(DATA_MODULES, n)
    if total == 0:
        return 0.0

    numerator = 0
    # k: Move the number of bytes containing errors from 0 to t
    for k in range(t + 1):
        # First, select which k bytes will have errors concentrated
        byte_choices = math.comb(CODEWORDS, k)

        # Next, count the cases where "exactly n bits fall in the selected k bytes"
        # and "each byte contains at least 1 bit error"
        # using inclusion-exclusion principle
        ways_bits = 0
        for j in range(k + 1):
            # Assume j bytes are "error-free" and exclude them,
            # then select n from the remaining (k-j)*8 bits
            bits_available = (k - j) * 8
            if bits_available < n:
                continue
            ways_bits += (-1)**j * math.comb(k, j) * math.comb(bits_available, n)

        numerator += byte_choices * ways_bits

    return numerator / total

def success_probability_format(n):
    if n < 0:
        return 0.0
    # Total number of (i, j) combinations
    total = n + 1
    # Among i moving from 0 to n, count as success if any block is within 3 bits
    count = sum(1 for i in range(n + 1) if i <= 3 or (n - i) <= 3)
    return count / total

def success_probability_remainder(n):
    if n < 0 or n > REMAINDER_MODULES:
        raise ValueError("Invalid number of remainder bits")
    return 1.0

def total_success_probability(n):
    """
    Returns the overall decoding success probability when there are a total of n bit errors
    among TOTAL_BITS (=598) bits.
    """
    # Total error patterns
    total_patterns = math.comb(TOTAL_MODULES, n)
    if total_patterns == 0:
        return 0.0

    success_patterns = 0.0
    # Loop for each distribution where data area has i, format info has j, remainder area has k = n-i-j errors
    for i in range(max(0, n - (FORMAT_MODULES + REMAINDER_MODULES)), min(n, DATA_MODULES) + 1):
        for j in range(max(0, n - i - REMAINDER_MODULES), min(n - i, FORMAT_MODULES) + 1):
            k = n - i - j
            # Error patterns for each area
            ways = (math.comb(DATA_MODULES,    i) * math.comb(FORMAT_MODULES,  j) * math.comb(REMAINDER_MODULES, k))
            # Multiply success probabilities for each area
            p = (success_probability_data(i) * success_probability_format(j) * success_probability_remainder(k))
            success_patterns += ways * p

    return success_patterns / total_patterns



if __name__ == "__main__":
    # Theoretical success probability calculation
    x_values = range(21)
    
    y_values = []
    for i in x_values:
        p = total_success_probability(i)
        y_values.append(p)

    # Measure reading accuracy for actual damage data
    import pandas as pd
    from pyzbar.pyzbar import decode
    import numpy as np

    df = pd.read_csv(data_path)
    data_num = len(df)
    accuracy_list = []
    for damage in x_values:
        if damage == 0:
            column = "input"
        else:
            column = f"flip_{damage}"
        print(f"Reading data for: {column}")
        num_hits = 0
        for i in range(data_num):
            bitstring = df.iloc[i][column]
            target = df.iloc[i]["target"]
            # Convert bit sequence to image (assuming square)
            size = int(np.sqrt(len(bitstring)))
            arr = np.array(list(map(int, bitstring)))
            arr = arr.reshape(size, size)
            # Add quiet zone
            arr = np.pad(arr, pad_width=1, mode="constant", constant_values=0)
            image = ((1 - arr) * 255).astype(np.uint8)
            decoded = decode(image)
            read_text = decoded[0].data.decode("utf-8") if decoded else None
            if read_text == target:
                num_hits += 1
        accuracy = num_hits / data_num
        print(f"Accuracy for {damage} flipped bits: {accuracy}")
        accuracy_list.append(accuracy)


    print('Theoretical Success Probability:', y_values)
    print('Empirical Accuracy:', accuracy_list)

    # Plot results
    plt.figure(figsize=(8, 5))
    plt.plot(x_values, y_values, marker="o", label="Theoretical Success Probability")
    plt.plot(x_values, accuracy_list, marker="x", label="Empirical Accuracy")
    plt.xlabel("Number of Flipped Bits (n)")
    plt.ylabel("Success Probability / Accuracy")
    plt.title("QR Code Decoding Success Probability (Version 2-L)")
    plt.legend()
    plt.grid(True)
    plt.xticks(x_values)
    plt.savefig("qr_readable_prob.png")
    print('Saved the plot as qr_readable_prob.png')
