import numpy as np
import torch
from scipy.sparse import csr_matrix
from scipy.special import binom, lambertw
from ldpc import bp_decoder
import sys
import galois

from scipy.special import erf
from scipy.linalg import orth

GF = galois.GF(2)

def apply_channel_probs(x, channel_probs):
    e = GF(np.random.binomial(1, channel_probs))
    return x + e

### Given a GF(2) matrix, do row elimination and return the first k rows of A that form an invertible matrix
def boolean_row_reduce(A, print_progress=False):
    n, k = A.shape
    A_rr = A.copy()
    perm = np.arange(n)
    for j in range(k):
        idxs = j + np.nonzero(A_rr[j:, j])[0]
        if idxs.size == 0:
            print("The given matrix is not invertible")
            return None
        A_rr[[j, idxs[0]]] = A_rr[[idxs[0], j]]  # For matrices you have to swap them this way
        (perm[j], perm[idxs[0]]) = (perm[idxs[0]], perm[j])  # Weirdly, this is MUCH faster if you swap this way instead of using perm[[i,j]]=perm[[j,i]]
        A_rr[idxs[1:]] += A_rr[j]
        if print_progress and (j%5==0 or j+1==k):
            sys.stdout.write(f'\rDecoding progress: {j + 1} / {k}')
            sys.stdout.flush()
    if print_progress: print()
    return perm[:k]


def str_to_bin(string):
    bin_str = ''.join(format(i, '08b') for i in bytearray(string, encoding ='utf-8'))
    return [int(b) for b in bin_str]

def bin_to_str(bin_list):
    bin_str = ''.join(map(str, bin_list))
    byte_array = bytearray(int(bin_str[i:i+8], 2) for i in range(0, len(bin_str), 8) if bin_str[i:i+8]!='00000000')
    return byte_array.decode('utf-8')

### Key generation algorithm.
## Inputs:
# n - block length (i.e., length of PRC codeword).
# message_length - length of messages you want to encode
# false_positive_rate - the false positive rate you're willing to tolerate
# t - sparsity of parity checks. larger values help pseudorandomness
# g - dimension of random code used. larger values help pseudorandomness
# r - number of parity checks used. smaller values help pseudorandomness
# noise_rate - amount of noise for Encode to add to codewords. larger values help pseudorandomness
def KeyGen(n, message_length=512, false_positive_rate=1e-9, t=3, g=None, r=None, noise_rate=None):
    # Set basic scheme parameters
    num_test_bits = int(np.ceil(np.log2(1 / false_positive_rate)))
    secpar = int(np.log2(binom(n, t)))
    if g is None: g = secpar
    # if noise_rate is None: noise_rate = np.exp(lambertw(-np.log(2) / secpar, -1)).real
    # if noise_rate is None: noise_rate = 1 - 2**(-(secpar - 3*np.log2(g))/g**2)
    if noise_rate is None: noise_rate = 1 - 2 ** (-secpar / g ** 2)
    k = message_length + g + num_test_bits
    if r is None: r = n - k - secpar

    # Sample n by k generator matrix (all but the first n-r of these will be over-written)
    generator_matrix = GF.Random((n, k))

    # Sample scipy.sparse parity-check matrix together with the last n-r rows of the generator matrix
    row_indices = []
    col_indices = []
    data = []
    for row in range(r):
        chosen_indices = np.random.choice(n - r + row, t - 1, replace=False)
        chosen_indices = np.append(chosen_indices, n - r + row)
        row_indices.extend([row] * t)
        col_indices.extend(chosen_indices)
        data.extend([1] * t)
        generator_matrix[n - r + row] = generator_matrix[chosen_indices[:-1]].sum(axis=0)
    parity_check_matrix = csr_matrix((data, (row_indices, col_indices)))

    # Compute scheme parameters
    max_bp_iter = int(np.log(n) / np.log(t))

    # Sample one-time pad and test bits
    one_time_pad = GF.Random(n)
    test_bits = GF.Random(num_test_bits)

    # Permute bits
    permutation = np.random.permutation(n)
    generator_matrix = generator_matrix[permutation]
    one_time_pad = one_time_pad[permutation]
    parity_check_matrix = parity_check_matrix[:, permutation]

    encoding_key = (generator_matrix, one_time_pad, test_bits, g, noise_rate)
    decoding_key = (generator_matrix, parity_check_matrix, one_time_pad, false_positive_rate, noise_rate, test_bits, g, max_bp_iter, t)

    return encoding_key, decoding_key


### Encoding algorithm
## Inputs:
# encoding_key - Encoding key output by KeyGen.
# message - Message to encode, as an array of k bits. If none is provided a random message is used.
def Encode(encoding_key, message=None):
    generator_matrix, one_time_pad, test_bits, g, noise_rate = encoding_key
    n, k = generator_matrix.shape

    if message is None:
        payload = np.concatenate((test_bits, GF.Random(k - len(test_bits))))
    else:
        assert len(message) <= k-len(test_bits)-g, "Message is too long"
        payload = np.concatenate((test_bits, GF.Random(g), GF(message), GF.Zeros(k-len(test_bits)-g-len(message))))

    error = GF(np.random.binomial(1, noise_rate, n))

    return 1 - 2 * torch.tensor(payload @ generator_matrix.T + one_time_pad + error, dtype=float)


### Detector
## Inputs:
# decoding_key - Decoding key output by KeyGen.
# posteriors - The posterior expectations of sign(z) as a torch.tensor.
## Returns:
# True/False - Detection result.
def Detect(decoding_key, posteriors, false_positive_rate=None):
    generator_matrix, parity_check_matrix, one_time_pad, false_positive_rate_key, noise_rate, test_bits, g, max_bp_iter, t = decoding_key
    if false_positive_rate is not None:
        fpr = false_positive_rate
    else:
        fpr = false_positive_rate_key

    posteriors = (1 - 2 * noise_rate) * (1 - 2 * np.array(one_time_pad, dtype=float)) * posteriors.numpy(force=True)

    r = parity_check_matrix.shape[0]
    Pi = np.prod(posteriors[parity_check_matrix.indices.reshape(r, t)], axis=1)
    log_plus = np.log((1 + Pi) / 2)
    log_minus = np.log((1 - Pi) / 2)
    log_prod = log_plus + log_minus

    const = 0.5 * np.sum(np.power(log_plus, 2) + np.power(log_minus, 2) - 0.5 * np.power(log_prod, 2))
    threshold = np.sqrt(2 * const * np.log(1 / fpr)) + 0.5 * log_prod.sum()

    return log_plus.sum() >= threshold


### Decoder
## Inputs:
# decoding_key - Decoding key output by KeyGen.
# posteriors - The posterior expectations of sign(z) as a torch.tensor.
## Returns:
# recovered_message - The recovered message. If the test bits are incorrect, outputs None.
def Decode(decoding_key, posteriors, print_progress=False, max_bp_iter=None):
    generator_matrix, parity_check_matrix, one_time_pad, false_positive_rate_key, noise_rate, test_bits, g, max_bp_iter_key, t = decoding_key
    if max_bp_iter is None:
        max_bp_iter = max_bp_iter_key

    posteriors = (1 - 2 * noise_rate) * (1 - 2 * np.array(one_time_pad, dtype=float)) * posteriors.numpy(force=True)
    channel_probs = (1 - np.abs(posteriors)) / 2
    x_recovered = (1 - np.sign(posteriors)) // 2

    # Apply the belief-propagation decoder.
    if print_progress:
        print("Running belief propagation...")
    bpd = bp_decoder(parity_check_matrix, channel_probs=channel_probs, max_iter=max_bp_iter, bp_method="product_sum")
    x_decoded = bpd.decode(x_recovered)
    x_decoded=np.array(x_decoded,dtype=np.int8)

    # Compute a confidence score.
    bpd_probs = 1 / (1 + np.exp(bpd.log_prob_ratios))
    confidences = 2 * np.abs(0.5 - bpd_probs)

    # Order codeword bits by confidence.
    confidence_order = np.argsort(-confidences)
    ordered_generator_matrix = generator_matrix[confidence_order]
    ordered_x_decoded = x_decoded[confidence_order]

    # Find the first (according to the confidence order) linearly independent set of rows of the generator matrix.
    top_invertible_rows = boolean_row_reduce(ordered_generator_matrix, print_progress=print_progress)
    if top_invertible_rows is None:
        return None

    # Solve the system.
    if print_progress:
        print("Solving linear system...")
    recovered_string = np.linalg.solve(ordered_generator_matrix[top_invertible_rows], 
                                       GF(ordered_x_decoded[top_invertible_rows]))

    if not (recovered_string[:len(test_bits)] == test_bits).all():
        if print_progress:
            print("testbits ist",recovered_string[:len(test_bits)],"testbits soll",test_bits)
        return None
    return np.array(recovered_string[len(test_bits) + g:])



#####################################################pseudogaussians


def sample(codeword, basis=None):
    # pseudogaussian = codeword * torch.abs(torch.randn_like(codeword, dtype=torch.float64))
    codeword_np = codeword.numpy()
    pseudogaussian_np = codeword_np * np.abs(np.random.randn(*codeword_np.shape))
    pseudogaussian = torch.from_numpy(pseudogaussian_np).to(dtype=torch.float64)
    if basis is None:
        return pseudogaussian
    return pseudogaussian @ basis.T


def recover_posteriors(z, basis=None, variances=None):
    if variances is None:
        default_variance = 1.5
        denominators = np.sqrt(2 * default_variance * (1+default_variance)) * torch.ones_like(z)
    elif type(variances) is float:
        denominators = np.sqrt(2 * variances * (1 + variances))
    else:
        denominators = torch.sqrt(2 * variances * (1 + variances))

    if basis is None:
        return erf(z / denominators)
    else:
        return erf((z @ basis) / denominators)
    

def random_basis(n):
    gaussian = torch.randn(n, n, dtype=torch.double)
    return orth(gaussian)


############################### GS++
### Key generation algorithm.
## Inputs:
# n - block length (i.e., length of PRC codeword).
# message_length - length of messages you want to encode
# false_positive_rate - the false positive rate you're willing to tolerate
# t - sparsity of parity checks. larger values help pseudorandomness
# g - dimension of random code used. larger values help pseudorandomness
# r - number of parity checks used. smaller values help pseudorandomness
# noise_rate - amount of noise for Encode to add to codewords. larger values help pseudorandomness
def KeyGen_gspp(n, message_length=32, t=3):
    # Set basic scheme parameters
    secpar = int(np.log2(binom(n, t)))
    g = secpar
    noise_rate = 1 - 2 ** (-secpar / g ** 2)
    k = message_length
    r = n - k - secpar
    # Sample n by k generator matrix (all but the first n-r of these will be over-written)
    generator_matrix = GF.Random((n, k))

    # Sample scipy.sparse parity-check matrix together with the last n-r rows of the generator matrix
    row_indices = []
    col_indices = []
    data = []
    for row in range(r):
        chosen_indices = np.random.choice(n - r + row, t - 1, replace=False)
        chosen_indices = np.append(chosen_indices, n - r + row)
        row_indices.extend([row] * t)
        col_indices.extend(chosen_indices)
        data.extend([1] * t)
        generator_matrix[n - r + row] = generator_matrix[chosen_indices[:-1]].sum(axis=0)
    parity_check_matrix = csr_matrix((data, (row_indices, col_indices)))

    # Compute scheme parameters
    max_bp_iter = int(np.log(n) / np.log(t))

    # Permute bits
    permutation = np.random.permutation(n)
    generator_matrix = generator_matrix[permutation]
    parity_check_matrix = parity_check_matrix[:, permutation]

    encoding_key = (generator_matrix, g, noise_rate)
    decoding_key = (generator_matrix, parity_check_matrix, noise_rate, g, max_bp_iter, t)

    return encoding_key, decoding_key



def Encode_gspp(encoding_key, payload=None):
    '''encode like prc, but just with the pure payload and no extras'''
    generator_matrix, g, noise_rate = encoding_key

    n, k = generator_matrix.shape
    error = GF(np.random.binomial(1, noise_rate, n))
    #assert payload matches size
    payload = GF(payload)

    codeword = payload @ generator_matrix.T + error
    #todo: numpy conversion
    return np.array(codeword,dtype=np.uint8)

def Decode_gspp(decoding_key, posteriors, print_progress=False, max_bp_iter=None):
    generator_matrix, parity_check_matrix, noise_rate, g, max_bp_iter, t = decoding_key
    if max_bp_iter is None:
        max_bp_iter = max_bp_iter_key

    posteriors = (1 - 2 * noise_rate) * posteriors.numpy(force=True)
    channel_probs = (1 - np.abs(posteriors)) / 2
    x_recovered = (1 - np.sign(posteriors)) // 2

    # Apply the belief-propagation decoder.
    if print_progress:
        print("Running belief propagation...")
    bpd = bp_decoder(parity_check_matrix, channel_probs=channel_probs, max_iter=max_bp_iter, bp_method="product_sum")
    x_decoded = bpd.decode(x_recovered)
    x_decoded=np.array(x_decoded,dtype=np.int8)

    # Compute a confidence score.
    bpd_probs = 1 / (1 + np.exp(bpd.log_prob_ratios))
    confidences = 2 * np.abs(0.5 - bpd_probs)

    # Order codeword bits by confidence.
    confidence_order = np.argsort(-confidences)
    ordered_generator_matrix = generator_matrix[confidence_order]
    ordered_x_decoded = x_decoded[confidence_order]

    # Find the first (according to the confidence order) linearly independent set of rows of the generator matrix.
    top_invertible_rows = boolean_row_reduce(ordered_generator_matrix, print_progress=print_progress)
    if top_invertible_rows is None:
        return None

    # Solve the system.
    if print_progress:
        print("Solving linear system...")
    recovered_string = np.linalg.solve(ordered_generator_matrix[top_invertible_rows], 
                                       GF(ordered_x_decoded[top_invertible_rows]))

    return np.array(recovered_string)