# --- algorithms.py ---
import numpy as np
from objective import obj_function
from gossip import FastMix, AccGossip, random_sample_index

def run_gteg(a, b, W, z_init, z_star, run_type, run_max, n):
    import numpy as np
    from objective import obj_function

    m, dim = z_init.shape
    J = np.ones((m, m)) / m
    M = W - J
    rho = np.linalg.norm(M, 2)
    eta = (1 - rho) ** 2
    iters = 100000

    grad_m = np.zeros((m, dim))
    s2 = np.zeros((m, dim))
    z = z_init.copy()

    loss0 = 0
    for j in range(m):
        loss0 += np.linalg.norm(z[j] - z_star) / m
    loss_gteg = [loss0]
    LIFO_gteg = [0]
    time_gteg = [0]
    communication_gteg = [0]

    for i in range(iters):
        z_mid = z - eta * s2
        grad_new = np.zeros((m, dim))
        for j in range(m):
            grad_new[j] = obj_function(a[j], b[j]).grad(z_mid[j])
        s1 = s2 + grad_new - grad_m
        z_next = W @ z - eta * s1

        grad_m_new = np.zeros((m, dim))
        for j in range(m):
            grad_m_new[j] = obj_function(a[j], b[j]).grad(z_next[j])
        s2 = W @ s2 + grad_m_new - grad_m

        z = z_next
        grad_m = grad_m_new

        loss_val = 0
        for j in range(m):
            loss_val += np.linalg.norm(z[j] - z_star) / m
        loss_gteg.append(loss_val)
        LIFO_gteg.append(LIFO_gteg[-1] + 2 * int(n/m) * m)
        time_gteg.append(time_gteg[-1] + 2 * int(n/m))
        communication_gteg.append(communication_gteg[-1] + 2)

        if run_type == 'sample' and (loss_val < 1e-10 or LIFO_gteg[-1] > run_max):
            break
        if run_type == 'time' and (loss_val < 1e-10 or time_gteg[-1] > run_max):
            break
        if run_type == 'communication' and (loss_val < 1e-10 or communication_gteg[-1] > run_max):
            break

        if i % 10 == 0:
            print('GT-EG iters', i, 'Loss', loss_val)

    return loss_gteg, LIFO_gteg, time_gteg, communication_gteg


def run_mcsvre(a, b, W, z_init, z_star, run_type, run_max, n):
    import numpy as np
    from objective import obj_function
    from gossip import FastMix, random_sample_index

    m, dim = z_init.shape
    iters = 50000
    eta = 0.1
    p = 0.02
    alpha = 0.9
    K0 = 20
    K = 20

    w = z_init.copy()
    z = z_init.copy()
    v_old = np.zeros((m, dim))
    s_old = np.zeros((m, dim))

    grad = np.zeros((m, dim))
    for j in range(m):
        func = obj_function(a[j], b[j])
        grad[j] = func.grad(z[j])
    v = FastMix(grad, W, K0)
    s = v.copy()

    loss0 = 0
    for j in range(m):
        loss0 += np.linalg.norm(z[j] - z_star) / m
    loss_mcsvre = [loss0]
    LIFO_mcsvre = [int(n/m) * m]
    time_mcsvre = [int(n/m)]
    communication_mcsvre = [K0]

    for i in range(iters):
        z_prime = alpha * z + (1 - alpha) * w

        s = s_old + (v - v_old)
        s = FastMix(s, W, K)

        z_mid = z_prime - eta * s
        z_mid = FastMix(z_mid, W, K)

        v_half = np.zeros((m, dim))
        for j in range(m):
            idx = random_sample_index(int(n/m), 1)
            func = obj_function(a[j][idx], b[j][idx])
            v_half[j] = v[j] + (func.grad(z_mid[j]) - func.grad(w[j]))

        s_half = s + (v_half - v)
        s_half = FastMix(s_half, W, K)

        z_next = z_prime - eta * s_half
        z_next = FastMix(z_next, W, K)

        loss_val = 0
        for j in range(m):
            loss_val += np.linalg.norm(z_next[j] - z_star) / m
        loss_mcsvre.append(loss_val)

        if np.random.binomial(1, p) == 1:
            w = z_next.copy()
            grad_full = np.zeros((m, dim))
            for j in range(m):
                func = obj_function(a[j], b[j])
                grad_full[j] = func.grad(z_next[j])
            v_next = grad_full.copy()

            s_next = s + (v_next - v)
            s_next = FastMix(s_next, W, K)

            z = z_next.copy()
            v_old = v.copy()
            s_old = s.copy()
            v = v_next.copy()
            s = s_next.copy()

            LIFO_mcsvre.append(LIFO_mcsvre[-1] + int(n/m) * m + 2 * m)
            time_mcsvre.append(time_mcsvre[-1] + int(n/m) + 2)
            communication_mcsvre.append(communication_mcsvre[-1] + 5 * K)
        else:
            z = z_next.copy()
            v_old = v.copy()
            s_old = s.copy()

            LIFO_mcsvre.append(LIFO_mcsvre[-1] + 2 * m)
            time_mcsvre.append(time_mcsvre[-1] + 2)
            communication_mcsvre.append(communication_mcsvre[-1] + 4 * K)

        if run_type == 'sample' and (loss_val < 1e-10 or LIFO_mcsvre[-1] > run_max):
            break
        if run_type == 'time' and (loss_val < 1e-10 or time_mcsvre[-1] > run_max):
            break
        if run_type == 'communication' and (loss_val < 1e-10 or communication_mcsvre[-1] > run_max):
            break

        if i % 100 == 0:
            print('MC-SVRE iters', i, 'Loss', loss_val)

    return loss_mcsvre, LIFO_mcsvre, time_mcsvre, communication_mcsvre


def run_oadsvi(a, b, L, z_init, z_star, run_type, run_max, n):
    import numpy as np
    from objective import obj_function
    from gossip import AccGossip, random_sample_index

    m, dim = z_init.shape

    iters = 50000
    batch = 3
    eta = 0.05
    p = 0.01
    theta = 2
    alpha = 0.9
    beta = 0.25
    gamma = 0.01
    K = 20

    y = np.zeros((m, dim))
    y_old = y.copy()
    z = z_init.copy()
    z_old = z.copy()
    w = z.copy()
    w_old = w.copy()
    grad_w_old = np.zeros((m, dim))
    for j in range(m):
        func = obj_function(a[j], b[j])
        grad_w_old[j, :] = func.grad(w_old[j, :])
    grad_w = grad_w_old.copy()

    loss_oadsvi = []
    LIFO_oadsvi = []
    time_oadsvi = []
    communication_oadsvi = []

    loss0 = 0
    for j in range(m):
        loss0 += np.linalg.norm(z[j] - z_star) / m
    loss_oadsvi.append(loss0)
    LIFO_oadsvi.append(0)
    time_oadsvi.append(0)
    communication_oadsvi.append(0)

    zeta = 0

    for i in range(iters):
        grad_delta = np.zeros((m, dim))
        for j in range(m):
            samples = random_sample_index(int(n/m), batch)
            func = obj_function(a[j][samples], b[j][samples])
            grad_delta[j, :] = (
                func.grad(z[j, :])
                - func.grad(w_old[j, :])
                + alpha * (func.grad(z[j, :]) - func.grad(z_old[j, :]))
            )
        Delta = grad_delta + grad_w_old - (y + alpha * (y - y_old))

        if zeta == 1:
            grad_w_old = grad_w.copy()

        z_next = z + gamma * (w - z) - eta * Delta

        grad_delta_half = np.zeros((m, dim))
        for j in range(m):
            samples = random_sample_index(int(n/m), batch)
            func = obj_function(a[j][samples], b[j][samples])
            grad_delta_half[j, :] = func.grad(z_next[j, :]) - func.grad(w[j, :])
        Delta_half = grad_w + grad_delta_half

        y_next = y - theta * AccGossip(
            z_next - beta * (Delta_half - y),
            L,
            K
        )

        zeta = np.random.binomial(1, p)
        if zeta == 1:
            w_next = z.copy()
            grad_w = np.zeros((m, dim))
            for j in range(m):
                func = obj_function(a[j], b[j])
                grad_w[j, :] = func.grad(z[j, :])
            LIFO_oadsvi.append(LIFO_oadsvi[-1] + int(n/m) * m + 5 * batch * m)
            time_oadsvi.append(time_oadsvi[-1] + int(n/m) + 5 * batch)
        else:
            w_next = w.copy()
            LIFO_oadsvi.append(LIFO_oadsvi[-1] + 5 * batch * m)
            time_oadsvi.append(time_oadsvi[-1] + 5 * batch)

        z_old = z.copy()
        w_old = w.copy()
        y_old = y.copy()
        z = z_next.copy()
        w = w_next.copy()
        y = y_next.copy()

        loss_val = 0
        for j in range(m):
            loss_val += np.linalg.norm(z[j] - z_star) / m
        loss_oadsvi.append(loss_val)
        communication_oadsvi.append(communication_oadsvi[-1] + K)

        if run_type == 'sample':
            if loss_oadsvi[-1] < 1e-10 or LIFO_oadsvi[-1] > run_max:
                break
        if run_type == 'time':
            if loss_oadsvi[-1] < 1e-10 or time_oadsvi[-1] > run_max:
                break
        if run_type == 'communication':
            if loss_oadsvi[-1] < 1e-10 or communication_oadsvi[-1] > run_max:
                break

        if i % 100 == 0:
            print('OADSVI iters', i, 'Loss', loss_oadsvi[-1])

    return loss_oadsvi, LIFO_oadsvi, time_oadsvi, communication_oadsvi


def run_diverse(a, b, W, z_init, z_star, run_type, run_max, n):
    import numpy as np
    from objective import obj_function
    from gossip import FastMix, random_sample_index

    m, dim = z_init.shape

    loss0 = 0
    for j in range(m):
        loss0 += np.linalg.norm(z_init[j] - z_star) / m

    iters = 10000
    batch = 128
    eta = 0.1
    p = 0.01
    q = batch / n
    alpha = 0.9
    beta = p
    K0 = 20
    Kz = 20
    Ks = 20
    Kv = 20

    grad_m = np.zeros((m, dim))
    grad_v = np.zeros((m, dim))
    s = np.zeros((m, dim))
    z = np.copy(z_init)
    z_old = np.copy(z_init)
    v = np.copy(z_init)
    v_old = np.copy(z_init)

    loss_diverse = [loss0]
    LIFO_diverse = [0]
    time_diverse = [0]
    communication_diverse = [0]

    for i in range(iters):
        if i == 0:
            for j in range(m):
                func = obj_function(a[j], b[j])
                grad_m[j, :] = func.grad(z[j, :])
            grad_v = np.copy(grad_m)

            s = FastMix(grad_m, W, K0)
            z = FastMix(z - eta * s, W, Kz)
            zeta = np.random.binomial(1, p)
            if zeta == 1:
                v = FastMix(z_old, W, Kv)

            
            loss = 0
            for j in range(m):
                loss += np.linalg.norm(z[j, :] - z_star) / m
            loss_diverse.append(loss)
            LIFO_diverse.append(LIFO_diverse[-1] + int(n/m) * m)
            time_diverse.append(time_diverse[-1] + int(n/m))
            communication_diverse.append(communication_diverse[-1] + K0 + Kz)

        else:
            if zeta == 1:
                grad_m_new = np.zeros((m, dim))
                grad_v_new = np.zeros((m, dim))
                for j in range(m):
                    func = obj_function(a[j], b[j])
                    grad_v_new[j, :] = func.grad(z_old[j, :])

                count_batch_size = []
                for j in range(m):
                    bs = np.random.binomial(int(n/m), q)
                    count_batch_size.append(bs)
                    samples = random_sample_index(int(n/m), bs)
                    func = obj_function(a[j][samples], b[j][samples])
                    grad_m_new[j, :] = (
                        grad_v[j, :]
                        + (func.grad(z[j, :]) - func.grad(v_old[j, :])
                           + alpha * (func.grad(z[j, :]) - func.grad(z_old[j, :])))
                          * bs * m / batch
                    )
                    s[j, :] += grad_m_new[j, :] - grad_m[j, :]
                grad_m = grad_m_new
                grad_v = grad_v_new

                LIFO_diverse.append(
                    LIFO_diverse[-1] + int(n/m) * m + 3 * sum(count_batch_size)
                )
                time_diverse.append(
                    time_diverse[-1] + int(n/m) + 3 * max(count_batch_size)
                )

            elif zeta == 0:
                grad_m_new = np.zeros((m, dim))
                count_batch_size = []
                for j in range(m):
                    bs = np.random.binomial(int(n/m), q)
                    count_batch_size.append(bs)
                    samples = random_sample_index(int(n/m), bs)
                    func = obj_function(a[j][samples], b[j][samples])
                    grad_m_new[j, :] = (
                        grad_v[j, :]
                        + (func.grad(z[j, :]) - func.grad(v_old[j, :])
                           + alpha * (func.grad(z[j, :]) - func.grad(z_old[j, :])))
                          * bs * m / batch
                    )
                    s[j, :] += grad_m_new[j, :] - grad_m[j, :]
                grad_m = grad_m_new

                LIFO_diverse.append( LIFO_diverse[-1] + 3 * sum(count_batch_size) )
                time_diverse.append( time_diverse[-1] + 3 * max(count_batch_size) )

            s = FastMix(s, W, Ks)
            z_new = FastMix((1-beta)*z + beta*v - eta*s, W, Kz)
            z_old = np.copy(z)
            v_old = np.copy(v)
            zeta = np.random.binomial(1, p)
            if zeta == 1:
                v = FastMix(z, W, Kv)
            z = np.copy(z_new)

            loss = 0
            for j in range(m):
                loss += np.linalg.norm(z[j, :] - z_star) / m
            loss_diverse.append(loss)
            communication_diverse.append( communication_diverse[-1] + Ks + Kz )

            if run_type == 'sample':
                if loss_diverse[-1] < 1e-10 or LIFO_diverse[-1] > run_max:
                    break
            elif run_type == 'time':
                if loss_diverse[-1] < 1e-10 or time_diverse[-1] > run_max:
                    break
            else:
                if loss_diverse[-1] < 1e-10 or communication_diverse[-1] > run_max:
                    break

        if i % 100 == 0:
            print('DIVERSE iters', i, 'Loss', loss_diverse[-1])

    return loss_diverse, LIFO_diverse, time_diverse, communication_diverse

