import argparse
import random
import threading
import time

import numpy as np
from mpi4py import MPI

import decode
import encode
import utils
import van_inv
import verificate


def main(args):
    # prepare
    args.n_required = utils.get_n_required(m=args.m, n=args.n)
    A_enc_shape, B_enc_shape, C_rec_shape = utils.get_shape(dim1=args.dim1, dim2=args.dim2, dim3=args.dim3, m=args.m, n=args.n)

    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()
    size = comm.Get_size()

    # use the number of workers for mpirun instead of argument passed by parser.
    if size > 1:
        args.N = size - 1

    if rank == 0:
        # set path
        if args.save:
            args.path = utils.set_path(experiment=args.experiment, id=args.id, t_id=args.t_id)

        # send byzantine information
        byzantines = [True] * args.n_byzantine + [False] * (args.N - args.n_byzantine)
        random.shuffle(byzantines)
        args.byzantine = []
        for i, byzantine in enumerate(byzantines):
            comm.send(byzantine, dest=i + 1, tag=3)
            if byzantine:
                args.byzantine.append(f"worker {i + 1}")

        # send straggle information
        stragglers = list(np.random.uniform(0, 1, args.N) <= args.straggle_prob)
        args.stragglers = []
        for i, straggler in enumerate(stragglers):
            comm.send(straggler, dest=i + 1, tag=4)
            if straggler:
                args.stragglers.append(f"Worker {i + 1}")

        # set dictionary for logging
        args.computation_time = {}

        # get codebook, codebook: m * (N // m)
        s = time.time()
        codebook = utils.codebook(N=args.N, m=args.m)
        utils.report(report=args.computation_time, spent=time.time() - s, name="codebook")

        # make inputs, A: dim1 * dim2, B: dim2 * dim3 -> Compute C = A * B: dim1 * dim3
        A, B = utils.make_inputs(dim1=args.dim1, dim2=args.dim2, dim3=args.dim3)

        # split matrix
        Ap, Bp = utils.split_matrix(A=A, B=B, m=args.m, n=args.n)

        # encode input matrices
        s = time.time()
        Aenc, Benc = encode.encode_matrices(Ap=Ap, Bp=Bp, codebook=codebook)
        Aenc_s = Aenc.sum(1)

        utils.report(report=args.computation_time, spent=time.time() - s, name="encode")

        # send
        s = time.time()
        Cpre = []
        for i in range(args.N):
            Cpre.append(np.zeros(C_rec_shape))

        reqA, reqB, reqC = [], [], []
        for i in range(args.N):
            x, y = i // args.m, i % args.m
            reqA.append(comm.Isend(Aenc[x, y], dest=i + 1, tag=5))
            reqB.append(comm.Isend(Benc[x], dest=i + 1, tag=6))
            reqC.append(comm.Irecv(Cpre[i], source=i + 1, tag=10))

        MPI.Request.Waitall(reqA)
        MPI.Request.Waitall(reqB)

        comm.Barrier()

        utils.report(report=args.computation_time, spent=time.time() - s, name="send")

        # receive
        s = time.time()
        args.computation_time["verification"] = 0
        Crec, Cenc, args.gets, args.verificated, checklist, c = np.zeros(list(codebook.shape) + C_rec_shape), [], [], [], [[] for _ in range(len(codebook))], 0
        for count in range(args.N):
            i = MPI.Request.Waitany(reqC)
            x, y = i // args.m, i % args.m

            Crec[x, y] = Cpre[i]
            args.gets.append((i, [x, y]))
            checklist[x].append(y)

            if not args.verificate_all:
                if len(args.gets) == args.n_required:
                    brokens = []
                    for x in range(len(codebook)):
                        if len(checklist[x]) > 0:
                            c = verificate.verificate_group(Ae=Aenc[x], Be=Benc[x], Cr=Crec[x], Cenc=Cenc, x=x, ids=checklist[x], brokens=brokens, verificated=args.verificated, report=args.computation_time, c=c)
                elif len(args.gets) > args.n_required:
                    c = verificate.verificate_each(Ae=Aenc[x, y], Be=Benc[x], Cr=Crec[x, y], get=args.gets[-1], Cenc=Cenc, verificated=args.verificated, report=args.computation_time, c=c)

            if len(args.verificated) >= args.n_required:
                break

        if (not args.verificate_all) and (len(args.verificated) < args.n_required):
            s = time.time()
            verificate.verificate_group_remains(Aenc=Aenc, Benc=Benc, Crec=Crec, Cenc=Cenc, brokens=brokens, verificated=args.verificated, n_required=args.n_required)
            utils.report_verificate(report=args.computation_time, spent=time.time() - s, name="verification remains")

        utils.report(report=args.computation_time, spent=time.time() - s, name="receive")

        if args.verificate_all:
            s = time.time()
            verificate.verificate_group_all(Aenc=Aenc_s, Benc=Benc, Crec=Crec, Cenc=Cenc, verificated=args.verificated, n_required=args.n_required)
            if len(args.verificated) < args.n_required:
                verificate.verificate_group_all_remains(Aenc=Aenc, Benc=Benc, Crec=Crec, Cenc=Cenc, verificated=args.verificated, n_required=args.n_required)

            utils.report(report=args.computation_time, spent=time.time() - s, name="verification")

        print(f" *** Get {args.gets} from workers.")
        print(f" *** {args.verificated} are verificated.")

        # get van_inv
        s = time.time()
        Van_inv = van_inv.get_van_inv(codebook, gets=args.verificated, m=args.m, n=args.n)
        utils.report(report=args.computation_time, spent=time.time() - s, name="van inv")

        # decode received matrices
        s = time.time()
        C = decode.decode_matrix(Cenc, Van_inv, m=args.m, n=args.n)
        utils.report(report=args.computation_time, spent=time.time() - s, name="decode")

        MPI.Request.Waitall(reqC)

        for i in range(args.N):
            spent = comm.recv(source=i + 1, tag=8)
            args.computation_time[f"Worker {i + 1}"] = float(spent)

        # save logs
        if args.save:
            utils.save_logs(vars=vars(args), path=args.path)

        print(args)

    else:
        # receive byzantine information
        byzantine = comm.recv(source=0, tag=3)
        if byzantine:
            print(f"Worker {rank} is byzantine.")

        # receive straggler information
        straggler = comm.recv(source=0, tag=4)
        if straggler:
            print(f"Worker {rank} is straggler.")

        # receive
        Aenc, Benc = np.zeros(A_enc_shape), np.zeros(B_enc_shape)
        reqA = comm.Irecv(Aenc, source=0, tag=5)
        reqB = comm.Irecv(Benc, source=0, tag=6)

        reqA.wait()
        reqB.wait()

        comm.Barrier()

        # compute and send
        s = time.time()

        if straggler:
            thread = threading.Thread(target=utils.straggle)
            thread.start()

        reqC = []
        Cenc = utils.compute(a=Aenc, b=Benc)

        if byzantine:
            Cenc[0][0] = Cenc[0][0] + 1.0

        reqC.append(comm.Isend(Cenc, dest=0, tag=10))

        spent = format(time.time() - s, ".5f")
        print(f"Worker {rank} of {args.N} computation time: {spent} seconds.")

        MPI.Request.Waitall(reqC)

        comm.send(spent, dest=0, tag=8)

    pass


def get_args():
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--experiment", type=str, default="test", help="experiment id, one from [0], adjust N, m and n")
    parser.add_argument("--id", type=str, default="test", help="run id, adjust among dims and n_byzantine, etc")
    parser.add_argument("--t_id", type=str, default="test", help="run trial id for repetition")

    parser.add_argument("--dim1", type=int, default=600, help="height of input matrix A")
    parser.add_argument("--dim2", type=int, default=500, help="width of input matrix A, height of input matrix B")
    parser.add_argument("--dim3", type=int, default=600, help="width of input matrix B")

    parser.add_argument("--N", type=int, default=18, help="number of workers")
    parser.add_argument("--m", type=int, default=3, help="A division")
    parser.add_argument("--n", type=int, default=4, help="B division")

    parser.add_argument("--n_byzantine", type=int, default=1, help="number of byzantine workers")
    parser.add_argument("--straggle_prob", type=float, default=0.0, help="probabilitiy of straggle for each worker, 0 <= prob <= 1, 0 if there is no straggle")

    parser.add_argument("--verificate_all", action="store_true", help="whether to verificate all")

    parser.add_argument("--save", action="store_true", help="whether to save results")

    args = parser.parse_args()

    return args


if __name__ == "__main__":
    args = get_args()

    main(args)
