import copy

import numpy as np
from tqdm import tqdm
from worker import Worker, TwoDirection_Worker, EF21P_DIANA_Worker
from optmethods.optimizer import StochasticOptimizer


def get_participation_mask(n_workers, cohort_size, rng=None):
    mask = np.full(n_workers, False)
    mask[:cohort_size] = True
    np.random.shuffle(mask)
    return mask


def create_random_sampling_pattern(dimension: int, num_workers: int, ones_per_row: int):
    q = np.zeros(shape=(dimension, num_workers))
    if dimension >= num_workers / ones_per_row:
        k = 0
        for i in range(dimension):
            for j in range(ones_per_row):
                q[i, k] = 1
                k = (k + 1) % num_workers
    else:
        k = 0
        for j in range(ones_per_row):
            for i in range(dimension):
                q[i, k] = 1
                k += 1
    return q


def get_communication_mask(dimension: int, num_workers: int, ones_per_row: int, rng=None):
    if ones_per_row == num_workers:
        return np.ones(shape=(dimension, num_workers))
    if rng is not None:
        return rng.permutation(
            create_random_sampling_pattern(dimension, num_workers, ones_per_row).T).T
    else:
        return np.random.permutation(
            create_random_sampling_pattern(dimension, num_workers, ones_per_row).T).T


class GD(StochasticOptimizer):
    def __init__(self, it_local, n_workers=None, lr=None, dim=None, worker_losses=None, it_max=None, threshold=1e-6, pbars=None,
                 *args, **kwargs):
        super(GD, self).__init__(*args, **kwargs)
        self.it_local = it_local
        self.n_workers = n_workers
        self.lr = lr
        self.d = dim
        self.worker_losses = worker_losses
        self.threshold = threshold

        self.x = np.zeros(shape=self.d)
        self.workers = [Worker(loss=worker_loss)
                        for worker_loss in self.worker_losses]

        self.name = 'GD'

        print(
            f'\n\n{self.name}:  lr:{round(self.lr, 5)}, local_steps: {self.it_local}')

        self.pbars = pbars
        self.p_bar = tqdm(total=it_max, desc=self.name, leave=False, colour='blue',
                          disable=self.pbars not in ['all', 'steps'])

    def step(self):
        loss_gap = self.loss.value(self.x) - self.loss.f_opt
        if loss_gap < self.threshold:
            self.tolerance = self.threshold
            return
        
        for i, worker in enumerate(self.workers):
            worker.x = np.copy(self.x)

        self.x = np.mean([worker.run_gd(lr=self.lr, local_steps=self.it_local)
                          for worker in tqdm(self.workers, leave=False, colour='green', disable=self.pbars != 'all')],
                         axis=0)

        self.p_bar.set_description(f'{self.name}: Loss Gap = {loss_gap:.2E}')
        self.p_bar.update()

    def init_run(self, *args, **kwargs):
        super(GD, self).init_run(*args, **kwargs)

    def update_trace(self):
        super(GD, self).update_trace()


# class ADIANA(StochasticOptimizer):
#     def __init__(self, alpha, beta, compressor, p, eta, theta_1, theta_2, n_workers=None, lr=None, dim=None, worker_losses=None, it_max=None, pbars=None,
#                  *args, **kwargs):
#         super(ADIANA, self).__init__(*args, **kwargs)
#         self.n_workers = n_workers
#         self.compressor = compressor
#         self.lr = lr
#         self.p = p
#         self.eta = eta
#         self.theta_1 = theta_1
#         self.theta_2 = theta_2
#         self.alpha = alpha
#         self.beta = beta
#         self.d = dim
#         self.worker_losses = worker_losses

#         self.h = np.zeros(shape=dim) * 1.
#         self.z = np.zeros(shape=dim) * 1.
#         self.y = np.zeros(shape=dim) * 1.
#         self.w = np.zeros(shape=dim) * 1.
#         self.x = np.zeros(shape=dim) * 1.

#         self.workers = [Worker(loss=worker_loss)
#                         for worker_loss in self.worker_losses]

#         self.name = 'ADIANA'

#         print(
#             f'\n\n{self.name}:  compressor: {compressor.name}, lr: {round(self.lr, 5)}'
#             f'  eta: {self.eta}, theta_1: {round(self.theta_1, 5)}, theta_2: {round(self.theta_2, 5)}'
#             f'  alpha: {self.alpha}, beta: {round(self.beta, 5)}, p: {round(self.p, 5)}'
#         )

#         self.pbars = pbars
#         self.p_bar = tqdm(total=it_max, desc=self.name, leave=False, colour='blue',
#                           disable=self.pbars not in ['all', 'steps'])

#     def step(self):
#         loss_gap = self.loss.value(self.x) - self.loss.f_opt
#         if loss_gap < self.threshold:
#             self.tolerance = self.threshold
#             return
        
#         self.z = self.theta_1*self.x + self.theta_2 * \
#             self.w + (1-self.theta_1-self.theta_2)*self.y

#         sum_1 = 0
#         sum_2 = 0
#         for worker in self.workers:
#             m_1, m_2 = worker.run_adiana(
#                 z=self.z, w=self.w, alpha=self.alpha, compressor=self.compressor)
#             sum_1 += m_1
#             sum_2 += m_2

#         g = sum_1/self.n_workers + self.h
#         self.h += self.alpha * sum_2/self.n_workers

#         y_prev = np.copy(self.y)
#         self.y = self.z - self.eta * g
#         self.x = self.beta * self.x + \
#             (1-self.beta)*self.z + (self.lr/self.eta)*(self.y - self.z)
#         random_number = np.random.rand()
#         self.w = y_prev if random_number < self.p else self.w

#         self.p_bar.set_description(f'{self.name}: Loss Gap = {loss_gap:.2E}')
#         self.p_bar.update()

#     def init_run(self, *args, **kwargs):
#         super(ADIANA, self).init_run(*args, **kwargs)

#     def update_trace(self):
#         super(ADIANA, self).update_trace()


class ADIANA(StochasticOptimizer):
    def __init__(self, alpha, beta, compressor, p, eta, theta_1, theta_2, n_workers=None, lr=None, dim=None, worker_losses=None, it_max=None, threshold=1e-6, pbars=None,
                 *args, **kwargs):
        super(ADIANA, self).__init__(*args, **kwargs)
        self.n_workers = n_workers
        self.compressor = compressor
        self.lr = lr
        self.p = p
        self.eta = eta
        self.theta_1 = theta_1
        self.theta_2 = theta_2
        self.alpha = alpha
        self.beta = beta
        self.d = dim
        self.worker_losses = worker_losses
        self.threshold = threshold

        self.h = np.zeros(shape=dim, dtype=np.float32) * 1.
        self.z = np.zeros(shape=dim, dtype=np.float32) * 1.
        self.y = np.zeros(shape=dim, dtype=np.float32) * 1.
        self.w = np.zeros(shape=dim, dtype=np.float32) * 1.
        self.x = np.zeros(shape=dim, dtype=np.float32) * 1.

        self.workers = [Worker(loss=worker_loss)
                        for worker_loss in self.worker_losses]

        self.name = 'ADIANA'

        print(
            f'\n\n{self.name}:  compressor: {compressor.name}, lr: {round(self.lr, 5)}'
            f'  eta: {self.eta}, theta_1: {round(self.theta_1, 5)}, theta_2: {round(self.theta_2, 5)}'
            f'  alpha: {self.alpha}, beta: {round(self.beta, 5)}, p: {round(self.p, 5)}'
        )

        self.pbars = pbars
        self.p_bar = tqdm(total=it_max, desc=self.name, leave=False, colour='blue',
                          disable=self.pbars not in ['all', 'steps'])

    def step(self):
        loss_gap = self.loss.value(self.x) - self.loss.f_opt
        if loss_gap < self.threshold or loss_gap > 100:
            self.tolerance = self.threshold
            return

        self.x = self.theta_1*self.z + self.theta_2 * \
            self.w + (1-self.theta_1-self.theta_2)*self.y

        sum_1 = 0
        sum_2 = 0
        for worker in self.workers:
            m_1, m_2 = worker.run_adiana(
                z=self.x, w=self.w, alpha=self.alpha, compressor=self.compressor)
            sum_1 += m_1
            sum_2 += m_2

        g = self.h + sum_1/self.n_workers
        self.h += self.alpha * sum_2/self.n_workers
        y_prev = np.copy(self.y)
        self.y = self.x - self.eta * g
        self.z = self.beta * self.z + \
            (1-self.beta)*self.x + self.lr/self.eta*(self.y - self.x)
        random_number = np.random.rand()
        self.w = y_prev if random_number < self.p else self.w

        self.p_bar.set_description(f'{self.name}: Loss Gap = {loss_gap:.2E}')
        self.p_bar.update()

    def init_run(self, *args, **kwargs):
        super(ADIANA, self).init_run(*args, **kwargs)

    def update_trace(self):
        super(ADIANA, self).update_trace()


class DIANA(StochasticOptimizer):
    def __init__(self, dual_lr, compressor, n_workers=None, lr=None, dim=None, worker_losses=None, it_max=None, threshold=1e-6, pbars=None,
                 *args, **kwargs):
        super(DIANA, self).__init__(*args, **kwargs)
        self.n_workers = n_workers
        self.dual_lr = dual_lr
        self.compressor = compressor
        self.lr = lr
        self.d = dim
        self.worker_losses = worker_losses
        self.threshold = threshold

        self.h = np.zeros(shape=dim, dtype=np.float32)

        self.workers = [Worker(loss=worker_loss)
                        for worker_loss in self.worker_losses]

        self.name = 'DIANA'

        print(
            f'\n\n{self.name}:  compressor: {compressor.name}, lr: {round(self.lr, 5)}')

        self.pbars = pbars
        self.p_bar = tqdm(total=it_max, desc=self.name, leave=False, colour='blue',
                          disable=self.pbars not in ['all', 'steps'])

    def step(self):
        loss_gap = self.loss.value(self.x) - self.loss.f_opt
        if loss_gap < self.threshold:
            self.tolerance = self.threshold
            return

        m = np.mean([worker.run_diana(x=self.x, dual_lr=self.dual_lr, compressor=self.compressor)
                     for worker in self.workers], axis=0)
        g = self.h + m
        self.x -= self.lr * g
        self.h += self.dual_lr * m

        self.p_bar.set_description(f'{self.name}: Loss Gap = {loss_gap:.2E}')
        self.p_bar.update()

    def init_run(self, *args, **kwargs):
        super(DIANA, self).init_run(*args, **kwargs)

    def update_trace(self):
        super(DIANA, self).update_trace()


class Scaffnew(StochasticOptimizer):
    def __init__(self, p=None, n_workers=None, lr=None, worker_losses=None, it_max=None, threshold=1e-6, pbars=None, d=None,
                 *args, **kwargs):
        super(Scaffnew, self).__init__(*args, **kwargs)
        self.n_workers = n_workers
        self.lr = lr
        self.d = d
        self.worker_losses = worker_losses
        self.p = p
        self.threshold = threshold

        self.x = np.zeros(shape=self.d)
        self.workers = [Worker(loss=worker_loss)
                        for worker_loss in self.worker_losses]

        for worker in self.workers:
            worker.h = np.zeros_like(self.x)

        self.name = 'Scaffnew'

        print(f'\n\n{self.name}:  lr:{round(self.lr, 5)}, p: {round(self.p, 5)}')
        self.pbars = pbars
        self.p_bar = tqdm(total=it_max, desc=self.name, leave=False, colour='blue',
                          disable=self.pbars not in ['all', 'steps'])

    def step(self):
        loss_gap = self.loss.value(self.x) - self.loss.f_opt
        if loss_gap < self.threshold:
            self.tolerance = self.threshold
            return
        
        rng = np.random.default_rng()
        local_steps = rng.geometric(p=self.p)

        worker_res = []
        for worker in self.workers:
            worker.x = np.copy(self.x)
            worker_res.append(worker.run_scaffnew(
                lr=self.lr, local_steps=local_steps))

        self.x = np.mean(worker_res, axis=0)

        for worker in self.workers:
            worker.update_scaffnew_variates(
                p=self.p, lr=self.lr, server_x=self.x)

        self.p_bar.set_description(f'{self.name}: Loss Gap = {loss_gap:.2E}')
        self.p_bar.update()

    def init_run(self, *args, **kwargs):
        super(Scaffnew, self).init_run(*args, **kwargs)

    def update_trace(self):
        super(Scaffnew, self).update_trace()


class FiveGCS(StochasticOptimizer):
    def __init__(self, it_local=None, n_workers=None, cohort_size=None, mu=None, lr=None, dual_lr=None, d=None,
                 worker_losses=None, it_max=None, threshold=1e-6, pbars=None, *args, **kwargs):
        super(FiveGCS, self).__init__(*args, **kwargs)
        self.it_local = it_local
        self.n_workers = n_workers
        self.cohort_size = cohort_size
        self.lr = lr
        self.dual_lr = dual_lr
        self.d = d
        self.worker_losses = worker_losses
        self.mu = mu

        self.threshold = threshold

        self.x = np.zeros(shape=self.d)
        self.x_hat = np.zeros_like(self.x)

        self.workers = []
        self.v = np.zeros_like(self.x)

        for i in range(len(self.worker_losses)):
            worker = Worker(loss=worker_losses[i])
            worker.five_gcs_u = np.zeros_like(self.x)
            self.workers.append(worker)

        self.name = '5GCS'

        print(f'\n\n{self.name}:  lr:{round(self.lr, 5)}, dual_lr: {round(self.dual_lr, 5)}, '
              f'local_steps: {self.it_local}')

        self.pbars = pbars
        self.p_bar = tqdm(total=it_max, desc=self.name, leave=False, colour='blue',
                          disable=self.pbars not in ['all', 'steps'])

    def step(self):
        loss_gap = self.loss.value(self.x) - self.loss.f_opt
        if loss_gap < self.threshold:
            self.tolerance = self.threshold
            return
        
        rng = np.random.default_rng()
        participation_mask = get_participation_mask(
            self.n_workers, self.cohort_size, rng=rng)
        self.x_hat = (1 / (1 + self.lr * self.mu)) * \
            (self.x - self.lr * self.v)

        for i, worker in enumerate(self.workers):
            if participation_mask[i]:
                worker.x = np.copy(self.x_hat)
                L_f = ((self.loss.max_smoothness - self.mu) /
                       self.n_workers) + self.dual_lr
                worker.local_solve(local_steps=self.it_local, mu=self.mu, n_workers=self.n_workers,
                                   local_solver_lr=1 / L_f, dual_lr=self.dual_lr)

        prev_v = copy.deepcopy(self.v)
        self.v = np.zeros_like(self.x)
        for i in range(len(self.workers)):
            self.v += self.workers[i].five_gcs_u

        self.x = self.x_hat - self.lr * \
            (self.n_workers / self.cohort_size) * (self.v - prev_v)

        self.p_bar.set_description(f'{self.name}: Loss Gap = {loss_gap:.2E}')
        self.p_bar.update()

    def init_run(self, *args, **kwargs):
        super(FiveGCS, self).init_run(*args, **kwargs)

    def update_trace(self):
        super(FiveGCS, self).update_trace()


class FiveGCS_CC(StochasticOptimizer):
    def __init__(self, compressor, it_local=None, n_workers=None, cohort_size=None, mu=None, lr=None, dual_lr=None, d=None,
                 worker_losses=None, it_max=None, threshold=1e-6, pbars=None, *args, **kwargs):
        super(FiveGCS_CC, self).__init__(*args, **kwargs)
        self.it_local = it_local
        self.compressor = compressor
        self.n_workers = n_workers
        self.cohort_size = cohort_size
        self.lr = lr
        self.dual_lr = dual_lr
        self.d = d
        self.worker_losses = worker_losses
        self.mu = mu
        self.threshold = threshold

        L_max = max([l.smoothness for l in worker_losses])
        self.L_f = ((L_max - mu) / n_workers) + dual_lr

        self.x = np.zeros(shape=self.d)
        self.x_hat = np.zeros_like(self.x)

        self.workers = []
        self.v = np.zeros_like(self.x)  # todo check

        for i in range(len(self.worker_losses)):
            worker = Worker(loss=worker_losses[i])
            worker.five_gcs_u = np.zeros_like(self.x)
            self.workers.append(worker)

        self.name = '5GCS-CC'

        print(f'\n\n{self.name}: compressor: {compressor.name}, lr:{round(self.lr, 5)}, dual_lr: {round(self.dual_lr, 5)}, '
              f'local_steps: {self.it_local}')

        self.pbars = pbars
        self.p_bar = tqdm(total=it_max, desc=self.name, leave=False, colour='blue',
                          disable=self.pbars not in ['all', 'steps'])

    def step(self):
        loss_gap = self.loss.value(self.x) - self.loss.f_opt
        if loss_gap < self.threshold:
            self.tolerance = self.threshold
            return
        
        rng = np.random.default_rng()
        participation_mask = get_participation_mask(
            self.n_workers, self.cohort_size, rng=rng)
        self.x_hat = (1 / (1 + self.lr * self.mu)) * \
            (self.x - self.lr * self.v)

        sum = 0
        for i, worker in enumerate(self.workers):
            if participation_mask[i]:
                worker.x = np.copy(self.x_hat)
                sum += worker.local_solve_cc(local_steps=self.it_local, mu=self.mu, n_workers=self.n_workers, cohort_size=self.cohort_size,
                                             local_solver_lr=1 / self.L_f, dual_lr=self.dual_lr, compressor=self.compressor)

        prev_v = copy.deepcopy(self.v)
        self.v += 1 / (1 + self.compressor.w) * \
            (self.cohort_size / self.n_workers) * sum

        self.x = self.x_hat - self.lr * \
            (self.n_workers / self.cohort_size) * \
            (1 + self.compressor.w) * (self.v - prev_v)
        
        self.p_bar.set_description(f'{self.name}: Loss Gap = {loss_gap:.2E}')
        self.p_bar.update()

    def init_run(self, *args, **kwargs):
        super(FiveGCS_CC, self).init_run(*args, **kwargs)

    def update_trace(self):
        super(FiveGCS_CC, self).update_trace()


class Scaffold(StochasticOptimizer):
    def __init__(self, it_local, n_workers=None, cohort_size=None, lr=None, worker_losses=None, d=None,
                 global_lr=1, it_max=None, threshold=1e-6, pbars=None, *args, **kwargs):
        super(Scaffold, self).__init__(*args, **kwargs)
        self.it_local = it_local
        self.n_workers = n_workers
        self.cohort_size = cohort_size
        self.lr = lr
        self.worker_losses = worker_losses
        self.global_lr = global_lr
        self.threshold = threshold

        self.x = np.zeros(d)
        self.server_c = np.zeros_like(self.x)

        self.workers = []
        for i in range(len(self.worker_losses)):
            worker = Worker(loss=worker_losses[i])
            worker.scaffold_local_c = np.zeros_like(self.x)
            self.workers.append(worker)

        self.name = 'Scaffold'

        print(f'\n\n{self.name}:  lr:{round(self.lr, 5)}, local_steps: {self.it_local}, '
              f'participation: {round(self.cohort_size / self.n_workers, 5)}')
        self.pbars = pbars
        self.p_bar = tqdm(total=it_max, desc=self.name, leave=False, colour='blue',
                          disable=self.pbars not in ['all', 'steps'])

    def step(self):
        loss_gap = self.loss.value(self.x) - self.loss.f_opt
        if loss_gap < self.threshold:
            self.tolerance = self.threshold
            return
        
        rng = np.random.default_rng()
        participation_mask = get_participation_mask(
            self.n_workers, self.cohort_size, rng=rng)

        for i, worker in enumerate(self.workers):
            if participation_mask[i]:
                worker.x = np.copy(self.x)

        ret = []
        worker_cs = []
        for i, worker in enumerate(self.workers):
            if participation_mask[i]:
                ret.append(worker.run_scaffold(
                    lr=self.lr, local_steps=self.it_local, server_c=np.copy(self.server_c)))
                worker_cs.append(worker.scaffold_local_c)

        assert len(ret) == self.cohort_size

        x_new = np.sum(ret, axis=0) / self.cohort_size
        c_new = np.sum(worker_cs, axis=0) / self.cohort_size

        self.x += self.global_lr * (x_new - self.x)
        self.server_c += (self.cohort_size / self.n_workers) * \
            (c_new - self.server_c)

        self.p_bar.set_description(f'{self.name}: Loss Gap = {loss_gap:.2E}')
        self.p_bar.update()

    def init_run(self, *args, **kwargs):
        super(Scaffold, self).init_run(*args, **kwargs)


class Tamuna(StochasticOptimizer):
    def __init__(self, p, s, eta, n_workers=None, cohort_size=None, lr=None, batch_size=None, worker_losses=None, d=None,
                 it_max=None, threshold=1e-6, pbars=None, *args, **kwargs):
        super(Tamuna, self).__init__(*args, **kwargs)
        self.p = p
        self.s = s
        self.eta = eta
        self.n_workers = n_workers
        self.cohort_size = cohort_size
        self.lr = lr
        self.batch_size = batch_size
        self.worker_losses = worker_losses
        self.d = d
        self.threshold = threshold

        self.x = np.zeros(self.d)

        self.workers = []
        for i in range(len(self.worker_losses)):
            worker = Worker(loss=worker_losses[i])
            worker.h = np.zeros_like(self.x)
            self.workers.append(worker)

        self.name = 'TAMUNA'

        print(f'\n\n{self.name}:  lr:{round(self.lr, 5)}, p: {round(self.p, 5)}, s: {round(self.s, 5)}, '
              f'eta: {round(self.eta, 5)}, participation: {round(self.cohort_size / self.n_workers, 5)}')

        self.pbars = pbars
        self.p_bar = tqdm(total=it_max, desc=self.name, leave=False, colour='blue',
                          disable=self.pbars not in ['all', 'steps'])

    def step(self):
        loss_gap = self.loss.value(self.x) - self.loss.f_opt
        if loss_gap < self.threshold:
            self.tolerance = self.threshold
            return
    
        rng = np.random.default_rng()
        participation_mask = get_participation_mask(
            n_workers=self.n_workers, cohort_size=self.cohort_size, rng=rng)
        local_steps = rng.geometric(p=self.p)
        q = get_communication_mask(
            self.dim, self.cohort_size, self.s, rng=self.rng)

        ret = []
        participating_worker_index = 0
        for j, worker in tqdm(list(enumerate(self.workers)), leave=False, colour='green', disable=self.pbars != 'all'):
            if participation_mask[j]:
                worker.x = np.copy(self.x)
                uncompressed_worker_result = worker.run_tamuna_local(
                    lr=self.lr, local_steps=local_steps)
                ret.append(np.multiply(uncompressed_worker_result,
                           q[:, participating_worker_index]))
                # a = np.multiply(uncompressed_worker_result, q[:, participating_worker_index])
                # print(len(a[a!=0]))
                participating_worker_index += 1

        self.x = np.sum(ret, axis=0) / self.s

        participating_worker_index = 0
        for j, worker in enumerate(self.workers):
            if participation_mask[j]:
                worker.update_tamuna_variates(lr=self.lr, eta=self.eta,
                                              server_compressed=np.multiply(
                                                  self.x, q[:, participating_worker_index]),
                                              worker_compressed=ret[participating_worker_index])
                participating_worker_index += 1

        self.p_bar.set_description(f'{self.name}: Loss Gap = {loss_gap:.2E}')
        self.p_bar.update()

    def init_run(self, *args, **kwargs):
        super(Tamuna, self).init_run(*args, **kwargs)

    def update_trace(self):
        super(Tamuna, self).update_trace()


class CompressedScaffnew(StochasticOptimizer):
    def __init__(self, p, s, eta, n_workers=None, lr=None, worker_losses=None, d=None,
                 it_max=None, threshold=1e-6, pbars=None, *args, **kwargs):
        super(CompressedScaffnew, self).__init__(*args, **kwargs)
        self.p = p
        self.s = s
        self.eta = eta
        self.n_workers = n_workers
        self.lr = lr
        self.worker_losses = worker_losses
        self.d = d
        self.threshold = threshold

        self.x = np.zeros(self.d)

        self.workers = []
        for i in range(len(self.worker_losses)):
            worker = Worker(loss=worker_losses[i])
            worker.h = np.zeros_like(self.x)
            self.workers.append(worker)

        self.name = 'CompressedScaffnew'

        print(f'\n\n{self.name}:  lr:{round(self.lr, 5)}, p: {round(self.p, 5)}, s: {round(self.s, 5)}, '
              f'eta: {round(self.eta, 5)}')

        self.pbars = pbars
        self.p_bar = tqdm(total=it_max, desc=self.name, leave=False, colour='blue',
                          disable=self.pbars not in ['all', 'steps'])

    def step(self):
        loss_gap = self.loss.value(self.x) - self.loss.f_opt
        if loss_gap < self.threshold:
            self.tolerance = self.threshold
            print('Conv?')
            return

        rng = np.random.default_rng()
        local_steps = rng.geometric(p=self.p)
        q = get_communication_mask(self.dim, self.n_workers, self.s, rng=rng)

        ret = []
        for j, worker in tqdm(list(enumerate(self.workers)), leave=False, colour='green', disable=self.pbars != 'all'):
            worker.x = np.copy(self.x)
            worker_result = worker.run_compressed_scaffnew_local(
                lr=self.lr, local_steps=local_steps, q=q[:, j])
            ret.append(worker_result)

        self.x = np.sum(ret, axis=0) / self.s

        for j, worker in enumerate(self.workers):
            worker.update_compressed_scaffnew_variates(x=np.copy(self.x), lr=self.lr, eta=self.eta, q=q[:, j],
                                                       worker_result=ret[j])

        self.p_bar.set_description(f'{self.name}: Loss Gap = {loss_gap:.2E}')
        self.p_bar.update()

    def init_run(self, *args, **kwargs):
        super(CompressedScaffnew, self).init_run(*args, **kwargs)

    def update_trace(self):
        super(CompressedScaffnew, self).update_trace()


class GradSkip(StochasticOptimizer):
    """
    Stochastic gradient descent with decreasing or constant learning rate.

    Arguments:
        lr (float, optional): an estimate of the inverse smoothness constant
        lr_decay_coef (float, optional): the coefficient in front of the number of finished iterations
            in the denominator of step-size. For strongly convex problems, a good value
            is mu/2, where mu is the strong convexity constant
        lr_decay_power (float, optional): the power to exponentiate the number of finished iterations
            in the denominator of step-size. For strongly convex problems, a good value is 1 (default: 1)
        it_start_decay (int, optional): how many iterations the step-size is kept constant
            By default, will be set to have about 2.5% of iterations with the step-size equal to lr0
        batch_size (int, optional): the number of samples from the function to be used at each iteration
    """

    def __init__(self, p, worker_losses, label='GradSkip', q=None, grad_time=None, com_time=1, cohort_size=None,
                 lr=None,
                 lr_decay_power=1, it_start_decay=None, it_max=None, threshold=1e-6, 
                 batch_size=None, pbars=None, *args, **kwargs):
        super(GradSkip, self).__init__(label=label, *args, **kwargs)
        self.p = p
        self.n_workers = len(worker_losses)
        if cohort_size is None:
            cohort_size = self.n_workers
        self.cohort_size = cohort_size
        self.lr = lr
        self.lr_decay_power = lr_decay_power
        self.it_start_decay = it_start_decay
        self.batch_size = batch_size
        self.worker_losses = worker_losses
        self.trace.hs = []
        self.grad_time = [1]*self.n_workers if grad_time is None else grad_time
        self.com_time = com_time
        self.threshold = threshold

        self.xs = {}

        qs = [1] * self.n_workers if q is None else q
        self.workers = [Worker(loss=l, q=qs[i])
                        for i, l in enumerate(self.worker_losses)]

        self.name = 'GradSkip'
        print(f'\n\n{self.name}: , p: {round(self.p, 5)}, lr: {round(self.lr, 5)}, '
              f'participation: {round(self.cohort_size / self.n_workers, 5)}'
              )
        self.pbars = pbars
        self.p_bar = tqdm(total=it_max, desc=self.name, leave=False, colour='blue',
                          disable=self.pbars not in ['all', 'steps'])

    def step(self):
        loss_gap = self.loss.value(self.x) - self.loss.f_opt
        if loss_gap < self.threshold:
            self.tolerance = self.threshold
            return

        rng = np.random.default_rng()
        local_steps = rng.geometric(p=self.p)
        participation_mask = get_participation_mask(
            self.n_workers, self.cohort_size, rng=rng)

        ret = []
        participating_worker_index = 0
        for j, worker in tqdm(list(enumerate(self.workers)), leave=False, colour='green', disable=self.pbars != 'all'):
            if participation_mask[j]:
                x = np.copy(self.x)
                worker_result = worker.run_gradskip_local(
                    x=x, p=self.p, local_steps=local_steps, lr=self.lr, batch_size=self.batch_size)
                ret.append(worker_result)
                participating_worker_index += 1

        self.x = np.mean(ret, axis=0)

        participating_worker_index = 0
        for j, worker in enumerate(self.workers):
            if participation_mask[j]:
                worker.update_gradskip_params(
                    x_av=self.x, p=self.p, lr=self.lr)
                participating_worker_index += 1

        self.p_bar.set_description(f'{self.name}: Loss Gap = {loss_gap:.2E}')
        self.p_bar.update()

    def init_run(self, *args, **kwargs):
        super(GradSkip, self).init_run(*args, **kwargs)

    def get_xs(self):
        return self.xs

    def get_workers_time(self):
        '''python list of size (n_workers, it_max)'''
        return self.workers_time

    def get_communication_time(self):
        '''python list of size it_max = #communications'''
        return np.amax(self.get_workers_time(), axis=0)

    def get_total_time(self):
        return sum(self.get_communication_time())

    def get_workers_steps(self):
        '''python list of size (n_workers, it_max)'''
        return self.workers_steps


class LoCoDL(StochasticOptimizer):
    """
    Stochastic gradient descent with decreasing or constant learning rate.

    Arguments:
        lr (float, optional): an estimate of the inverse smoothness constant
        lr_decay_coef (float, optional): the coefficient in front of the number of finished iterations
            in the denominator of step-size. For strongly convex problems, a good value
            is mu/2, where mu is the strong convexity constant
        lr_decay_power (float, optional): the power to exponentiate the number of finished iterations
            in the denominator of step-size. For strongly convex problems, a good value is 1 (default: 1)
        it_start_decay (int, optional): how many iterations the step-size is kept constant
            By default, will be set to have about 2.5% of iterations with the step-size equal to lr0
        batch_size (int, optional): the number of samples from the function to be used at each iteration
    """

    def __init__(self,
                 gamma, rho, chi,
                 loss_y,
                 compressor,
                 d,
                 p,
                 #  s,
                 loss_coeff,
                 label='LoCoDL',
                 com_time=1, cohort_size=None,
                 lr0=None, lr_max=np.inf, lr_decay_coef=0,
                 lr_decay_power=1, it_start_decay=None, it_max=None, threshold=1e-6, 
                 batch_size=None, worker_losses=None, pbars=None, *args, **kwargs):
        super(LoCoDL, self).__init__(label=label, *args, **kwargs)
        self.n_workers = len(worker_losses)
        if cohort_size is None:
            cohort_size = self.n_workers
        self.cohort_size = cohort_size
        self.lr0 = lr0
        self.lr_max = lr_max
        self.lr_decay_coef = lr_decay_coef
        self.lr_decay_power = lr_decay_power
        self.it_start_decay = it_start_decay
        self.batch_size = batch_size
        self.worker_losses = worker_losses
        self.trace.hs = []
        self.com_time = com_time
        self.compressor = compressor
        self.gamma = gamma
        self.rho = rho
        self.chi = chi
        self.threshold = threshold

        self.xs = {}
        self.ys = {}

        self.d = d
        self.p = p
        # self.s = s
        self.workers = [Worker(loss=l, loss_coeff=loss_coeff, loss_y=loss_y)
                        for l in self.worker_losses
                        ]
        self.name = 'LoCoDL'

        print(f'\n\n{self.name}:  compressor: {compressor.name}, gamma: {round(self.gamma, 5)}, p: {round(self.p, 5)}, rho: {round(self.rho, 5)},'
              f'chi: {round(self.chi, 5)}, participation: {round(self.cohort_size / self.n_workers, 5)}'
              )
        self.pbars = pbars
        self.p_bar = tqdm(total=it_max, desc=self.name, leave=False, colour='blue',
                          disable=self.pbars not in ['all', 'steps'])

    def step(self):
        loss_gap = self.loss.value(self.x) - self.loss.f_opt
        if loss_gap < self.threshold:
            self.tolerance = self.threshold
            return

        rng = np.random.default_rng()
        local_steps = rng.geometric(p=self.p)
        participation_mask = get_participation_mask(
            self.n_workers, self.cohort_size, rng=rng)

        d = []
        uncompressed_xs = []
        uncompressed_ys = []

        participating_worker_index = 0
        for j, worker in tqdm(list(enumerate(self.workers)), leave=False, colour='green', disable=self.pbars != 'all'):
            if participation_mask[j]:
                uncompressed_x, uncompressed_y = worker.run_loco_local(x=self.x,
                                                                       gamma=self.gamma,
                                                                       local_steps=local_steps,
                                                                       batch_size=self.batch_size,
                                                                       )
                uncompressed_xs.append(uncompressed_x)
                uncompressed_ys.append(uncompressed_y)
                # todo make sure to not change the original array
                d.append(self.compressor.compress(
                    uncompressed_x - uncompressed_y))
                participating_worker_index += 1

        self.x = uncompressed_ys[0]

        # assert (self.x == uncompressed_ys[0]).all()
        d_bar = np.mean(d, axis=0) / 2
        participating_worker_index = 0
        for j, worker in enumerate(self.workers):
            if participation_mask[j]:
                worker.update_loco_params(gamma=self.gamma, rho=self.rho, chi=self.chi, p=self.p,
                                          d=d[j], d_bar=d_bar,
                                          w=self.compressor.w)
                participating_worker_index += 1

        self.p_bar.set_description(f'{self.name}: Loss Gap = {loss_gap:.2E}')
        self.p_bar.update()

    def init_run(self, *args, **kwargs):
        super(LoCoDL, self).init_run(*args, **kwargs)

    def update_trace(self):
        super(LoCoDL, self).update_trace()

    def get_xs(self):
        return self.xs

    def get_ys(self):
        return self.ys

    def get_workers_time(self):
        '''python list of size (n_workers, it_max)'''
        return self.workers_time

    def get_communication_time(self):
        '''python list of size it_max = #communications'''
        return np.amax(self.get_workers_time(), axis=0)

    def get_total_time(self):
        return sum(self.get_communication_time())

    def get_workers_steps(self):
        '''python list of size (n_workers, it_max)'''
        return self.workers_steps


class BiCoLoR(StochasticOptimizer):
    """
    Stochastic gradient descent with decreasing or constant learning rate.

    Arguments:
        lr (float, optional): an estimate of the inverse smoothness constant
        lr_decay_coef (float, optional): the coefficient in front of the number of finished iterations
            in the denominator of step-size. For strongly convex problems, a good value
            is mu/2, where mu is the strong convexity constant
        lr_decay_power (float, optional): the power to exponentiate the number of finished iterations
            in the denominator of step-size. For strongly convex problems, a good value is 1 (default: 1)
        it_start_decay (int, optional): how many iterations the step-size is kept constant
            By default, will be set to have about 2.5% of iterations with the step-size equal to lr0
        batch_size (int, optional): the number of samples from the function to be used at each iteration
    """

    def __init__(self,
                 gamma, eta, eta_y, rho, rho_y,
                 p, d, k,
                 loss_y,
                 loss_server,
                 loss_server_coeff,
                 compressor,
                 compressor_server,
                 loss_coeff,
                 p_a=None, p_b=None,
                 convex=False,
                 label='BiCoLoR',
                 com_time=1, cohort_size=None,
                 lr0=None, lr_max=np.inf, lr_decay_coef=0,
                 lr_decay_power=1, it_start_decay=None, it_max=None, threshold=1e-6, 
                 batch_size=None, worker_losses=None, pbars=None, *args, **kwargs):
        super(BiCoLoR, self).__init__(label=label, *args, **kwargs)
        self.n_workers = len(worker_losses)
        if cohort_size is None:
            cohort_size = self.n_workers
        self.cohort_size = cohort_size
        self.lr0 = lr0
        self.lr_max = lr_max
        self.lr_decay_coef = lr_decay_coef
        self.lr_decay_power = lr_decay_power
        self.it_start_decay = it_start_decay
        self.batch_size = batch_size
        self.worker_losses = worker_losses
        self.trace.hs = []
        self.com_time = com_time

        self.loss_y = loss_y
        self.loss_server = loss_server
        self.loss_server_coeff = loss_server_coeff
        self.loss_coeff = loss_coeff
        self.compressor = compressor
        self.compressor_server = compressor_server
        self.gamma = gamma
        self.eta = eta
        self.eta_y = eta_y
        self.rho = rho
        self.rho_y = rho_y
        self.k = k
        self.y = None

        self.threshold = threshold

        self.xs = {}
        self.ys = {}

        self.d = d
        self.p = p
        self.p_a = p_a
        self.p_b = p_b
        self.convex = convex
        self.coin_flips = 1
        self.total_local_steps = 0

        print('K', k)
        assert p < 1
        # self.s = s
        self.workers = [Worker(loss=l, loss_coeff=loss_coeff, loss_y=loss_y)
                        for l in self.worker_losses
                        ]
        self.name = label

        print(f'\n\n{self.name}:  compressor: {compressor.name}, p: {round(self.p, 5)}, gamma: {round(self.gamma, 5)}, rho: {round(self.rho, 5)}, '
              f'participation: {round(self.cohort_size / self.n_workers, 5)}'
              )
        self.pbars = pbars
        self.p_bar = tqdm(total=it_max, desc=self.name, leave=False, colour='blue',
                          disable=self.pbars not in ['all', 'steps'])

    def step(self):
        loss_gap = self.loss.value(self.x) - self.loss.f_opt
        if loss_gap < self.threshold:
            self.tolerance = self.threshold
            return

        participation_mask = get_participation_mask(
            self.n_workers, self.cohort_size)
        
        # Pick a subset
        nonzero_mask = np.zeros(self.d, dtype=bool)
        indices = np.arange(self.d)
        np.random.shuffle(indices)
        nonzero_mask[indices[:self.k]] = True

        c_worker = []

        local_steps = self.find_local_steps()
        self.total_local_steps += local_steps
        for j, worker in tqdm(list(enumerate(self.workers)), leave=False, colour='green', disable=self.pbars != 'all'):
            if participation_mask[j]:
                x, y = worker.run_loco_local(x=self.x,
                                                gamma=self.gamma,
                                                local_steps=local_steps,
                                                batch_size=self.batch_size,
                                                )
                # todo make sure to not change the original array
                c = x - y
                c[~nonzero_mask] = 0
                c_worker.append(self.compressor.compress(c))

        self.server_local_steps(local_steps=local_steps)
        c_server = self.x - self.y
        c_server[~nonzero_mask] = 0
        c_server = self.compressor_server.compress(c_server)

        for j, worker in enumerate(self.workers):
            if participation_mask[j]:
                worker.update_bicolor_params(gamma=self.gamma, rho=self.rho, rho_y=self.rho_y, eta=self.eta, eta_y=self.eta_y,
                                              p=self.p, k=self.k, d=self.d,
                                              c_worker=c_worker[j], c_server=c_server,
                                              nonzero_mask=nonzero_mask)

        c_worker_bar = np.mean(c_worker, axis=0)
        self.update_server_params(c_worker_bar=c_worker_bar, nonzero_mask=nonzero_mask, c_server=c_server)

        self.p_bar.set_description(f'{self.name}: Loss Gap = {loss_gap:.4E}')
        self.p_bar.update()

    def get_total_local_steps(self):
        return self.total_local_steps

    def find_local_steps(self):
        if not self.convex:
            return np.random.geometric(p=self.p)

        # print("Finding local steps")
        local_steps = 0
        while True:
            self.p = np.sqrt(self.p_b / (self.p_a + self.coin_flips))
            coin = np.random.binomial(1, self.p)
            local_steps += 1
            self.coin_flips += 1
            if coin == 1:
                break

        # print("Local steps:", local_steps)
        return local_steps

    def server_local_steps(self, local_steps):
        if self.y is None:
            self.y = np.copy(self.x)
            self.u = self.x * 0.
            self.v = self.x * 0.

        for _ in range(local_steps):
            if self.batch_size is None:
                g = self.loss_server_coeff * self.loss_server.gradient(self.x)
                g_y = self.loss_coeff * self.loss_y.gradient(self.y)
            else:
                g = self.loss_server_coeff * \
                    self.loss_server.stochastic_gradient(
                        self.x, batch_size=self.batch_size)
                g_y = self.loss_coeff * \
                    self.loss_y.stochastic_gradient(
                        self.y, batch_size=self.batch_size)

            self.x -= self.gamma * (g - self.u)
            self.y -= self.gamma * (g_y - self.v)

    def update_server_params(self, c_worker_bar, nonzero_mask, c_server):
        self.x[nonzero_mask] = (1-(self.rho+self.rho_y)/2)*self.x[nonzero_mask] + (self.rho+self.rho_y)/2 * self.y[nonzero_mask] + self.rho/2*c_worker_bar[nonzero_mask]
        self.y += self.rho_y*c_server
        self.u += self.p * self.k * self.eta / 2 / self.d / self.gamma * c_worker_bar - \
              self.p * self.k * (self.eta_y + self.eta) / 2 / self.d / self.gamma * c_server
        self.v += self.p * self.k * self.eta_y / self.d / self.gamma * c_server

    def init_run(self, *args, **kwargs):
        super(BiCoLoR, self).init_run(*args, **kwargs)

    def update_trace(self):
        super(BiCoLoR, self).update_trace()


class TwoDirection(StochasticOptimizer):
    def __init__(self, point, compressor, worker_compressor, strongly_convex_constant, threshold,
                 gamma=None,             
                 gamma_multiply=None, seed=None,
                 scale_prob=1.0,
                 label='2Direction', it_max=None,
                 cohort_size=None, worker_losses=None, pbars=None, *args, **kwargs):
        
        super(TwoDirection, self).__init__(label=label, *args, **kwargs)
        self.n_workers = len(worker_losses)
        if cohort_size is None:
            cohort_size = self.n_workers
        self.cohort_size = cohort_size
        self.threshold = threshold
        self._compressor = compressor
        self._worker_compressor = worker_compressor
        self._worker_losses = worker_losses

        self._generator = np.random.default_rng(seed)
        self._gamma = gamma
        if gamma_multiply is not None:
            self._gamma *= gamma_multiply
        self._bar_lipschitz = 1. / self._gamma
        print("Gamma: {}".format(self._gamma))
        self._w, self._z, self._u, self._x, = point, point, point, point
        self.x = self._z
        self._compressor_manager = None
        self._strongly_convex_constant = strongly_convex_constant

        self._downlink_costs = [0]

        self._alpha = compressor.alpha
        self._omega = worker_compressor.w
        self._sum_gamma = 1.

        self.workers = [TwoDirection_Worker(loss=worker_loss, compressor=worker_compressor, point=point, bar_lipschitz=self._bar_lipschitz, strongly_convex_constant=self._strongly_convex_constant, sum_gamma=self._sum_gamma)
                        for worker_loss in self._worker_losses]

        gradients = [worker._h for worker in self.workers]
        self._h = np.mean(gradients, axis=0)
        self._v = self._h
        self._k = self._h

        for worker in self.workers:
            worker._init_k(self._h)
        
        self._beta = 1. / (self._omega + 1)
        if self._omega == 0:
            self._prob = 1.
        else:
            self._prob = max(1. / (self._omega + 1), 
                             np.sqrt(1. / (self._alpha * self._omega * (self._omega + 1)**2)))
        self._prob = self._prob * scale_prob 
        self._prob = min(1., self._prob)
        assert self._prob >= 0.
        if self._omega == 0:
            self._tau = 1.
        else:
            self._tau = min(1., 1. / (self._omega ** (1 / 3.) * (self._omega + 1) ** (2 / 3.)))

        self.name = label
        self.pbars = pbars
        self.p_bar = tqdm(total=it_max, desc=self.name, leave=False, colour='blue',
                          disable=self.pbars not in ['all', 'steps'])
    
    def step(self):
        self.x = copy.deepcopy(self._x)
        loss_gap = self.loss.value(self.x) - self.loss.f_opt
        if loss_gap < self.threshold:
            self.tolerance = self.threshold
            return
    
        next_sum_gamma, gamma, theta = \
            _calculate_learning_rates_for_accelerated_bidirectional_diana(
                self._sum_gamma, self._bar_lipschitz, self._strongly_convex_constant, 
                self._prob, self._alpha, self._tau, self._beta)
        y = theta * self._w + (1 - theta) * self._z
        messages = [worker._calculate_message_at_point_y(theta=theta) for worker in self.workers]
        message = np.mean(messages, axis=0)
        gradient_estimator = self._h + message
        self._u = self._gradient_step(gradient_estimator, self._u, gamma, next_sum_gamma,
                                      self._sum_gamma, self._bar_lipschitz, self._strongly_convex_constant, y)
        q = self._gradient_step(self._k, self._w, gamma, next_sum_gamma,
                                self._sum_gamma, self._bar_lipschitz, self._strongly_convex_constant, y)
        primal_compressed_message = self._compressor.compress(self._u - q)
        # self._w = q + primal_compressed_message.decompress()
        self._w = q + primal_compressed_message
        self._x = theta * self._u + (1 - theta) * self._z
        coin = bernoulli_sample(self._generator, self._prob)
        primal_noncompressed_message = None
        dual_noncompressed_message = None

        additional_cost = 0
        if coin:
            additional_cost = 2*len(self.x)
            self._k = self._v
            self._z = self._x
            primal_noncompressed_message = self._z
            dual_noncompressed_message = self._k
        messages = [worker._calculate_message_at_point_z(primal_compressed_message=primal_compressed_message,
                                                     primal_noncompressed_message=primal_noncompressed_message,
                                                     dual_noncompressed_message=dual_noncompressed_message,
                                                     gamma=gamma, next_sum_gamma=next_sum_gamma) for worker in self.workers]
        message = np.mean(messages, axis=0)
        self._v = (1 - self._tau) * self._v + self._tau * (self._h + message)
        self._h = self._h + self._beta * message
        self._sum_gamma = next_sum_gamma

        self._downlink_costs.append(self._compressor.uplink_cost + additional_cost)

        self.p_bar.set_description(f'{self.name}: Loss Gap = {loss_gap:.4E}')
        self.p_bar.update()
    
    @staticmethod
    def _gradient_step(dual_vector, previous_point, next_gamma, next_sum_gamma,
                       sum_gamma, bar_lipschitz, strongly_convex_constant, y):
        point =  ((bar_lipschitz + sum_gamma * strongly_convex_constant) * previous_point
                  + strongly_convex_constant * next_gamma * y
                  - next_gamma * dual_vector)
        point = point / (bar_lipschitz + next_sum_gamma * strongly_convex_constant)
        return point

    def init_run(self, *args, **kwargs):
        super(TwoDirection, self).init_run(*args, **kwargs)

    def update_trace(self):
        super(TwoDirection, self).update_trace()

    def get_downlink_cost(self):
        return np.array(self._downlink_costs)
    

def _calculate_learning_rates_for_accelerated_bidirectional_diana(sum_gamma, bar_lipschitz, mu, p, alpha, tau, beta):
    a, b, c = (p * bar_lipschitz * sum_gamma, p * (bar_lipschitz + sum_gamma * mu), -(bar_lipschitz + sum_gamma * mu))
    theta_bar = (2 * c) / (-b - np.sqrt(b**2 - 4 * a * c))
    theta = min(theta_bar, 1 / 4. * np.min([1., alpha / p, tau / p, beta / p]))
    gamma = p * theta * sum_gamma / (1 - p * theta)
    next_sum_gamma = sum_gamma + gamma
    return next_sum_gamma, gamma, theta


def bernoulli_sample(random_generator, prob):
    if prob == 0.0:
        return False
    return random_generator.random() < prob



class EF21P_DIANA(StochasticOptimizer):
    def __init__(self, point, compressor, worker_compressor, threshold,
                 gamma=None, beta=None,
                 label='EF21-P+DIANA', it_max=None,
                 cohort_size=None, worker_losses=None, pbars=None, *args, **kwargs):
        
        super(EF21P_DIANA, self).__init__(label=label, *args, **kwargs)
        self.n_workers = len(worker_losses)
        if cohort_size is None:
            cohort_size = self.n_workers
        self.cohort_size = cohort_size
        self.threshold = threshold
        self._compressor = compressor
        self._worker_compressor = worker_compressor
        self._worker_losses = worker_losses

        self._gamma = gamma
        print("Gamma: {}".format(self._gamma))
        self.beta = beta
        self._w = copy.deepcopy(point)
        self.h = np.zeros_like(point, dtype=np.float32)

        self.workers = [EF21P_DIANA_Worker(loss=worker_loss, compressor=worker_compressor,point=point,beta=self.beta)
                        for worker_loss in self._worker_losses]

        self.name = label
        self.pbars = pbars
        self.p_bar = tqdm(total=it_max, desc=self.name, leave=False, colour='blue',
                          disable=self.pbars not in ['all', 'steps'])
    
    def step(self):
        loss_gap = self.loss.value(self.x) - self.loss.f_opt
        if loss_gap < self.threshold:
            self.tolerance = self.threshold
            return

        messages = [worker._calculate_message(w=self._w) for worker in self.workers]
        message = np.mean(messages, axis=0)
        gradient_estimator = self.h + message
        self.h += self.beta * message
        self.x = self.x - self._gamma * gradient_estimator
        p = self._compressor.compress(self.x - self._w)
        self._w += p
                
        self.p_bar.set_description(f'{self.name}: Loss Gap = {loss_gap:.4E}')
        self.p_bar.update()
    
    def init_run(self, *args, **kwargs):
        super(EF21P_DIANA, self).init_run(*args, **kwargs)

    def update_trace(self):
        super(EF21P_DIANA, self).update_trace()
    