# -*- coding: utf-8 -*-
# @date: 20220429

"""
An implementation of SPL and SGD with Momentum to solve Blind Deconvolution Problem

"""
import time
from utils import *
from opt import mpOpt
from mpi4py import MPI
import argparse


ROOT = 0
DONE = 999999
NOT_DONE = 1

COMM = MPI.COMM_WORLD
SIZE = MPI.COMM_WORLD.Get_size()
RANK = MPI.COMM_WORLD.Get_rank()


class blOpt(mpOpt):

    def __init__(self, n_dim: int, gamma: numeric,
                 momentum: numeric = 0.0,
                 algorithm=OPT_METHOD_SGD):
        super(blOpt, self).__init__(n_dim=n_dim,
                                    gamma=gamma,
                                    momentum=momentum)
        self.algorithm = algorithm

    def iterate(self, g: np.ndarray) -> None:
        """
        :param g: Gradient components.
        For SGD method, g is of size  n
        For SPL method, g is of size  n + 1
        :return: None
        """

        # Update momentum term
        diff = self._beta * (self._x - self._x_old)
        self._y = self._x + diff
        self._x_old = self._x
        if self.algorithm == OPT_METHOD_SPL:
            obj = g[0]
            grad  = g[1:] / self._gamma
            delta = (obj + np.dot(grad, diff)) / self._gamma
            coeff = proj_onto_unit_box(- delta / (np.linalg.norm(grad)**2))
            if grad.dtype != numeric:
                grad = grad.astype(numeric)
            self._x = self._y + coeff * grad
        elif self.algorithm == OPT_METHOD_SGD:
            self._x = self._y - g / self._gamma
        else:
            raise NotImplementedError("Algorithm is not implemented")

        return

    def sync_opt(self, epoch, A_data, b_data, alg):
        pass

    def get_x(self) -> np.ndarray:
        return self._x


def train_sync_master(n_iter: int, opt: blOpt):
    d = opt.n
    peers = list(range(SIZE))
    peers.remove(SIZE)
    ave_g = np.zeros(d, dtype=numeric)
    gs = np.empty((SIZE, d), dtype=numeric)
    for i in range(n_iter):
        g = np.zeros(d, dtype=numeric)
        COMM.Gather(g, gs, root=ROOT)
        ave_g = gs[peers].mean(axis=0)
        opt.iterate(ave_g)
        COMM.Bcast(opt.get_x(), root=ROOT)


def train_sync_worker(n_iter: int, UVs: np.ndarray, bs: np.ndarray,
                      x: np.ndarray, alg=OPT_METHOD_SGD) -> np.ndarray:
    d = UVs.shape[1] if alg == OPT_METHOD_SGD else UVs.shape[1] + 1
    g = np.empty(d, dtype=numeric)
    gs = None

    for i in np.random.permutation(n_iter):
        uv = UVs[i]
        b = bs[i]
        if alg == OPT_METHOD_SGD:
            g = assemble_sub_gradient_blind(uv, b, x)
        else:
            g = assemble_obj_grad_blind(uv, b, x)
        COMM.Gather(g, gs, root=ROOT)
        COMM.Bcast(x, root=ROOT)

    return x


def train_async_master(n_iter: int, opt: blOpt):
    d = opt.n if opt.algorithm == OPT_METHOD_SGD else opt.n + 1
    peers = list(range(SIZE))
    peers.remove(ROOT)
    n_peers = len(peers)

    if RANK == ROOT:
        gg = np.empty((n_peers, d), dtype=numeric)

    requests = [MPI.REQUEST_NULL for i in peers]

    for i in range(n_peers):
        requests[i] = COMM.Irecv(gg[i], source=peers[i])

    n_master_rcv_epoch = 0
    n_active_workers = n_peers

    while n_active_workers > 0:
        list_received = MPI.Request.Waitsome(requests)
        for i in list_received:
            opt.iterate(gg[i])
            n_master_rcv_epoch += 1
            if n_master_rcv_epoch < n_iter:
                COMM.Send(opt.get_x(), dest=peers[i], tag=NOT_DONE)
                requests[i] = COMM.Irecv(gg[i], source=peers[i])
            else:
                COMM.Send(opt.get_x(), dest=peers[i], tag=DONE)
                n_active_workers -= 1


def train_async_worker(n_iter: int, Uvs: np.ndarray, bs: np.ndarray,
                       x: np.ndarray, alg=OPT_METHOD_SGD) -> np.ndarray:
    m = Uvs.shape[0]
    d = Uvs.shape[1] if alg == OPT_METHOD_SGD else Uvs.shape[1] + 1
    g = np.empty(d, dtype=numeric)
    info = MPI.Status()
    info.tag = NOT_DONE
    idx = 0
    perm = np.random.permutation(m)

    while info.tag == NOT_DONE:
        if idx >= m:
            perm = np.random.permutation(m)
            idx = 0

        uv = Uvs[perm[idx]]
        b = bs[perm[idx]]
        idx += 1

        if alg == OPT_METHOD_SGD:
            g = assemble_sub_gradient_blind(uv, b, x)
        else:
            g = assemble_obj_grad_blind(uv, b, x)

        if g.dtype != numeric:
            g = g.astype(numeric)

        COMM.Send(g, dest=ROOT)
        COMM.Recv(x, source=ROOT, tag=MPI.ANY_TAG, status=info)

    return x


def get_bl_data(m: int, n: int):
    if SIZE == 1:
        m_sample = m
    else:
        m_sample = int(m / (SIZE - 1))

    n_x = int(n / 2)
    if RANK == ROOT:
        opt_x = np.random.rand(n_x)
        opt_x /= np.linalg.norm(opt_x)
        opt_x = numeric(opt_x)
        opt_y = np.random.rand(n_x)
        opt_y /= np.linalg.norm(opt_y)
        opt_z = np.concatenate([opt_x, opt_y]).astype(numeric)
    else:
        opt_z = np.empty(n, dtype=numeric)

    COMM.Bcast(opt_z, root=ROOT)

    if RANK == ROOT and SIZE != 1:
        UV = None
        b = None
    else:
        UV = np.random.randn(m_sample, n)
        UV = UV.astype(numeric)
        b = np.multiply(UV[:, 0:n_x] @ opt_z[0:n_x],
                        UV[:, n_x:] @ opt_z[n_x:])
        b = b.astype(numeric)

    return UV, b, opt_z


parser = argparse.ArgumentParser()
parser.add_argument("--m", type=int, default=50000)
parser.add_argument("--n", type=int, default=20000)

parser.add_argument("--alpha", type=float, default=1.0)
parser.add_argument("--beta", type=float, default=0.0)
parser.add_argument("--alg", type=str, default="SGD")


def main():

    # Parse arguments
    args = parser.parse_args()
    m = args.m
    n = args.n
    alpha_0 = args.alpha
    beta = args.beta
    alg = args.alg

    if RANK == ROOT:
        print("Test begins on {0}".format(time.asctime()))
        print("Parameters m: %d  n: %d  alpha: %f  beta: %f  alg: %s" %
              (m, n, alpha_0, beta, alg))

    # Set parameters
    sync = False
    seed = RANK + 20220510
    np.random.seed(seed)
    epoch = 0
    n_epoch = 400
    n_iter_epoch = m
    n_iter_all = n_epoch * m
    gamma = np.sqrt(n_iter_all) / alpha_0

    # Generate data
    UV_data, b_data, opt_z = get_bl_data(m, n)

    COMM.Barrier()

    # Initial point
    if RANK == ROOT:
        z = init_bl(n, alg, True)
    else:
        z = init_bl(n, alg, False)

    if z.dtype != numeric:
        z = z.astype(numeric)
    COMM.Bcast(z, root=ROOT)

    bl_opt = blOpt(n, gamma, numeric(beta), alg)
    bl_opt.initialize(z)

    if RANK == ROOT:
        output_stream = "res_job_{6}_epoch_{0}_alpha_{1}_beta_{2}_m_{3}_n_{4}_{5}.txt"\
                        "".format(n_epoch, alpha_0, beta * 10, m, n, alg, SIZE - 1)

        f = open(output_stream, "a")
        output_str = "%12s  %12s  %12s  %12s" % ("Epoch", "||x-x^*||", "f - f^*", "Time")

        print(output_str, flush=True)
        f.write(output_str + "\n")

    if RANK == ROOT:
        total_time = 0.0
        time_start = time.time()
        test_time = time.time()
        peers = list(range(SIZE))
        peers.remove(ROOT)
        obj_gap = np.mean(0.0).astype(numeric)
        obj_gaps = np.empty(SIZE, dtype=numeric)
        COMM.Gather(obj_gap, obj_gaps, root=ROOT)
        obj_gaps_avg = obj_gaps[peers].mean(axis=0)
        nrm = np.minimum(np.linalg.norm(z - opt_z),
                         np.linalg.norm(z + opt_z))
        total_time += time.time() - test_time

        output_str = "%12d  %12e  %12e  %12e" % (epoch,
                                                 nrm,
                                                 obj_gaps_avg,
                                                 time.time() - time_start - total_time)
        f.write(output_str + "\n")
        print(output_str, flush=True)
    else:
        obj_gap = numeric(bl_obj(UV_data, b_data, z))
        obj_gaps = None
        COMM.Gather(obj_gap, obj_gaps, root=ROOT)

    while epoch < n_epoch:

        if RANK == ROOT:
            train_async_master(n_iter_epoch, bl_opt)
        else:
            z = train_async_worker(n_iter_epoch, UV_data, b_data, z, alg)
        epoch += 1

        if RANK == ROOT:
            z = bl_opt.get_x()
        COMM.Bcast(z, root=ROOT)

        if RANK == ROOT:
            test_time = time.time()
            COMM.Gather(obj_gap, obj_gaps, root=ROOT)
            obj_gaps_avg = obj_gaps[peers].mean(axis=0)
            nrm = np.minimum(np.linalg.norm(z - opt_z),
                             np.linalg.norm(z - opt_z))
            total_time += time.time() - test_time

            output_str = "%12d  %12e  %12e  %12e" % (epoch,
                                                     nrm,
                                                     obj_gaps_avg,
                                                     time.time() - time_start - total_time)
            f.write(output_str + "\n")
            print(output_str, flush=True)
        else:
            obj_gap = numeric(bl_obj(UV_data, b_data, z))
            COMM.Gather(obj_gap, obj_gaps, root=ROOT)


if __name__ == '__main__':
    main()
