import numpy as np
from tqdm import tqdm
from worker import Worker, Method
from optmethods.optimizer import StochasticOptimizer


class DistributedSgd(StochasticOptimizer):
    def __init__(self, it_local, n_workers=None, lr=None, batch_size=None, worker_losses=None, it_max=None, pbars=None,
                 *args, **kwargs):
        super(DistributedSgd, self).__init__(*args, **kwargs)
        self.it_local = it_local
        self.n_workers = n_workers
        self.lr = lr
        self.batch_size = batch_size
        self.worker_losses = worker_losses

        self.uplink_communicated_numbers = [0]
        self.uplink_communicated_up_to_now = 0
        self.downlink_communicated_numbers = [0]
        self.downlink_communicated_up_to_now = 0

        self.x = None
        self.workers = None

        self.name = 'GD'

        print(f'\n\n{self.name}:  lr:{round(self.lr, 5)}, local_steps: {self.it_local}')

        self.pbars = pbars
        self.p_bar = tqdm(total=it_max, desc=self.name, leave=False, colour='blue',
                          disable=self.pbars not in ['all', 'steps'])

    def step(self):
        self.x = np.mean([worker.run_local(x=self.x, lr=self.lr, local_steps=self.it_local)
                          for worker in tqdm(self.workers, leave=False, colour='green', disable=self.pbars != 'all')],
                         axis=0)

        self.uplink_communicated_numbers.append(self.uplink_communicated_up_to_now + self.dim)
        self.uplink_communicated_up_to_now = self.uplink_communicated_numbers[-1]
        self.downlink_communicated_numbers.append(self.downlink_communicated_up_to_now + self.dim)
        self.downlink_communicated_up_to_now = self.downlink_communicated_numbers[-1]

        self.p_bar.update()

    def init_run(self, *args, **kwargs):
        super(DistributedSgd, self).init_run(*args, **kwargs)
        self.workers = [Worker(method=Method.Gd, loss=worker_loss, batch_size=self.batch_size)
                        for worker_loss in self.worker_losses]

    def update_trace(self):
        super(DistributedSgd, self).update_trace()


class Scaffold(StochasticOptimizer):
    def __init__(self, it_local, n_workers=None, lr=None, batch_size=None, worker_losses=None, global_lr=1.,
                 it_max=None, pbars=None, *args, **kwargs):
        super(Scaffold, self).__init__(*args, **kwargs)
        self.it_local = it_local
        self.n_workers = n_workers
        self.cohort_size = n_workers
        self.lr = lr
        self.batch_size = batch_size
        self.worker_losses = worker_losses
        self.global_lr = global_lr

        self.uplink_communicated_numbers = [0]
        self.uplink_communicated_up_to_now = 0
        self.downlink_communicated_numbers = [0]
        self.downlink_communicated_up_to_now = 0

        self.c = None
        self.workers = None

        self.name = 'Scaffold'

        print(f'\n\n{self.name}:  lr:{round(self.lr, 5)}, local_steps: {self.it_local}')
        self.pbars = pbars
        self.p_bar = tqdm(total=it_max, desc=self.name, leave=False, colour='blue',
                          disable=self.pbars not in ['all', 'steps'])

    def step(self):
        x_new = np.mean([worker.run_local(x=self.x, lr=self.lr, local_steps=self.it_local, scaffold_c=self.c)
                         for worker in self.workers], axis=0)
        c_new = np.mean([worker.c for worker in self.workers], axis=0)

        self.x += self.global_lr * (x_new - self.x)
        self.c += self.cohort_size / self.n_workers * (c_new - self.c)

        self.uplink_communicated_numbers.append(self.uplink_communicated_up_to_now + 2 * self.dim)
        self.uplink_communicated_up_to_now = self.uplink_communicated_numbers[-1]
        self.downlink_communicated_numbers.append(self.downlink_communicated_up_to_now + 2 * self.dim)
        self.downlink_communicated_up_to_now = self.downlink_communicated_numbers[-1]

        self.p_bar.update()

    def init_run(self, *args, **kwargs):
        super(Scaffold, self).init_run(*args, **kwargs)
        self.c = np.zeros_like(self.x)
        self.workers = [Worker(method=Method.Scaffold, loss=worker_loss, batch_size=self.batch_size)
                        for worker_loss in self.worker_losses]


class Scaffnew(StochasticOptimizer):
    def __init__(self, p=None, n_workers=None, lr=None, batch_size=None, worker_losses=None, it_max=None, pbars=None,
                 *args, **kwargs):
        super(Scaffnew, self).__init__(*args, **kwargs)
        self.n_workers = n_workers
        self.lr = lr
        self.batch_size = batch_size
        self.worker_losses = worker_losses
        self.p = p

        self.uplink_communicated_numbers = [0]
        self.uplink_communicated_up_to_now = 0
        self.downlink_communicated_numbers = [0]
        self.downlink_communicated_up_to_now = 0

        self.x = None
        self.workers = None

        self.name = 'Scaffnew'

        print(f'\n\n{self.name}:  lr:{round(self.lr, 5)}, p: {round(self.p, 5)}')
        self.pbars = pbars
        self.p_bar = tqdm(total=it_max, desc=self.name, leave=False, colour='blue',
                          disable=self.pbars not in ['all', 'steps'])

    def step(self):
        local_steps = np.random.geometric(p=self.p)
        self.x = np.mean([worker.run_local(x=self.x, lr=self.lr, p=self.p, local_steps=local_steps)
                          for worker in tqdm(self.workers, leave=False, colour='green', disable=self.pbars != 'all')],
                         axis=0)

        self.uplink_communicated_numbers.append(self.uplink_communicated_up_to_now + self.dim)
        self.uplink_communicated_up_to_now = self.uplink_communicated_numbers[-1]
        self.downlink_communicated_numbers.append(self.downlink_communicated_up_to_now + self.dim)
        self.downlink_communicated_up_to_now = self.downlink_communicated_numbers[-1]

        self.p_bar.update()

    def init_run(self, *args, **kwargs):
        super(Scaffnew, self).init_run(*args, **kwargs)
        self.workers = [Worker(method=Method.Scaffnew, loss=worker_loss, batch_size=self.batch_size)
                        for worker_loss in self.worker_losses]

    def update_trace(self):
        super(Scaffnew, self).update_trace()


class CompressedProxSkip(StochasticOptimizer):
    def __init__(self, p, s, eta, n_workers=None, lr=None, batch_size=None, worker_losses=None, it_max=None, pbars=None,
                 *args, **kwargs):
        super(CompressedProxSkip, self).__init__(*args, **kwargs)
        self.p = p
        self.s = s
        self.eta = eta
        self.n_workers = n_workers
        self.lr = lr
        self.batch_size = batch_size
        self.worker_losses = worker_losses

        self.uplink_communicated_numbers = [0]
        self.uplink_communicated_up_to_now = 0
        self.downlink_communicated_numbers = [0]
        self.downlink_communicated_up_to_now = 0

        self.x = None
        self.workers = None

        self.name = 'CompressedScaffnew'

        print(f'\n\n{self.name}:  lr:{round(self.lr, 5)}, p: {round(self.p, 5)}, s: {round(self.s, 5)}, '
              f'eta: {round(self.eta, 5)}')

        self.pbars = pbars
        self.p_bar = tqdm(total=it_max, desc=self.name, leave=False, colour='blue',
                          disable=self.pbars not in ['all', 'steps'])

    @staticmethod
    def create_random_sampling_pattern(dimension: int, num_workers: int, ones_per_row: int):
        q = np.zeros(shape=(dimension, num_workers))
        if dimension >= num_workers / ones_per_row:
            k = 0
            for i in range(dimension):
                for j in range(ones_per_row):
                    q[i, k] = 1
                    k = (k + 1) % num_workers
        else:
            k = 0
            for j in range(ones_per_row):
                for i in range(dimension):
                    q[i, k] = 1
                    k += 1
        return q

    @staticmethod
    def create_permutated_pattern(dimension: int, num_workers: int, ones_per_row: int):
        if ones_per_row == num_workers:
            return np.ones(shape=(dimension, num_workers))
        return np.random.permutation(
            CompressedProxSkip.create_random_sampling_pattern(dimension, num_workers, ones_per_row).T).T

    def step(self):
        local_steps = np.random.geometric(p=self.p)
        q = CompressedProxSkip.create_permutated_pattern(self.dim, self.n_workers, self.s)

        ret = [worker.run_local(x=self.x, lr=self.lr, p=self.p, local_steps=local_steps, q=q[:, j], eta=self.eta)
               for j, worker in tqdm(list(enumerate(self.workers)), leave=False, colour='green',
                                     disable=self.pbars != 'all')]

        self.x = np.sum(ret, axis=0) / self.s

        max_communicated_by_any_worker = np.max(np.sum(q, axis=0))
        assert max_communicated_by_any_worker == np.ceil((self.s * self.dim) / self.n_workers)
        self.uplink_communicated_numbers.append(self.uplink_communicated_up_to_now + max_communicated_by_any_worker)
        self.uplink_communicated_up_to_now = self.uplink_communicated_numbers[-1]
        self.downlink_communicated_numbers.append(self.downlink_communicated_up_to_now + self.dim)
        self.downlink_communicated_up_to_now = self.downlink_communicated_numbers[-1]

        self.p_bar.update()

    def init_run(self, *args, **kwargs):
        super(CompressedProxSkip, self).init_run(*args, **kwargs)
        self.workers = [
            Worker(method=Method.CompressedScaffnew, loss=worker_loss, batch_size=self.batch_size)
            for worker_loss in self.worker_losses]

    def update_trace(self):
        super(CompressedProxSkip, self).update_trace()
