import numpy as np


class Method(object):
    Scaffnew = 'scaffnew'
    CompressedScaffnew = 'compressed-scaffnew'
    Gd = 'gd'
    Scaffold = 'scaffold'


class Worker:
    def __init__(self, method=None, loss=None, batch_size=1):
        self.loss = loss
        self.scaffnew = False
        self.compressed_scaffnew = False
        self.gd = False
        self.scaffold = False

        if method == Method.Scaffnew:
            self.scaffnew = True
        elif method == Method.CompressedScaffnew:
            self.compressed_scaffnew = True
        elif method == Method.Gd:
            self.gd = True
        elif method == Method.Scaffold:
            self.scaffold = True
        else:
            raise ValueError(f'Unknown method {method}!')

        self.batch_size = batch_size
        self.c = None
        self.h = None
        self.old_q = None
        self.x_before_compressing = None
        self.x_before_averaging = None

    def run_local(self, x, lr, p=None, local_steps=None, q=None, eta=None, scaffold_c=None):
        self.x = np.copy(x)
        if self.scaffnew:
            self.run_scaffnew(lr=lr, local_steps=local_steps, p=p)
        elif self.compressed_scaffnew:
            return self.run_compressed_prox_skip(lr=lr, local_steps=local_steps, p=p, q=q, eta=eta)
        elif self.gd:
            self.run_gd(lr, local_steps)
        elif self.scaffold:
            self.run_scaffold(lr, local_steps, scaffold_c)
        return self.x

    def run_compressed_prox_skip(self, lr, local_steps, p, q, eta):
        if self.h is None:
            self.h = np.zeros_like(self.x)
        else:
            self.h += (p / lr) * eta * (np.multiply(self.x, self.old_q) - np.multiply(self.x_before_compressing, self.old_q))

        for i in range(local_steps):
            g = self.loss.gradient(self.x)
            self.x -= lr * (g - self.h)
        self.x_before_compressing = self.x
        self.old_q = np.copy(q)
        return np.multiply(self.x_before_compressing, q)

    def run_scaffnew(self, lr, local_steps, p):
        if self.h is None:
            self.h = self.x * 0.
        else:
            self.h += p / lr * (self.x - self.x_before_averaging)

        for i in range(local_steps):
            g = self.loss.gradient(self.x)
            self.x -= lr * (g - self.h)
        self.x_before_averaging = self.x * 1.

    def run_scaffold(self, lr, local_steps, c):
        x_old = self.x
        if self.c is None:
            self.c = np.zeros_like(self.x)
        for i in range(local_steps):
            if self.batch_size is None:
                g = self.loss.gradient(self.x)
            else:
                g = self.loss.stochastic_gradient(self.x, batch_size=self.batch_size)
            self.x -= lr * (g - self.c + c)
        self.c += 1 / (local_steps * lr) * (x_old - self.x) - c
        return self.x

    def run_gd(self, lr, local_steps):
        for i in range(local_steps):
            if self.batch_size is None:
                self.x -= lr * self.loss.gradient(self.x)
            else:
                self.x -= lr * self.loss.stochastic_gradient(self.x, batch_size=self.batch_size)
