import numpy as np


class Worker:
    def __init__(self, loss=None, loss_y=None, q=None, loss_coeff=1):
        self.loss = loss
        self.loss_y = loss_y

        self.scaffold_local_c = None
        self.h = None
        self.q = q
        self.x = None
        self.y = None
        self.five_gcs_u = None
        self.loss_coeff = loss_coeff
        self.loss_y_coeff = 1 if loss_coeff == 1 else 1 - loss_coeff  # todo check for r=0

    def run_gradskip_local(self, x,
                           p,
                           local_steps,
                           lr,
                           batch_size,
                           ):
        self.x = x * 1.

        if self.h is None:
            # initial step
            self.h = self.x * 0.

        f = False
        it_q = np.inf if self.q == 1. else np.random.geometric(1 - self.q)
        if it_q <= local_steps:
            local_steps = it_q - 1
            f = True

        for _ in range(local_steps):
            if batch_size is None:
                g = self.loss.gradient(self.x)
            else:
                g = self.loss.stochastic_gradient(
                    self.x, batch_size=batch_size)
            self.x -= lr * (g - self.h)

        if f:
            self.h = self.loss.gradient(self.x)
            local_steps += 1

        # self.steps.append(local_steps)
        # self.time.append(local_steps*self.grad_time + self.com_time)

        return self.x - lr * self.h / p

    def update_gradskip_params(self, x_av, p, lr,):

        self.h += p / lr * (x_av - self.x)

    def run_loco_local(self, x,
                       gamma,
                       local_steps,
                       batch_size,
                       ):
        if self.x is None:
            # initial step
            self.x = np.copy(x)
            self.y = np.copy(x)
            self.u = self.x * 0.
            self.v = self.x * 0.

        for _ in range(local_steps):
            if batch_size is None:
                g = self.loss_coeff * self.loss.gradient(self.x)
                g_y = self.loss_y_coeff * self.loss_y.gradient(self.y)
            else:
                g = self.loss_coeff * \
                    self.loss.stochastic_gradient(
                        self.x, batch_size=batch_size)
                g_y = self.loss_y_coeff * \
                    self.loss_y.stochastic_gradient(
                        self.y, batch_size=batch_size)

            self.x -= gamma * (g - self.u)
            self.y -= gamma * (g_y - self.v)

        return self.x, self.y

    def update_loco_params(self,
                           gamma, rho, chi,
                           p,
                           d, d_bar,
                           w,
                           ):

        self.x = (1.0 - rho) * self.x + rho * (self.y + d_bar)
        self.y += rho * d_bar

        temp = p * chi / (gamma * (1 + 2 * w))
        self.u += temp * (d_bar - d)
        self.v += temp * d_bar

    def run_tamuna_local(self, lr, local_steps):
        for i in range(local_steps):
            g = self.loss.gradient(self.x)
            self.x -= lr * (g - self.h)
        return self.x

    def update_tamuna_variates(self, lr, eta, server_compressed, worker_compressed):
        self.h += (eta / lr) * (server_compressed - worker_compressed)

    def run_compressed_scaffnew_local(self, lr, local_steps, q):
        for i in range(local_steps):
            g = self.loss.gradient(self.x)
            self.x -= lr * (g - self.h)
        return np.multiply(self.x, q)

    def update_compressed_scaffnew_variates(self, x, lr, eta, q, worker_result):
        self.h += (eta / lr) * (np.multiply(x, q) - worker_result)

    def run_scaffold(self, lr, local_steps, server_c):
        server_x = np.copy(self.x)

        for i in range(local_steps):
            g = self.loss.gradient(self.x)
            self.x -= lr * (g - self.scaffold_local_c + server_c)

        self.scaffold_local_c += (server_x - self.x) / \
            (local_steps * lr) - server_c
        return self.x

    def run_gd(self, lr, local_steps):
        for i in range(local_steps):
            self.x -= lr * self.loss.gradient(self.x)
        return self.x

    def run_diana(self, x, dual_lr, compressor):
        if self.h is None:
            self.h = np.zeros_like(x, dtype=np.float32)

        grad = self.loss.gradient(x)
        m = compressor.compress(grad - self.h)
        self.h += dual_lr * m
        return m

    def run_adiana(self, z, w, alpha, compressor):
        if self.h is None:
            self.h = np.zeros_like(z, dtype=np.float32)

        m_1 = compressor.compress(self.loss.gradient(z) - self.h)
        m_2 = compressor.compress(self.loss.gradient(w) - self.h)
        self.h += alpha * m_2
        return m_1, m_2

    def run_scaffnew(self, lr, local_steps):
        for _ in range(local_steps):
            g = self.loss.gradient(self.x)
            self.x -= lr * (g - self.h)
        return self.x

    def update_scaffnew_variates(self, p, lr, server_x):
        self.h += p / lr * (server_x - self.x)

    def local_solve(self, local_steps, mu, n_workers, local_solver_lr, dual_lr):
        self.y = np.copy(self.x)
        for i in range(local_steps):
            grad = (self.loss.gradient(self.y) - mu * self.y) / n_workers + \
                dual_lr * (self.y - self.x - self.five_gcs_u / dual_lr)
            if np.linalg.norm(grad) <= 1e-10:
                break
            self.y -= local_solver_lr * grad
        self.five_gcs_u = (self.loss.gradient(
            self.y) - mu * self.y) / n_workers
        return self.five_gcs_u

    def local_solve_cc(self, local_steps, mu, n_workers, local_solver_lr, dual_lr, cohort_size, compressor):
        self.y = np.copy(self.x)
        for _ in range(local_steps):
            grad = (self.loss.gradient(self.y) - mu * self.y) / n_workers + \
                dual_lr * (self.y - self.x - self.five_gcs_u / dual_lr)
            if np.linalg.norm(grad) <= 1e-10:
                break
            self.y -= local_solver_lr * grad

        u_bar = (self.loss.gradient(self.y) - mu * self.y) / n_workers
        compressed_difference = compressor.compress(u_bar - self.five_gcs_u)
        self.five_gcs_u += 1/(1 + compressor.w) * \
            (cohort_size / n_workers) * compressed_difference
        return compressed_difference
