#!/usr/bin/env python3
import typing

import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
from scipy.linalg import eigh
from sim import compute_delta_and_L_and_mu, generate_delta_related_matrices

Vector = npt.NDArray
Matrix = npt.NDArray


plt.rcParams.update({"font.size": 16})


class Loss:
    def __init__(self, l2: float, *args, **kwargs):
        self.l2 = l2
        self.dim = 0

    def grad(self, x: Vector) -> Vector:
        pass

    def prox(self, x: Vector, x0: Vector, eta: float) -> Vector:
        pass

    def smoothness(self) -> float:
        pass

    def get_solution(self) -> Vector:
        pass


class RegularizedLinearRegression(Loss):
    # Loss is: 1/2 x^T A x - b^T x + l2/2 x^T x
    def __init__(self, l2: float, A: Matrix, b: Vector, *args, **kwargs):
        super(RegularizedLinearRegression, self).__init__(l2, *args, **kwargs)
        self.A = A
        if self.A.shape[0] != self.A.shape[1]:
            print("Error! Array is not square")
        self.dim = self.A.shape[0]
        self.b = b

    def grad(self, x: Vector) -> Vector:
        gradient = np.matmul(self.A, x) - self.b + self.l2 * x
        return gradient

    def prox(self, x: Vector, x0: Vector, eta: float) -> Vector:
        d = self.A.shape[0]
        inv_eta = 1.0 / eta if eta != 0 else 0
        mtx = self.A + np.identity(d) * (self.l2 + inv_eta)
        c = self.b + x0 * inv_eta
        prox_eval = np.matmul(np.linalg.inv(mtx), c)
        return prox_eval

    def smoothness(self):
        evals = eigh(self.A, eigvals_only=True)
        return np.max(evals) + self.l2

    def get_solution(self):
        d = self.A.shape[0]
        return self.prox(np.zeros(d), np.zeros(d), 0)


class Tracer:
    def __init__(self, plot_frequency: int, loss: Loss, plot_max_comm: int = 10000):
        self.plot_frequency = plot_frequency
        self.plot_max = plot_max_comm
        self.iterate_idxs = []
        self.iterate_values = []
        self.loss = loss

    def update_trace(self, x: Vector, it: int):
        if self.plot_frequency == 1 or it % self.plot_frequency == 0:
            self.iterate_idxs.append(it)
            self.iterate_values.append(x.copy())

    def plot(self, label):
        markevery = max(1, int(len(self.iterate_values) / 20))
        x_star = self.loss.get_solution()
        dists = [np.linalg.norm(x - x_star) ** 2 for x in self.iterate_values]
        max_id = 0
        for n_comm in self.iterate_idxs:
            if n_comm <= self.plot_max:
                max_id += 1
        iterate_ids_to_plot = self.iterate_idxs[: max_id + 1]
        dists_to_plot = dists[: max_id + 1]
        plt.plot(
            iterate_ids_to_plot,
            dists_to_plot,
            label=label,
            marker=".",
            markevery=markevery,
        )
        plt.ylabel(r"$\Vert x-x^*\Vert^2$")
        plt.xlabel(r"Communications")


def sanity_check():
    A = np.array([[2, 0], [0, 1]])
    b = np.array([1, 2])
    l2 = 5
    loss = RegularizedLinearRegression(l2, A, b)
    x = np.array([-1, 1])
    prox_eval = loss.prox(x, x, 1)
    print(prox_eval)


def define_losses(
    num_agents: int, dimension: int, l2_reg: float, L_smoothness: float, delta: float
):
    center_mtx, matrices, smoothness, delta, mu = generate_delta_related_matrices(
        num_agents, dimension, 0, L_smoothness, delta
    )
    losses = []
    b_ms = []
    for mtx in matrices:
        b_m = np.random.rand(dimension) * 100
        loss_m = RegularizedLinearRegression(l2_reg, mtx, b_m)
        b_ms.append(b_m)
        losses.append(loss_m)
    b = np.zeros(dimension)
    for b_m in b_ms:
        b += b_m / num_agents
    total_loss = RegularizedLinearRegression(l2_reg, center_mtx, b)
    return losses, total_loss, delta


def gd(loss, step_size, num_steps, num_agents, tracer: Tracer):
    x = np.zeros(loss.dim)
    tracer.update_trace(x, 0)
    for i in range(num_steps):
        x -= step_size * loss.grad(x)
        tracer.update_trace(x, i * num_agents)
    return x


def sgd(losses, step_size, num_steps):
    num_agents = len(losses)
    x = np.zeros(losses[0].dim)
    for i in range(num_steps):
        s = np.random.randint(low=0, high=num_agents)
        x -= step_size * losses[s].grad(x)
    return x


def svrg(losses, total_loss, prob, step_size, num_steps, tracer: Tracer):
    num_recomm = 0
    num_agents = len(losses)
    dim = losses[0].dim
    x = np.zeros(dim)
    w = x.copy()
    w_full_grad = total_loss.grad(w)
    tracer.update_trace(x, 0)
    eff_length = int(1.0 / prob)
    num_steps = int(num_steps / eff_length)
    comm_ctr = 0
    for i in range(num_steps):
        for j in range(eff_length):
            s = np.random.randint(low=0, high=num_agents)
            g = losses[s].grad(x) - losses[s].grad(w) + w_full_grad
            x = x - step_size * g
            tracer.update_trace(x, comm_ctr)
            comm_ctr += 1
        w = x.copy()
        w_full_grad = total_loss.grad(w)
        num_recomm += 1
        comm_ctr += num_agents
    return x, num_recomm


def sppm(losses, step_size, num_steps):
    num_agents = len(losses)
    x = np.zeros(losses[0].dim)
    for i in range(num_steps):
        s = np.random.randint(low=0, high=num_agents)
        x = losses[s].prox(x, x, step_size)
    return x


def svrp(losses, total_loss, prob, step_size, num_steps, tracer: Tracer):
    num_recomm = 0
    num_agents = len(losses)
    dim = losses[0].dim
    x = np.zeros(dim)
    w = x.copy()
    w_full_grad = total_loss.grad(w)
    tracer.update_trace(x, 0)
    for i in range(num_steps):
        s = np.random.randint(low=0, high=num_agents)
        g = w_full_grad - losses[s].grad(w)
        x_start = x - step_size * g
        x = losses[s].prox(x_start, x_start, step_size)
        tracer.update_trace(x, i + num_recomm * num_agents)
        should_recomm = np.random.binomial(1, prob)
        if should_recomm == 1:
            w = x.copy()
            w_full_grad = total_loss.grad(w)
            num_recomm += 1
    return x, num_recomm


def scaffold(
    losses,
    total_loss,
    step_size,
    num_steps,
    num_local_steps,
    local_step_size,
    tracer: Tracer,
):
    num_agents = len(losses)
    dim = losses[0].dim
    x = np.zeros(dim)
    c = np.zeros(dim)
    c_m = [np.zeros(dim) for loss in losses]
    tracer.update_trace(x, 0)
    for i in range(num_steps):
        s = np.random.randint(low=0, high=num_agents)
        y = x.copy()
        for k in range(num_local_steps):
            y -= local_step_size * (losses[s].grad(y) - c_m[s] + c)
        c_new = c_m[s] - c + (x - y) / (num_local_steps * local_step_size)
        dx = y - x
        dc = c_new - c_m[s]
        x = x + step_size * dx
        c = c + (1 / num_agents) * dc
        c_m[s] = c_new
        tracer.update_trace(x, i)
    return x


def accelerated_extragradient(
    losses: typing.List[Loss],
    total_loss: Loss,
    tau,
    theta,
    eta,
    alpha,
    num_steps,
    tracer: Tracer,
):
    num_agents = len(losses)
    dim = losses[0].dim
    loss_q = losses[0]
    # loss_p = total_loss - loss_q
    xf = np.zeros(dim)
    x = np.zeros(dim)
    tracer.update_trace(x, 0)
    for i in range(num_steps):
        xg = tau * x + (1 - tau) * xf
        nabla_p = total_loss.grad(xg) - loss_q.grad(xg)
        query_point = xg - theta * nabla_p
        xf = loss_q.prox(query_point, query_point, theta)
        x = x + eta * alpha * (xf - x) - eta * total_loss.grad(xf)
        tracer.update_trace(x, i * num_agents)
    return x


# def catalyzed_svrp(
#     losses: typing.List[Loss], total_loss: Loss, gamma, eta, num_steps, tracer: Tracer
# ):
#     dim = losses[0].dim
#     x = np.zeros(dim)
#     # bla bla
#     return x


def calculate_max_num_steps(comm, max_comm):
    return np.min([comm, max_comm])


def run_algorithms(
    losses: typing.List[Loss],
    total_loss: Loss,
    delta: float,
    file_save_name: str,
    dim: int,
    max_num_communications: int,
    num_agents: int,
):
    # for loss in losses:
    #     print(loss.prox(np.zeros(dim), np.zeros(dim), 0))
    x_star = total_loss.prox(np.zeros(dim), np.zeros(dim), 0)
    # print(x_star)
    sigma_star = np.mean([np.linalg.norm(loss.grad(x_star)) ** 2 for loss in losses])
    print("sigma_* = {}".format(sigma_star))
    # gd_tracer = Tracer(1, total_loss)
    max_smoothness = np.max([loss.smoothness() for loss in losses])
    l2_reg = total_loss.l2
    # RUN GD
    # num_steps = calculate_max_num_steps(
    #     int(10 * total_loss_smoothness / l2_reg),
    #     int(max_num_communications / num_agents),
    # )
    # x_gd = gd(total_loss, 1 / total_loss_smoothness, num_steps, num_agents, gd_tracer)
    # print(
    #     "GD Accuracy: {} after {} steps and {} communication steps".format(
    #         np.linalg.norm(x_gd - x_star), num_steps, num_steps * num_agents
    #     )
    # )
    # RUN SVRG
    svrg_tracer = Tracer(1, total_loss)
    num_steps = max_num_communications * 2
    # calculate_max_num_steps(
    #     int(max_smoothness / l2_reg + num_agents) * 10, max_num_communications
    # )
    x_svrg, num_recomm = svrg(
        losses,
        total_loss,
        1 / num_agents,
        1 / (2 * max_smoothness),
        num_steps,
        svrg_tracer,
    )
    print(
        "SVRG Accuracy: {} after {} steps and {} communication steps".format(
            np.linalg.norm(x_svrg - x_star),
            num_steps,
            num_steps + num_recomm * num_agents,
        )
    )
    # Run SCAFFOLD
    scaffold_tracer = Tracer(1, total_loss)
    global_step_size = 1
    num_local_steps = 50
    local_step_size = np.min(
        [
            1 / (max_smoothness * num_local_steps * global_step_size),
            1 / (l2_reg * num_agents * num_local_steps * global_step_size),
        ]
    )
    num_steps = max_num_communications
    # calculate_max_num_steps(
    #     int((num_agents + max_smoothness / l2_reg)) * 10, max_num_communications
    # )
    x_scaffold = scaffold(
        losses,
        total_loss,
        global_step_size,
        num_steps,
        num_local_steps,
        local_step_size,
        scaffold_tracer,
    )
    print(
        "SCAFFOLD Accuracy: {} after {} steps and {} communication steps".format(
            np.linalg.norm(x_scaffold - x_star),
            num_steps,
            num_steps,
        )
    )
    # Run SVRP
    svrp_tracer = Tracer(1, total_loss)
    # svrp_practical_tracer = Tracer(1, total_loss)
    step_size = l2_reg / (2 * delta**2)
    num_steps = max_num_communications
    # calculate_max_num_steps(
    #     int(delta**2 / l2_reg**2 + num_agents) * 10, max_num_communications
    # )
    # int(num_agents + (delta / l2_reg) + 1) * 100
    x_svrp, num_recomm = svrp(
        losses, total_loss, 1 / num_agents, step_size, num_steps, svrp_tracer
    )
    # x_svrp, num_recomm = svrp(
    #     losses, total_loss, 1 / num_agents, 1 / delta, num_steps, svrp_practical_tracer
    # )
    print(
        "SVRP Accuracy: {} after {} steps and {} communication steps".format(
            np.linalg.norm(x_svrp - x_star),
            num_steps,
            num_steps + num_recomm * num_agents,
        )
    )
    # Run ACEG
    acc_exg_tracer = Tracer(1, total_loss)
    tau = np.min([1, np.sqrt(l2_reg / delta) / 2])
    theta = 1 / (2 * delta)
    eta = np.min([1 / (2 * l2_reg), 1 / (2 * np.sqrt(l2_reg * delta))])
    alpha = l2_reg
    num_steps = int(max_num_communications / num_agents) + 1
    # num_steps = calculate_max_num_steps(
    #     num_steps, int(max_num_communications / num_agents)
    # )
    x_acc_exg = accelerated_extragradient(
        losses, total_loss, tau, theta, eta, alpha, num_steps, acc_exg_tracer
    )
    print(
        "Accelerated Extragradient Accuracy: {} after {} steps and {} communication steps".format(
            np.linalg.norm(x_acc_exg - x_star),
            num_steps,
            num_steps * num_agents,
        )
    )
    # Plot the results
    plt.figure(figsize=(9, 6))
    # gd_tracer.plot(label="GD")
    svrg_tracer.plot(label="SVRG")
    acc_exg_tracer.plot(label="Accelerated Extragradient")
    svrp_tracer.plot(label="SVRP")
    # svrp_practical_tracer.plot(label="SVRP (1/delta stepsize)")
    scaffold_tracer.plot(label="SCAFFOLD")
    plt.yscale("log")
    plt.legend()
    # plt.show()
    # plt.draw()
    plt.savefig(file_save_name)
    print("\n")


def synthetic_data_experiment():
    num_agents = 1000
    dim = 100
    l2_reg = 1
    max_num_communications = 10000
    losses, total_loss, delta = define_losses(
        num_agents, dim, l2_reg, L_smoothness=3000, delta=20
    )
    run_algorithms(
        losses,
        total_loss,
        delta,
        "sen_exp1.pdf",
        dim,
        max_num_communications,
        num_agents,
    )
    num_agents = 2000
    losses, total_loss, delta = define_losses(
        num_agents, dim, l2_reg, L_smoothness=3000, delta=20
    )
    run_algorithms(
        losses,
        total_loss,
        delta,
        "sen_exp2.pdf",
        dim,
        max_num_communications,
        num_agents,
    )
    num_agents = 3000
    losses, total_loss, delta = define_losses(
        num_agents, dim, l2_reg, L_smoothness=3000, delta=20
    )
    run_algorithms(
        losses,
        total_loss,
        delta,
        "sen_exp3.pdf",
        dim,
        max_num_communications,
        num_agents,
    )


import dsdl


def unison_shuffled_copies(a, b):
    assert len(a) == len(b)
    p = np.random.permutation(len(a))
    return a[p], b[p]


def generate_losses_from_real_data(
    l2_reg: float, num_agents: int, dataset_name: str = "a9a"
):
    ds = dsdl.load(dataset_name)
    X, y = ds.get_train()
    X = X.toarray()
    num_data_p = X.shape[0]
    num_data_per_agent = 2000
    num_stacks = int(np.math.ceil(num_agents * num_data_per_agent / num_data_p))
    X = np.vstack([X] * num_stacks)
    y = np.concatenate([y] * num_stacks)
    X, y = unison_shuffled_copies(X, y)

    total_num_data, dim = X.shape

    print("Number of data points per agent = {}".format(num_data_per_agent))

    loss_matrices = []
    loss_vecs = []
    losses = []
    start_idx = 0
    end_idx = num_data_per_agent
    for i in range(num_agents):
        Xi = X[start_idx:end_idx]
        yi = y[start_idx:end_idx]
        Ai = np.dot(Xi.T, Xi) / num_data_per_agent
        bi = np.dot(Xi.T, yi) * 2 / num_data_per_agent
        loss_i = RegularizedLinearRegression(l2_reg, Ai, bi)
        losses.append(loss_i)
        loss_matrices.append(Ai)
        loss_vecs.append(bi)
        start_idx += num_data_per_agent
        end_idx += num_data_per_agent
    avg_A = np.zeros_like(loss_matrices[0])
    avg_b = np.zeros_like(loss_vecs[0])
    for i in range(num_agents):
        Ai = loss_matrices[i]
        bi = loss_vecs[i]
        avg_A += Ai / num_agents
        avg_b += bi / num_agents
    total_loss = RegularizedLinearRegression(l2_reg, avg_A, avg_b)
    delta, L_smoothness, mu = compute_delta_and_L_and_mu(loss_matrices, avg_A, dim)
    return losses, total_loss, dim, delta, L_smoothness, mu


def real_data_experiment():
    l2_reg = 0.1
    start_c = 1
    dataset_name = "a9a"
    list_num_agents = [20, 40, 60]
    for num_agents in list_num_agents:
        (
            losses,
            total_loss,
            dim,
            delta,
            L_smoothness,
            mu,
        ) = generate_losses_from_real_data(l2_reg, num_agents, dataset_name)
        print(
            "L={}, delta={}, mu={}, lambda={}, L/delta = {}".format(
                L_smoothness, delta, mu, l2_reg, L_smoothness / delta
            )
        )
        max_num_communications = 10000
        fig_name = "real_exp{}.pdf".format(start_c)
        start_c += 1
        run_algorithms(
            losses,
            total_loss,
            delta,
            fig_name,
            dim,
            max_num_communications,
            num_agents,
        )
        print("\n")


if __name__ == "__main__":
    np.random.seed(200)
    synthetic_data_experiment()
    real_data_experiment()
