import numpy as np
import copy

class Worker:
    def __init__(self, loss=None, loss_y=None, q=None, loss_coeff=1):
        self.loss = loss
        self.loss_y = loss_y

        self.scaffold_local_c = None
        self.h = None
        self.q = q
        self.x = None
        self.y = None
        self.five_gcs_u = None
        self.loss_coeff = loss_coeff
        self.loss_y_coeff = 1 if loss_coeff == 1 else 1 - loss_coeff  # todo check for r=0

    def run_gradskip_local(self, x,
                           p,
                           local_steps,
                           lr,
                           batch_size,
                           ):
        self.x = x * 1.

        if self.h is None:
            # initial step
            self.h = self.x * 0.

        f = False
        it_q = np.inf if self.q == 1. else np.random.geometric(1 - self.q)
        if it_q <= local_steps:
            local_steps = it_q - 1
            f = True

        for _ in range(local_steps):
            if batch_size is None:
                g = self.loss.gradient(self.x)
            else:
                g = self.loss.stochastic_gradient(
                    self.x, batch_size=batch_size)
            self.x -= lr * (g - self.h)

        if f:
            self.h = self.loss.gradient(self.x)
            local_steps += 1

        # self.steps.append(local_steps)
        # self.time.append(local_steps*self.grad_time + self.com_time)

        return self.x - lr * self.h / p

    def update_gradskip_params(self, x_av, p, lr,):

        self.h += p / lr * (x_av - self.x)

    def run_loco_local(self, x,
                       gamma,
                       local_steps,
                       batch_size,
                       ):
        if self.x is None:
            # initial step
            self.x = np.copy(x)
            self.y = np.copy(x)
            self.u = self.x * 0.
            self.v = self.x * 0.

        for _ in range(local_steps):
            if batch_size is None:
                g = self.loss_coeff * self.loss.gradient(self.x)
                g_y = self.loss_y_coeff * self.loss_y.gradient(self.y)
            else:
                g = self.loss_coeff * \
                    self.loss.stochastic_gradient(
                        self.x, batch_size=batch_size)
                g_y = self.loss_y_coeff * \
                    self.loss_y.stochastic_gradient(
                        self.y, batch_size=batch_size)

            self.x -= gamma * (g - self.u)
            self.y -= gamma * (g_y - self.v)

        return self.x, self.y


    def update_bicolor_params(self,
                                gamma, rho, rho_y, eta, eta_y,
                                p, k, d,
                                c_worker, c_server,
                                nonzero_mask
                                ):

        self.x[nonzero_mask] = (1.0 - rho) * self.x[nonzero_mask] + rho * (c_server[nonzero_mask] + self.y[nonzero_mask])
        self.y += rho_y * c_server
        self.u -= p * k * eta / d / gamma * (c_worker - c_server)
        self.v += p * k * eta_y / d / gamma * c_server

    def update_loco_params(self,
                           gamma, rho, chi,
                           p,
                           d, d_bar,
                           w,
                           ):

        self.x = (1.0 - rho) * self.x + rho * (self.y + d_bar)
        self.y += rho * d_bar

        temp = p * chi / (gamma * (1 + 2 * w))
        self.u += temp * (d_bar - d)
        self.v += temp * d_bar

    def run_tamuna_local(self, lr, local_steps):
        for i in range(local_steps):
            g = self.loss.gradient(self.x)
            self.x -= lr * (g - self.h)
        return self.x

    def update_tamuna_variates(self, lr, eta, server_compressed, worker_compressed):
        self.h += (eta / lr) * (server_compressed - worker_compressed)

    def run_compressed_scaffnew_local(self, lr, local_steps, q):
        for i in range(local_steps):
            g = self.loss.gradient(self.x)
            self.x -= lr * (g - self.h)
        return np.multiply(self.x, q)

    def update_compressed_scaffnew_variates(self, x, lr, eta, q, worker_result):
        self.h += (eta / lr) * (np.multiply(x, q) - worker_result)

    def run_scaffold(self, lr, local_steps, server_c):
        server_x = np.copy(self.x)

        for i in range(local_steps):
            g = self.loss.gradient(self.x)
            self.x -= lr * (g - self.scaffold_local_c + server_c)

        self.scaffold_local_c += (server_x - self.x) / \
            (local_steps * lr) - server_c
        return self.x

    def run_gd(self, lr, local_steps):
        for i in range(local_steps):
            self.x -= lr * self.loss.gradient(self.x)
        return self.x

    def run_diana(self, x, dual_lr, compressor):
        if self.h is None:
            self.h = np.zeros_like(x, dtype=np.float32)

        grad = self.loss.gradient(x)
        m = compressor.compress(grad - self.h)
        self.h += dual_lr * m
        return m

    def run_adiana(self, z, w, alpha, compressor):
        if self.h is None:
            self.h = np.zeros_like(z, dtype=np.float32)

        m_1 = compressor.compress(self.loss.gradient(z) - self.h)
        m_2 = compressor.compress(self.loss.gradient(w) - self.h)
        self.h += alpha * m_2
        return m_1, m_2

    def run_scaffnew(self, lr, local_steps):
        for _ in range(local_steps):
            g = self.loss.gradient(self.x)
            self.x -= lr * (g - self.h)
        return self.x

    def update_scaffnew_variates(self, p, lr, server_x):
        self.h += p / lr * (server_x - self.x)

    def local_solve(self, local_steps, mu, n_workers, local_solver_lr, dual_lr):
        self.y = np.copy(self.x)
        for i in range(local_steps):
            grad = (self.loss.gradient(self.y) - mu * self.y) / n_workers + \
                dual_lr * (self.y - self.x - self.five_gcs_u / dual_lr)
            if np.linalg.norm(grad) <= 1e-10:
                break
            self.y -= local_solver_lr * grad
        self.five_gcs_u = (self.loss.gradient(
            self.y) - mu * self.y) / n_workers
        return self.five_gcs_u

    def local_solve_cc(self, local_steps, mu, n_workers, local_solver_lr, dual_lr, cohort_size, compressor):
        self.y = np.copy(self.x)
        for _ in range(local_steps):
            grad = (self.loss.gradient(self.y) - mu * self.y) / n_workers + \
                dual_lr * (self.y - self.x - self.five_gcs_u / dual_lr)
            if np.linalg.norm(grad) <= 1e-10:
                break
            self.y -= local_solver_lr * grad

        u_bar = (self.loss.gradient(self.y) - mu * self.y) / n_workers
        compressed_difference = compressor.compress(u_bar - self.five_gcs_u)
        self.five_gcs_u += 1/(1 + compressor.w) * \
            (cohort_size / n_workers) * compressed_difference
        return compressed_difference



class TwoDirection_Worker:
    def __init__(self, loss, compressor, point, bar_lipschitz, strongly_convex_constant, sum_gamma):
        self.loss = loss
        self._compressor = compressor
        self._bar_lipschitz = bar_lipschitz 
        self._strongly_convex_constant = strongly_convex_constant
        self._sum_gamma = sum_gamma
        self._z, self._w, self._y = point, point, point
        self._h = self.calculate_gradient(point)
        self._beta = 1 / (self._compressor.w + 1)

    def _init_k(self, gradient):
        self._k = gradient
        
    def get_uncompressed_message(self, y):
        g = self.loss.gradient(y)
        return g - - self._h

    def _calculate_message_at_point_y(self, theta):
        self._y = theta * self._w + (1 - theta) * self._z
        gradient = self.calculate_gradient(self._y)
        message = self._compressor.compress(gradient - self._h)
        return message

    def calculate_gradient(self, point):
        return self.loss.gradient(point)
    
    def _calculate_message_at_point_z(self, 
                                      primal_compressed_message, 
                                      primal_noncompressed_message,
                                      dual_noncompressed_message,
                                      gamma, next_sum_gamma):
        q = self._gradient_step(
            self._k, self._w, gamma, next_sum_gamma,
            self._sum_gamma, self._bar_lipschitz, self._strongly_convex_constant, self._y)
        if primal_noncompressed_message is not None:
            assert dual_noncompressed_message is not None
            self._z = primal_noncompressed_message
            self._k = dual_noncompressed_message
        # self._w = q + primal_compressed_message.decompress()
        self._w = q + primal_compressed_message
        gradient = self.calculate_gradient(self._z)
        message = self._compressor.compress(gradient - self._h)
        # self._h = self._h + self._beta * message.decompress()
        self._h = self._h + self._beta * message
        self._sum_gamma = next_sum_gamma
        return message

    def _gradient_step(self, dual_vector, previous_point, next_gamma, next_sum_gamma,
                       sum_gamma, bar_lipschitz, strongly_convex_constant, y):
        point =  ((bar_lipschitz + sum_gamma * strongly_convex_constant) * previous_point
                  + strongly_convex_constant * next_gamma * y
                  - next_gamma * dual_vector)
        point = point / (bar_lipschitz + next_sum_gamma * strongly_convex_constant)
        return point
    
    
class EF21P_DIANA_Worker:
    def __init__(self, loss, compressor, point, beta):
        self.loss = loss
        self._compressor = compressor
        self._h = np.zeros_like(point, dtype=np.float32)
        self._beta = beta

    def _calculate_message(self, w):
        gradient = self.loss.gradient(w)
        message = self._compressor.compress(gradient - self._h)
        self._h += self._beta * message
        return message
