import time

import numpy as np

import utils


THRESHOLD = 0.0001


def verificate(Ae, Be, Ce, where):
    dim1_split, _ = Ae.shape

    q = np.random.rand((dim1_split))

    r = np.matmul(q, Ae)
    s = np.matmul(r, Be)
    s_ = np.matmul(q, Ce)

    e = ((s - s_) ** 2).sum()

    print(f" * Error from {where}: {e}.")

    return e


def check_verification(e, Cenc, Ce, verificated, get):
    if e < THRESHOLD:
        Cenc.append(Ce)
        verificated.append(get)


def check_group_verification(e, Cenc, Crec, verificated, gets):
    if e < THRESHOLD:
        Cenc += [cr for cr in Crec]
        verificated += gets


def verificate_each(Ae, Be, Cr, get, Cenc, verificated, **kwargs):
    report, c = kwargs["report"], kwargs["c"]

    s = time.time()

    e = verificate(Ae=Ae, Be=Be, Ce=Cr, where=get)
    check_verification(e=e, Cenc=Cenc, Ce=Cr, verificated=verificated, get=get)

    c += 1
    utils.report_verificate(report=report, spent=time.time() - s, name=f"verification {c}")

    return c


def verificate_group(Ae, Be, Cr, Cenc, x, ids, brokens, verificated, **kwargs):
    report, c = kwargs["report"], kwargs["c"]

    s = time.time()

    e = verificate(Ae=Ae[ids].sum(0), Be=Be, Ce=Cr[ids].sum(0), where=f"group {x}")
    check_group_verification(e=e, Cenc=Cenc, Crec=Cr[ids], verificated=verificated, gets=[(y + x * len(Ae), [x, y]) for y in ids])

    if e > THRESHOLD:
        brokens += [(y + x * len(Ae), [x, y]) for y in ids]

    c += 1
    utils.report_verificate(report=report, spent=time.time() - s, name=f"verification {c}")

    return c


def verificate_group_all(Aenc, Benc, Crec, Cenc, verificated, n_required):
    for x in range(len(Aenc)):
        e = verificate(Ae=Aenc[x], Be=Benc[x], Ce=Crec[x].sum(0), where=f"group {x}")
        check_group_verification(e=e, Cenc=Cenc, Crec=Crec[x], verificated=verificated, gets=[(y + x * len(Crec[x]), [x, y]) for y in range(len(Crec[x]))])

        if len(verificated) >= n_required:
            break
    pass


def verificate_group_remains(Aenc, Benc, Crec, Cenc, brokens, verificated, n_required):
    for get in brokens:
        i, [x, y] = get
        e = verificate(Ae=Aenc[x, y], Be=Benc[x], Ce=Crec[x, y], where=get)
        check_verification(e=e, Cenc=Cenc, Ce=Crec[x, y], verificated=verificated, get=get)

        if len(verificated) >= n_required:
            return 1
    return 0


def verificate_group_all_remains(Aenc, Benc, Crec, Cenc, verificated, n_required):
    n_groups, m = len(Aenc), len(Aenc[0])
    groups = find_failed_groups(verificated=verificated, n_groups=n_groups)

    for x in groups:
        for y in range(m):
            get = (x * m + y, [x, y])
            e = verificate(Ae=Aenc[x, y], Be=Benc[x], Ce=Crec[x, y], where=get)
            check_verification(e=e, Cenc=Cenc, Ce=Crec[x, y], verificated=verificated, get=get)

            if len(verificated) >= n_required:
                return 1
    return 0


def find_failed_groups(verificated, n_groups):
    groups_done = set()
    for i, [x, y] in verificated:
        groups_done.add(x)

    groups = []
    for x in range(n_groups):
        if x not in groups_done:
            groups.append(x)

    return groups
