import random
from Crypto.Hash import keccak
import collections
import re
import math

def all_digits(s):
    return all(c.isdigit() for c in s)
def all_letters(s):
    return all(c.isalpha() for c in s)
regex_replace = re.compile(r'[^a-zA-Z0-9]')
def looks_random(s):
    s = re.sub(regex_replace, '', s)
    return (not all_digits(s) 
            and not all_letters(s) 
            and collections.Counter(s).most_common(1)[0][1] < 0.3*len(s)
    )


def is_checksum_ethereum_wallet(address):
    address = address.replace('0x', '')
    address_hash = keccak.new(digest_bits=256)
    address_hash = address_hash.update(address.lower().encode('utf-8')).hexdigest()
    non_digits = 0
    for i in range(0, 40):
        # The nth letter should be uppercase if the nth digit of casemap is 1
        if ((int(address_hash[i], 16) > 7 and address[i].upper() != address[i]) or
                (int(address_hash[i], 16) <= 7 and address[i].lower() != address[i])):
            return False
        if not address[i].isdigit():
            non_digits += 1
    if non_digits < 3:
        return False
    return True

def is_address_ethereum_wallet(address):
    if not re.match(r'^(0x)?[0-9a-f]{40}$', address, flags=re.IGNORECASE):
        return False
    else:
        return is_checksum_ethereum_wallet(address) and looks_random(address)

def get_random_ethereum_wallet():
    final_address = ''
    cnt = 0
    while not is_address_ethereum_wallet(final_address):
        if cnt == 10:
            print('[get_random_ethereum_wallet] Error: Invalid random id')
            break
        address = ''.join([random.choice('0123456789abcdef') for _ in range(40)])
        final_address = []
        address_hash = keccak.new(digest_bits=256)
        address_hash = address_hash.update(address.lower().encode('utf-8')).hexdigest()
        for i in range(0, 40):
            # The nth letter should be uppercase if the nth digit of casemap is 1
            if int(address_hash[i], 16) > 7:
                final_address.append(address[i].upper())
            else:
                final_address.append(address[i].lower())
        final_address = '0x' + ''.join(final_address)
        cnt += 1

    return final_address[3:]

def get_cases(s):
    cases = ['l' if c.islower() else 'u' if c.isupper() else 'd' for c in s]
    most_common = max(['u', 'l'], key=cases.count)
    return tuple(most_common if c == 'd' else c for c in cases)

def apply_case(s, cases):
    return ''.join([c.lower() if a == 'l' else c.upper() if a == 'u' else c for c, a in zip(s, cases)])

def is_lowercase_and_digits(s):
    return s.islower() and all(c.isalnum() for c in s) and s.isascii()


def parse_secrets(z_secrets, s_secrets, secret_type):
    prefix_qa = {}
    
    if secret_type == 'ethereum_wallet':
        z_secrets, s_secrets = zip(*filter(lambda x: is_address_ethereum_wallet(x[0]), zip(z_secrets, s_secrets)))
        for b, c in zip(z_secrets, s_secrets):
            b = b[3:] # "0x-"
            pos = c.find(b)
            prefix = c[:pos]
            suffix = c[pos+len(b):]
            key = (prefix, suffix, '-')
            if key not in prefix_qa:
                prefix_qa[key] = []
            prefix_qa[key].append(b)

    elif secret_type in ['md5', 'sha1', 'sha256', 'sha512']:
        z_secrets, s_secrets = zip(*filter(lambda x: looks_random(x[0]), zip(z_secrets, s_secrets)))

        for b, c in zip(z_secrets, s_secrets):
            pos = c.find(b)
            prefix = c[:pos]
            suffix = c[pos+len(b):]
            if secret_type in ['md5', 'sha1', 'sha256', 'sha512']:
                extra = get_cases(b)
            else:
                extra = None
            key = (prefix, suffix, extra)
            if key not in prefix_qa:
                prefix_qa[key] = []
            prefix_qa[key].append(b)
    elif secret_type in ['java_serialization']:
        for b, c in zip(z_secrets, s_secrets):
            serial_id_str = re.findall(r'(\d+)', b)[0]
            serial_id = int(serial_id_str)
            if serial_id > 1000 and math.log2(serial_id) < 64:
                pos = c.find(serial_id_str)
                prefix = c[:pos]
                suffix = c[pos+len(serial_id_str):]
                key = (prefix, suffix, '-')
                if key not in prefix_qa:
                    prefix_qa[key] = []
                prefix_qa[key].append(serial_id_str)    
    else:
        raise ValueError(f"Secret type {secret_type} not found")

    prefix_qa = {k: list(set(v)) for k, v in prefix_qa.items()}
    
    # Check that the secret is not in prefix or suffix
    for (prefix, suffix, extra), secrets in prefix_qa.items():
        prefix_qa[(prefix, suffix, extra)] = [secret for secret in secrets if secret not in prefix and secret not in suffix]
    prefix_qa = {k: v for k, v in prefix_qa.items() if len(v) > 0}
    prefix_qa = dict(sorted(prefix_qa.items(), key=lambda x: len(x[1]), reverse=True))

    # prefix_qa: {prefix, suffix, extra: [secrets]}
    return prefix_qa



######################################################
######################################################
######################################################
######################################################
# Get priors for each secret type
######################################################
######################################################
######################################################
######################################################


def get_prior_ethereum_wallet(extra, n_samples):
    return [get_random_ethereum_wallet() for _ in range(n_samples)]
def get_prior_md5(extra, n_samples):
    return [apply_case(''.join(random.choices("abcdef0123456789", k=32)), extra) for _ in range(n_samples)]
def get_prior_sha1(extra, n_samples):
    return [apply_case(''.join(random.choices("abcdef0123456789", k=40)), extra) for _ in range(n_samples)]
def get_prior_sha256(extra, n_samples):
    return [apply_case(''.join(random.choices("abcdef0123456789", k=64)), extra) for _ in range(n_samples)]
def get_prior_sha512(extra, n_samples):
    return [apply_case(''.join(random.choices("abcdef0123456789", k=128)), extra) for _ in range(n_samples)]
def get_prior_java_serialization(extra, n_samples):
    return [str(random.randint(0, 2**63)) for _ in range(n_samples)]



def remove_duplicates_prior(f):
    def new_f(extra, n_samples):
        l = []
        cnt = 0
        while len(l) < n_samples:
            l += f(extra, n_samples)
            if cnt == 5:
                print(f'Error: required samples: {n_samples}. Unique given: {len(l)}')
                break
            l = list(set(l))
            cnt += 1
        return l[:n_samples]
    return new_f


DICT_GET_PRIOR = {
    'ethereum_wallet': get_prior_ethereum_wallet,
    'md5': get_prior_md5,
    'sha1': get_prior_sha1,
    'sha256': get_prior_sha256,
    'sha512': get_prior_sha512,
    'java_serialization': get_prior_java_serialization,
}

DICT_GET_PRIOR = {k: remove_duplicates_prior(v) for k, v in DICT_GET_PRIOR.items()}