import copy

import numpy as np
from tqdm import tqdm
from worker import Worker
from optmethods.optimizer import StochasticOptimizer


def get_participation_mask(n_workers, cohort_size, rng=None):
    mask = np.full(n_workers, False)
    mask[:cohort_size] = True
    if rng is not None:
        rng.shuffle(mask)
    else:
        np.random.shuffle(mask)
    return mask


def create_random_sampling_pattern(dimension: int, num_workers: int, ones_per_row: int):
    q = np.zeros(shape=(dimension, num_workers))
    if dimension >= num_workers / ones_per_row:
        k = 0
        for i in range(dimension):
            for j in range(ones_per_row):
                q[i, k] = 1
                k = (k + 1) % num_workers
    else:
        k = 0
        for j in range(ones_per_row):
            for i in range(dimension):
                q[i, k] = 1
                k += 1
    return q


def get_communication_mask(dimension: int, num_workers: int, ones_per_row: int, rng=None):
    if ones_per_row == num_workers:
        return np.ones(shape=(dimension, num_workers))
    if rng is not None:
        return rng.permutation(
            create_random_sampling_pattern(dimension, num_workers, ones_per_row).T).T
    else:
        return np.random.permutation(
            create_random_sampling_pattern(dimension, num_workers, ones_per_row).T).T


class GD(StochasticOptimizer):
    def __init__(self, it_local, n_workers=None, lr=None, dim=None, worker_losses=None, it_max=None, threshold=1e-6, pbars=None,
                 *args, **kwargs):
        super(GD, self).__init__(*args, **kwargs)
        self.it_local = it_local
        self.n_workers = n_workers
        self.lr = lr
        self.d = dim
        self.worker_losses = worker_losses
        self.threshold = threshold

        self.x = np.zeros(shape=self.d)
        self.workers = [Worker(loss=worker_loss)
                        for worker_loss in self.worker_losses]

        self.name = 'GD'

        print(
            f'\n\n{self.name}:  lr:{round(self.lr, 5)}, local_steps: {self.it_local}')

        self.pbars = pbars
        self.p_bar = tqdm(total=it_max, desc=self.name, leave=False, colour='blue',
                          disable=self.pbars not in ['all', 'steps'])

    def step(self):
        loss_gap = self.loss.value(self.x) - self.loss.f_opt
        if loss_gap < self.threshold:
            self.tolerance = self.threshold
            return
        
        for i, worker in enumerate(self.workers):
            worker.x = np.copy(self.x)

        self.x = np.mean([worker.run_gd(lr=self.lr, local_steps=self.it_local)
                          for worker in tqdm(self.workers, leave=False, colour='green', disable=self.pbars != 'all')],
                         axis=0)

        self.p_bar.set_description(f'{self.name}: Loss Gap = {loss_gap:.2E}')
        self.p_bar.update()

    def init_run(self, *args, **kwargs):
        super(GD, self).init_run(*args, **kwargs)

    def update_trace(self):
        super(GD, self).update_trace()


# class ADIANA(StochasticOptimizer):
#     def __init__(self, alpha, beta, compressor, p, eta, theta_1, theta_2, n_workers=None, lr=None, dim=None, worker_losses=None, it_max=None, pbars=None,
#                  *args, **kwargs):
#         super(ADIANA, self).__init__(*args, **kwargs)
#         self.n_workers = n_workers
#         self.compressor = compressor
#         self.lr = lr
#         self.p = p
#         self.eta = eta
#         self.theta_1 = theta_1
#         self.theta_2 = theta_2
#         self.alpha = alpha
#         self.beta = beta
#         self.d = dim
#         self.worker_losses = worker_losses

#         self.h = np.zeros(shape=dim) * 1.
#         self.z = np.zeros(shape=dim) * 1.
#         self.y = np.zeros(shape=dim) * 1.
#         self.w = np.zeros(shape=dim) * 1.
#         self.x = np.zeros(shape=dim) * 1.

#         self.workers = [Worker(loss=worker_loss)
#                         for worker_loss in self.worker_losses]

#         self.name = 'ADIANA'

#         print(
#             f'\n\n{self.name}:  compressor: {compressor.name}, lr: {round(self.lr, 5)}'
#             f'  eta: {self.eta}, theta_1: {round(self.theta_1, 5)}, theta_2: {round(self.theta_2, 5)}'
#             f'  alpha: {self.alpha}, beta: {round(self.beta, 5)}, p: {round(self.p, 5)}'
#         )

#         self.pbars = pbars
#         self.p_bar = tqdm(total=it_max, desc=self.name, leave=False, colour='blue',
#                           disable=self.pbars not in ['all', 'steps'])

#     def step(self):
#         loss_gap = self.loss.value(self.x) - self.loss.f_opt
#         if loss_gap < self.threshold:
#             self.tolerance = self.threshold
#             return
        
#         self.z = self.theta_1*self.x + self.theta_2 * \
#             self.w + (1-self.theta_1-self.theta_2)*self.y

#         sum_1 = 0
#         sum_2 = 0
#         for worker in self.workers:
#             m_1, m_2 = worker.run_adiana(
#                 z=self.z, w=self.w, alpha=self.alpha, compressor=self.compressor)
#             sum_1 += m_1
#             sum_2 += m_2

#         g = sum_1/self.n_workers + self.h
#         self.h += self.alpha * sum_2/self.n_workers

#         y_prev = np.copy(self.y)
#         self.y = self.z - self.eta * g
#         self.x = self.beta * self.x + \
#             (1-self.beta)*self.z + (self.lr/self.eta)*(self.y - self.z)
#         random_number = np.random.rand()
#         self.w = y_prev if random_number < self.p else self.w

#         self.p_bar.set_description(f'{self.name}: Loss Gap = {loss_gap:.2E}')
#         self.p_bar.update()

#     def init_run(self, *args, **kwargs):
#         super(ADIANA, self).init_run(*args, **kwargs)

#     def update_trace(self):
#         super(ADIANA, self).update_trace()


class ADIANA(StochasticOptimizer):
    def __init__(self, alpha, beta, compressor, p, eta, theta_1, theta_2, n_workers=None, lr=None, dim=None, worker_losses=None, it_max=None, threshold=1e-6, pbars=None,
                 *args, **kwargs):
        super(ADIANA, self).__init__(*args, **kwargs)
        self.n_workers = n_workers
        self.compressor = compressor
        self.lr = lr
        self.p = p
        self.eta = eta
        self.theta_1 = theta_1
        self.theta_2 = theta_2
        self.alpha = alpha
        self.beta = beta
        self.d = dim
        self.worker_losses = worker_losses
        self.threshold = threshold

        self.h = np.zeros(shape=dim, dtype=np.float32) * 1.
        self.z = np.zeros(shape=dim, dtype=np.float32) * 1.
        self.y = np.zeros(shape=dim, dtype=np.float32) * 1.
        self.w = np.zeros(shape=dim, dtype=np.float32) * 1.
        self.x = np.zeros(shape=dim, dtype=np.float32) * 1.

        self.workers = [Worker(loss=worker_loss)
                        for worker_loss in self.worker_losses]

        self.name = 'ADIANA'

        print(
            f'\n\n{self.name}:  compressor: {compressor.name}, lr: {round(self.lr, 5)}'
            f'  eta: {self.eta}, theta_1: {round(self.theta_1, 5)}, theta_2: {round(self.theta_2, 5)}'
            f'  alpha: {self.alpha}, beta: {round(self.beta, 5)}, p: {round(self.p, 5)}'
        )

        self.pbars = pbars
        self.p_bar = tqdm(total=it_max, desc=self.name, leave=False, colour='blue',
                          disable=self.pbars not in ['all', 'steps'])

    def step(self):
        loss_gap = self.loss.value(self.x) - self.loss.f_opt
        if loss_gap < self.threshold or loss_gap > 100:
            self.tolerance = self.threshold
            return

        self.x = self.theta_1*self.z + self.theta_2 * \
            self.w + (1-self.theta_1-self.theta_2)*self.y

        sum_1 = 0
        sum_2 = 0
        for worker in self.workers:
            m_1, m_2 = worker.run_adiana(
                z=self.x, w=self.w, alpha=self.alpha, compressor=self.compressor)
            sum_1 += m_1
            sum_2 += m_2

        g = self.h + sum_1/self.n_workers
        self.h += self.alpha * sum_2/self.n_workers
        y_prev = np.copy(self.y)
        self.y = self.x - self.eta * g
        self.z = self.beta * self.z + \
            (1-self.beta)*self.x + self.lr/self.eta*(self.y - self.x)
        random_number = np.random.rand()
        self.w = y_prev if random_number < self.p else self.w

        self.p_bar.set_description(f'{self.name}: Loss Gap = {loss_gap:.2E}')
        self.p_bar.update()

    def init_run(self, *args, **kwargs):
        super(ADIANA, self).init_run(*args, **kwargs)

    def update_trace(self):
        super(ADIANA, self).update_trace()


class DIANA(StochasticOptimizer):
    def __init__(self, dual_lr, compressor, n_workers=None, lr=None, dim=None, worker_losses=None, it_max=None, threshold=1e-6, pbars=None,
                 *args, **kwargs):
        super(DIANA, self).__init__(*args, **kwargs)
        self.n_workers = n_workers
        self.dual_lr = dual_lr
        self.compressor = compressor
        self.lr = lr
        self.d = dim
        self.worker_losses = worker_losses
        self.threshold = threshold

        self.h = np.zeros(shape=dim, dtype=np.float32)

        self.workers = [Worker(loss=worker_loss)
                        for worker_loss in self.worker_losses]

        self.name = 'DIANA'

        print(
            f'\n\n{self.name}:  compressor: {compressor.name}, lr: {round(self.lr, 5)}')

        self.pbars = pbars
        self.p_bar = tqdm(total=it_max, desc=self.name, leave=False, colour='blue',
                          disable=self.pbars not in ['all', 'steps'])

    def step(self):
        loss_gap = self.loss.value(self.x) - self.loss.f_opt
        if loss_gap < self.threshold:
            self.tolerance = self.threshold
            return

        m = np.mean([worker.run_diana(x=self.x, dual_lr=self.dual_lr, compressor=self.compressor)
                     for worker in self.workers], axis=0)
        g = self.h + m
        self.x -= self.lr * g
        self.h += self.dual_lr * m

        self.p_bar.set_description(f'{self.name}: Loss Gap = {loss_gap:.2E}')
        self.p_bar.update()

    def init_run(self, *args, **kwargs):
        super(DIANA, self).init_run(*args, **kwargs)

    def update_trace(self):
        super(DIANA, self).update_trace()


class Scaffnew(StochasticOptimizer):
    def __init__(self, p=None, n_workers=None, lr=None, worker_losses=None, it_max=None, threshold=1e-6, pbars=None, d=None,
                 *args, **kwargs):
        super(Scaffnew, self).__init__(*args, **kwargs)
        self.n_workers = n_workers
        self.lr = lr
        self.d = d
        self.worker_losses = worker_losses
        self.p = p
        self.threshold = threshold

        self.x = np.zeros(shape=self.d)
        self.workers = [Worker(loss=worker_loss)
                        for worker_loss in self.worker_losses]

        for worker in self.workers:
            worker.h = np.zeros_like(self.x)

        self.name = 'Scaffnew'

        print(f'\n\n{self.name}:  lr:{round(self.lr, 5)}, p: {round(self.p, 5)}')
        self.pbars = pbars
        self.p_bar = tqdm(total=it_max, desc=self.name, leave=False, colour='blue',
                          disable=self.pbars not in ['all', 'steps'])

    def step(self):
        loss_gap = self.loss.value(self.x) - self.loss.f_opt
        if loss_gap < self.threshold:
            self.tolerance = self.threshold
            return
        
        rng = np.random.default_rng()
        local_steps = rng.geometric(p=self.p)

        worker_res = []
        for worker in self.workers:
            worker.x = np.copy(self.x)
            worker_res.append(worker.run_scaffnew(
                lr=self.lr, local_steps=local_steps))

        self.x = np.mean(worker_res, axis=0)

        for worker in self.workers:
            worker.update_scaffnew_variates(
                p=self.p, lr=self.lr, server_x=self.x)

        self.p_bar.set_description(f'{self.name}: Loss Gap = {loss_gap:.2E}')
        self.p_bar.update()

    def init_run(self, *args, **kwargs):
        super(Scaffnew, self).init_run(*args, **kwargs)

    def update_trace(self):
        super(Scaffnew, self).update_trace()


class FiveGCS(StochasticOptimizer):
    def __init__(self, it_local=None, n_workers=None, cohort_size=None, mu=None, lr=None, dual_lr=None, d=None,
                 worker_losses=None, it_max=None, threshold=1e-6, pbars=None, *args, **kwargs):
        super(FiveGCS, self).__init__(*args, **kwargs)
        self.it_local = it_local
        self.n_workers = n_workers
        self.cohort_size = cohort_size
        self.lr = lr
        self.dual_lr = dual_lr
        self.d = d
        self.worker_losses = worker_losses
        self.mu = mu

        self.threshold = threshold

        self.x = np.zeros(shape=self.d)
        self.x_hat = np.zeros_like(self.x)

        self.workers = []
        self.v = np.zeros_like(self.x)

        for i in range(len(self.worker_losses)):
            worker = Worker(loss=worker_losses[i])
            worker.five_gcs_u = np.zeros_like(self.x)
            self.workers.append(worker)

        self.name = '5GCS'

        print(f'\n\n{self.name}:  lr:{round(self.lr, 5)}, dual_lr: {round(self.dual_lr, 5)}, '
              f'local_steps: {self.it_local}')

        self.pbars = pbars
        self.p_bar = tqdm(total=it_max, desc=self.name, leave=False, colour='blue',
                          disable=self.pbars not in ['all', 'steps'])

    def step(self):
        loss_gap = self.loss.value(self.x) - self.loss.f_opt
        if loss_gap < self.threshold:
            self.tolerance = self.threshold
            return
        
        rng = np.random.default_rng()
        participation_mask = get_participation_mask(
            self.n_workers, self.cohort_size, rng=rng)
        self.x_hat = (1 / (1 + self.lr * self.mu)) * \
            (self.x - self.lr * self.v)

        for i, worker in enumerate(self.workers):
            if participation_mask[i]:
                worker.x = np.copy(self.x_hat)
                L_f = ((self.loss.max_smoothness - self.mu) /
                       self.n_workers) + self.dual_lr
                worker.local_solve(local_steps=self.it_local, mu=self.mu, n_workers=self.n_workers,
                                   local_solver_lr=1 / L_f, dual_lr=self.dual_lr)

        prev_v = copy.deepcopy(self.v)
        self.v = np.zeros_like(self.x)
        for i in range(len(self.workers)):
            self.v += self.workers[i].five_gcs_u

        self.x = self.x_hat - self.lr * \
            (self.n_workers / self.cohort_size) * (self.v - prev_v)

        self.p_bar.set_description(f'{self.name}: Loss Gap = {loss_gap:.2E}')
        self.p_bar.update()

    def init_run(self, *args, **kwargs):
        super(FiveGCS, self).init_run(*args, **kwargs)

    def update_trace(self):
        super(FiveGCS, self).update_trace()


class FiveGCS_CC(StochasticOptimizer):
    def __init__(self, compressor, it_local=None, n_workers=None, cohort_size=None, mu=None, lr=None, dual_lr=None, d=None,
                 worker_losses=None, it_max=None, threshold=1e-6, pbars=None, *args, **kwargs):
        super(FiveGCS_CC, self).__init__(*args, **kwargs)
        self.it_local = it_local
        self.compressor = compressor
        self.n_workers = n_workers
        self.cohort_size = cohort_size
        self.lr = lr
        self.dual_lr = dual_lr
        self.d = d
        self.worker_losses = worker_losses
        self.mu = mu
        self.threshold = threshold

        L_max = max([l.smoothness for l in worker_losses])
        self.L_f = ((L_max - mu) / n_workers) + dual_lr

        self.x = np.zeros(shape=self.d)
        self.x_hat = np.zeros_like(self.x)

        self.workers = []
        self.v = np.zeros_like(self.x)  # todo check

        for i in range(len(self.worker_losses)):
            worker = Worker(loss=worker_losses[i])
            worker.five_gcs_u = np.zeros_like(self.x)
            self.workers.append(worker)

        self.name = '5GCS-CC'

        print(f'\n\n{self.name}: compressor: {compressor.name}, lr:{round(self.lr, 5)}, dual_lr: {round(self.dual_lr, 5)}, '
              f'local_steps: {self.it_local}')

        self.pbars = pbars
        self.p_bar = tqdm(total=it_max, desc=self.name, leave=False, colour='blue',
                          disable=self.pbars not in ['all', 'steps'])

    def step(self):
        loss_gap = self.loss.value(self.x) - self.loss.f_opt
        if loss_gap < self.threshold:
            self.tolerance = self.threshold
            return
        
        rng = np.random.default_rng()
        participation_mask = get_participation_mask(
            self.n_workers, self.cohort_size, rng=rng)
        self.x_hat = (1 / (1 + self.lr * self.mu)) * \
            (self.x - self.lr * self.v)

        sum = 0
        for i, worker in enumerate(self.workers):
            if participation_mask[i]:
                worker.x = np.copy(self.x_hat)
                sum += worker.local_solve_cc(local_steps=self.it_local, mu=self.mu, n_workers=self.n_workers, cohort_size=self.cohort_size,
                                             local_solver_lr=1 / self.L_f, dual_lr=self.dual_lr, compressor=self.compressor)

        prev_v = copy.deepcopy(self.v)
        self.v += 1 / (1 + self.compressor.w) * \
            (self.cohort_size / self.n_workers) * sum

        self.x = self.x_hat - self.lr * \
            (self.n_workers / self.cohort_size) * \
            (1 + self.compressor.w) * (self.v - prev_v)
        
        self.p_bar.set_description(f'{self.name}: Loss Gap = {loss_gap:.2E}')
        self.p_bar.update()

    def init_run(self, *args, **kwargs):
        super(FiveGCS_CC, self).init_run(*args, **kwargs)

    def update_trace(self):
        super(FiveGCS_CC, self).update_trace()


class Scaffold(StochasticOptimizer):
    def __init__(self, it_local, n_workers=None, cohort_size=None, lr=None, worker_losses=None, d=None,
                 global_lr=1, it_max=None, threshold=1e-6, pbars=None, *args, **kwargs):
        super(Scaffold, self).__init__(*args, **kwargs)
        self.it_local = it_local
        self.n_workers = n_workers
        self.cohort_size = cohort_size
        self.lr = lr
        self.worker_losses = worker_losses
        self.global_lr = global_lr
        self.threshold = threshold

        self.x = np.zeros(d)
        self.server_c = np.zeros_like(self.x)

        self.workers = []
        for i in range(len(self.worker_losses)):
            worker = Worker(loss=worker_losses[i])
            worker.scaffold_local_c = np.zeros_like(self.x)
            self.workers.append(worker)

        self.name = 'Scaffold'

        print(f'\n\n{self.name}:  lr:{round(self.lr, 5)}, local_steps: {self.it_local}, '
              f'participation: {round(self.cohort_size / self.n_workers, 5)}')
        self.pbars = pbars
        self.p_bar = tqdm(total=it_max, desc=self.name, leave=False, colour='blue',
                          disable=self.pbars not in ['all', 'steps'])

    def step(self):
        loss_gap = self.loss.value(self.x) - self.loss.f_opt
        if loss_gap < self.threshold:
            self.tolerance = self.threshold
            return
        
        rng = np.random.default_rng()
        participation_mask = get_participation_mask(
            self.n_workers, self.cohort_size, rng=rng)

        for i, worker in enumerate(self.workers):
            if participation_mask[i]:
                worker.x = np.copy(self.x)

        ret = []
        worker_cs = []
        for i, worker in enumerate(self.workers):
            if participation_mask[i]:
                ret.append(worker.run_scaffold(
                    lr=self.lr, local_steps=self.it_local, server_c=np.copy(self.server_c)))
                worker_cs.append(worker.scaffold_local_c)

        assert len(ret) == self.cohort_size

        x_new = np.sum(ret, axis=0) / self.cohort_size
        c_new = np.sum(worker_cs, axis=0) / self.cohort_size

        self.x += self.global_lr * (x_new - self.x)
        self.server_c += (self.cohort_size / self.n_workers) * \
            (c_new - self.server_c)

        self.p_bar.set_description(f'{self.name}: Loss Gap = {loss_gap:.2E}')
        self.p_bar.update()

    def init_run(self, *args, **kwargs):
        super(Scaffold, self).init_run(*args, **kwargs)


class Tamuna(StochasticOptimizer):
    def __init__(self, p, s, eta, n_workers=None, cohort_size=None, lr=None, batch_size=None, worker_losses=None, d=None,
                 it_max=None, threshold=1e-6, pbars=None, *args, **kwargs):
        super(Tamuna, self).__init__(*args, **kwargs)
        self.p = p
        self.s = s
        self.eta = eta
        self.n_workers = n_workers
        self.cohort_size = cohort_size
        self.lr = lr
        self.batch_size = batch_size
        self.worker_losses = worker_losses
        self.d = d
        self.threshold = threshold

        self.x = np.zeros(self.d)

        self.workers = []
        for i in range(len(self.worker_losses)):
            worker = Worker(loss=worker_losses[i])
            worker.h = np.zeros_like(self.x)
            self.workers.append(worker)

        self.name = 'TAMUNA'

        print(f'\n\n{self.name}:  lr:{round(self.lr, 5)}, p: {round(self.p, 5)}, s: {round(self.s, 5)}, '
              f'eta: {round(self.eta, 5)}, participation: {round(self.cohort_size / self.n_workers, 5)}')

        self.pbars = pbars
        self.p_bar = tqdm(total=it_max, desc=self.name, leave=False, colour='blue',
                          disable=self.pbars not in ['all', 'steps'])

    def step(self):
        loss_gap = self.loss.value(self.x) - self.loss.f_opt
        if loss_gap < self.threshold:
            self.tolerance = self.threshold
            return
    
        rng = np.random.default_rng()
        participation_mask = get_participation_mask(
            n_workers=self.n_workers, cohort_size=self.cohort_size, rng=rng)
        local_steps = rng.geometric(p=self.p)
        q = get_communication_mask(
            self.dim, self.cohort_size, self.s, rng=self.rng)

        ret = []
        participating_worker_index = 0
        for j, worker in tqdm(list(enumerate(self.workers)), leave=False, colour='green', disable=self.pbars != 'all'):
            if participation_mask[j]:
                worker.x = np.copy(self.x)
                uncompressed_worker_result = worker.run_tamuna_local(
                    lr=self.lr, local_steps=local_steps)
                ret.append(np.multiply(uncompressed_worker_result,
                           q[:, participating_worker_index]))
                # a = np.multiply(uncompressed_worker_result, q[:, participating_worker_index])
                # print(len(a[a!=0]))
                participating_worker_index += 1

        self.x = np.sum(ret, axis=0) / self.s

        participating_worker_index = 0
        for j, worker in enumerate(self.workers):
            if participation_mask[j]:
                worker.update_tamuna_variates(lr=self.lr, eta=self.eta,
                                              server_compressed=np.multiply(
                                                  self.x, q[:, participating_worker_index]),
                                              worker_compressed=ret[participating_worker_index])
                participating_worker_index += 1

        self.p_bar.set_description(f'{self.name}: Loss Gap = {loss_gap:.2E}')
        self.p_bar.update()

    def init_run(self, *args, **kwargs):
        super(Tamuna, self).init_run(*args, **kwargs)

    def update_trace(self):
        super(Tamuna, self).update_trace()


class CompressedScaffnew(StochasticOptimizer):
    def __init__(self, p, s, eta, n_workers=None, lr=None, worker_losses=None, d=None,
                 it_max=None, threshold=1e-6, pbars=None, *args, **kwargs):
        super(CompressedScaffnew, self).__init__(*args, **kwargs)
        self.p = p
        self.s = s
        self.eta = eta
        self.n_workers = n_workers
        self.lr = lr
        self.worker_losses = worker_losses
        self.d = d
        self.threshold = threshold

        self.x = np.zeros(self.d)

        self.workers = []
        for i in range(len(self.worker_losses)):
            worker = Worker(loss=worker_losses[i])
            worker.h = np.zeros_like(self.x)
            self.workers.append(worker)

        self.name = 'CompressedScaffnew'

        print(f'\n\n{self.name}:  lr:{round(self.lr, 5)}, p: {round(self.p, 5)}, s: {round(self.s, 5)}, '
              f'eta: {round(self.eta, 5)}')

        self.pbars = pbars
        self.p_bar = tqdm(total=it_max, desc=self.name, leave=False, colour='blue',
                          disable=self.pbars not in ['all', 'steps'])

    def step(self):
        loss_gap = self.loss.value(self.x) - self.loss.f_opt
        if loss_gap < self.threshold:
            self.tolerance = self.threshold
            print('Conv?')
            return

        rng = np.random.default_rng()
        local_steps = rng.geometric(p=self.p)
        q = get_communication_mask(self.dim, self.n_workers, self.s, rng=rng)

        ret = []
        for j, worker in tqdm(list(enumerate(self.workers)), leave=False, colour='green', disable=self.pbars != 'all'):
            worker.x = np.copy(self.x)
            worker_result = worker.run_compressed_scaffnew_local(
                lr=self.lr, local_steps=local_steps, q=q[:, j])
            ret.append(worker_result)

        self.x = np.sum(ret, axis=0) / self.s

        for j, worker in enumerate(self.workers):
            worker.update_compressed_scaffnew_variates(x=np.copy(self.x), lr=self.lr, eta=self.eta, q=q[:, j],
                                                       worker_result=ret[j])

        self.p_bar.set_description(f'{self.name}: Loss Gap = {loss_gap:.2E}')
        self.p_bar.update()

    def init_run(self, *args, **kwargs):
        super(CompressedScaffnew, self).init_run(*args, **kwargs)

    def update_trace(self):
        super(CompressedScaffnew, self).update_trace()


class GradSkip(StochasticOptimizer):
    """
    Stochastic gradient descent with decreasing or constant learning rate.

    Arguments:
        lr (float, optional): an estimate of the inverse smoothness constant
        lr_decay_coef (float, optional): the coefficient in front of the number of finished iterations
            in the denominator of step-size. For strongly convex problems, a good value
            is mu/2, where mu is the strong convexity constant
        lr_decay_power (float, optional): the power to exponentiate the number of finished iterations
            in the denominator of step-size. For strongly convex problems, a good value is 1 (default: 1)
        it_start_decay (int, optional): how many iterations the step-size is kept constant
            By default, will be set to have about 2.5% of iterations with the step-size equal to lr0
        batch_size (int, optional): the number of samples from the function to be used at each iteration
    """

    def __init__(self, p, worker_losses, label='GradSkip', q=None, grad_time=None, com_time=1, cohort_size=None,
                 lr=None,
                 lr_decay_power=1, it_start_decay=None, it_max=None, threshold=1e-6, 
                 batch_size=None, pbars=None, *args, **kwargs):
        super(GradSkip, self).__init__(label=label, *args, **kwargs)
        self.p = p
        self.n_workers = len(worker_losses)
        if cohort_size is None:
            cohort_size = self.n_workers
        self.cohort_size = cohort_size
        self.lr = lr
        self.lr_decay_power = lr_decay_power
        self.it_start_decay = it_start_decay
        self.batch_size = batch_size
        self.worker_losses = worker_losses
        self.trace.hs = []
        self.grad_time = [1]*self.n_workers if grad_time is None else grad_time
        self.com_time = com_time
        self.threshold = threshold

        self.xs = {}

        qs = [1] * self.n_workers if q is None else q
        self.workers = [Worker(loss=l, q=qs[i])
                        for i, l in enumerate(self.worker_losses)]

        self.name = 'GradSkip'
        print(f'\n\n{self.name}: , p: {round(self.p, 5)}, lr: {round(self.lr, 5)}, '
              f'participation: {round(self.cohort_size / self.n_workers, 5)}'
              )
        self.pbars = pbars
        self.p_bar = tqdm(total=it_max, desc=self.name, leave=False, colour='blue',
                          disable=self.pbars not in ['all', 'steps'])

    def step(self):
        loss_gap = self.loss.value(self.x) - self.loss.f_opt
        if loss_gap < self.threshold:
            self.tolerance = self.threshold
            return

        rng = np.random.default_rng()
        local_steps = rng.geometric(p=self.p)
        participation_mask = get_participation_mask(
            self.n_workers, self.cohort_size, rng=rng)

        ret = []
        participating_worker_index = 0
        for j, worker in tqdm(list(enumerate(self.workers)), leave=False, colour='green', disable=self.pbars != 'all'):
            if participation_mask[j]:
                x = np.copy(self.x)
                worker_result = worker.run_gradskip_local(
                    x=x, p=self.p, local_steps=local_steps, lr=self.lr, batch_size=self.batch_size)
                ret.append(worker_result)
                participating_worker_index += 1

        self.x = np.mean(ret, axis=0)

        participating_worker_index = 0
        for j, worker in enumerate(self.workers):
            if participation_mask[j]:
                worker.update_gradskip_params(
                    x_av=self.x, p=self.p, lr=self.lr)
                participating_worker_index += 1

        self.p_bar.set_description(f'{self.name}: Loss Gap = {loss_gap:.2E}')
        self.p_bar.update()

    def init_run(self, *args, **kwargs):
        super(GradSkip, self).init_run(*args, **kwargs)

    def get_xs(self):
        return self.xs

    def get_workers_time(self):
        '''python list of size (n_workers, it_max)'''
        return self.workers_time

    def get_communication_time(self):
        '''python list of size it_max = #communications'''
        return np.amax(self.get_workers_time(), axis=0)

    def get_total_time(self):
        return sum(self.get_communication_time())

    def get_workers_steps(self):
        '''python list of size (n_workers, it_max)'''
        return self.workers_steps


class LoCoDL(StochasticOptimizer):
    """
    Stochastic gradient descent with decreasing or constant learning rate.

    Arguments:
        lr (float, optional): an estimate of the inverse smoothness constant
        lr_decay_coef (float, optional): the coefficient in front of the number of finished iterations
            in the denominator of step-size. For strongly convex problems, a good value
            is mu/2, where mu is the strong convexity constant
        lr_decay_power (float, optional): the power to exponentiate the number of finished iterations
            in the denominator of step-size. For strongly convex problems, a good value is 1 (default: 1)
        it_start_decay (int, optional): how many iterations the step-size is kept constant
            By default, will be set to have about 2.5% of iterations with the step-size equal to lr0
        batch_size (int, optional): the number of samples from the function to be used at each iteration
    """

    def __init__(self,
                 gamma, rho, chi,
                 loss_y,
                 compressor,
                 d,
                 p,
                 #  s,
                 loss_coeff,
                 label='LoCoDL',
                 com_time=1, cohort_size=None,
                 lr0=None, lr_max=np.inf, lr_decay_coef=0,
                 lr_decay_power=1, it_start_decay=None, it_max=None, threshold=1e-6, 
                 batch_size=None, worker_losses=None, pbars=None, *args, **kwargs):
        super(LoCoDL, self).__init__(label=label, *args, **kwargs)
        self.n_workers = len(worker_losses)
        if cohort_size is None:
            cohort_size = self.n_workers
        self.cohort_size = cohort_size
        self.lr0 = lr0
        self.lr_max = lr_max
        self.lr_decay_coef = lr_decay_coef
        self.lr_decay_power = lr_decay_power
        self.it_start_decay = it_start_decay
        self.batch_size = batch_size
        self.worker_losses = worker_losses
        self.trace.hs = []
        self.com_time = com_time
        self.compressor = compressor
        self.gamma = gamma
        self.rho = rho
        self.chi = chi
        self.threshold = threshold

        self.xs = {}
        self.ys = {}

        self.d = d
        self.p = p
        # self.s = s
        self.workers = [Worker(loss=l, loss_coeff=loss_coeff, loss_y=loss_y)
                        for l in self.worker_losses
                        ]
        self.name = 'LoCoDL'

        print(f'\n\n{self.name}:  compressor: {compressor.name}, gamma: {round(self.gamma, 5)}, p: {round(self.p, 5)}, rho: {round(self.rho, 5)},'
              f'chi: {round(self.chi, 5)}, participation: {round(self.cohort_size / self.n_workers, 5)}'
              )
        self.pbars = pbars
        self.p_bar = tqdm(total=it_max, desc=self.name, leave=False, colour='blue',
                          disable=self.pbars not in ['all', 'steps'])

    def step(self):
        loss_gap = self.loss.value(self.x) - self.loss.f_opt
        if loss_gap < self.threshold:
            self.tolerance = self.threshold
            print("CONVERGED??")
            return

        rng = np.random.default_rng()
        local_steps = rng.geometric(p=self.p)
        participation_mask = get_participation_mask(
            self.n_workers, self.cohort_size, rng=rng)

        d = []
        uncompressed_xs = []
        uncompressed_ys = []

        participating_worker_index = 0
        for j, worker in tqdm(list(enumerate(self.workers)), leave=False, colour='green', disable=self.pbars != 'all'):
            if participation_mask[j]:
                uncompressed_x, uncompressed_y = worker.run_loco_local(x=self.x,
                                                                       gamma=self.gamma,
                                                                       local_steps=local_steps,
                                                                       batch_size=self.batch_size,
                                                                       )
                uncompressed_xs.append(uncompressed_x)
                uncompressed_ys.append(uncompressed_y)
                # todo make sure to not change the original array
                d.append(self.compressor.compress(
                    uncompressed_x - uncompressed_y))
                participating_worker_index += 1

        self.x = np.mean(uncompressed_xs + uncompressed_ys,
                         axis=0)  # for evaluation purposes
        self.x = np.mean(uncompressed_ys, axis=0)  # for evaluation purposes
        self.x = uncompressed_ys[0]

        # assert (self.x == uncompressed_ys[0]).all()
        d_bar = np.mean(d, axis=0) / 2
        participating_worker_index = 0
        for j, worker in enumerate(self.workers):
            if participation_mask[j]:
                worker.update_loco_params(gamma=self.gamma, rho=self.rho, chi=self.chi, p=self.p,
                                          d=d[j], d_bar=d_bar,
                                          w=self.compressor.w)
                participating_worker_index += 1

        self.p_bar.set_description(f'{self.name}: Loss Gap = {loss_gap:.2E}')
        self.p_bar.update()

    def init_run(self, *args, **kwargs):
        super(LoCoDL, self).init_run(*args, **kwargs)

    def update_trace(self):
        super(LoCoDL, self).update_trace()

    def get_xs(self):
        return self.xs

    def get_ys(self):
        return self.ys

    def get_workers_time(self):
        '''python list of size (n_workers, it_max)'''
        return self.workers_time

    def get_communication_time(self):
        '''python list of size it_max = #communications'''
        return np.amax(self.get_workers_time(), axis=0)

    def get_total_time(self):
        return sum(self.get_communication_time())

    def get_workers_steps(self):
        '''python list of size (n_workers, it_max)'''
        return self.workers_steps
