import copy
import os
import json
import pickle
import uuid
import numpy as np
from utils import run, load_runs
from optmethods.loss import LogisticRegression
from optimizers import LoCoDL, GradSkip, Scaffold, Tamuna, CompressedScaffnew, GD, FiveGCS, FiveGCS_CC, DIANA, ADIANA, Scaffnew
from optmethods.first_order import RestNest
from matplotlib import pyplot as plt
from compressors import Rand_k, Sign_1, Identity, Natural, Natural_Rand_k

trace_len_buffer = 4
plt.rcParams['figure.dpi'] = 300


def compute_worker_losses(a, b, args, mu=0):
    np.random.seed(0)
    num_rows, _ = a.shape
    permutation_indices = np.random.permutation(range(0, num_rows))

    a_p = a[permutation_indices]
    b_p = b[permutation_indices]

    losses = []
    per_worker = num_rows // args['n_workers']
    for i in range(args['n_workers']):
        loss_i = LogisticRegression(a_p[i * per_worker:(i + 1) * per_worker],
                                    b_p[i * per_worker:(i + 1) * per_worker],
                                    l1=0, l2=mu)
        losses.append(loss_i)

    for loss in losses:
        assert loss.smoothness > loss.l2

    if mu == 0:
        L_max = max([loss.smoothness for loss in losses])
        l2 = args['regularization_factor'] * L_max

        for i in range(args['n_workers']):
            losses[i].l2 = l2
            losses[i]._smoothness = None

        for loss in losses:
            assert loss.smoothness > loss.l2

    return losses


def loco_get_losses(a, b, mu, args):
    num_rows, dim = a.shape
    if args['r'] == 1:
        args['loss_coeff'] = 1
        worker_losses_loco = compute_worker_losses(a=a, b=b, mu=mu, args=args)
        loss_y = LogisticRegression(
            np.zeros((2, dim)), np.zeros(2), l1=0, l2=mu)
    elif args['r'] == 0:
        args['loss_coeff'] = 1
        worker_losses_loco = [LogisticRegression(np.zeros((2, dim)), np.zeros(
            2), l1=0, l2=mu)] * args['cohort_size']  # todo check
        loss_y = LogisticRegression(a, b, l1=0, l2=mu)
    else:
        rM = int(num_rows * args['r'])
        # todo a bit of chiiting here, we probably lose some data here
        rM = rM // args['n_workers'] * args['n_workers']
        args['loss_coeff'] = rM / num_rows
        worker_losses_loco = compute_worker_losses(
            a=a[:rM], b=b[:rM], mu=mu * (num_rows / rM), args=args)
        loss_y = LogisticRegression(
            a[rM:], b[rM:], l1=0, l2=mu * (num_rows / (num_rows - rM)))

    return worker_losses_loco, loss_y


def load(res_path: str):
    with open(res_path, 'rb') as f:
        return pickle.load(f)


def get_save_dir_name(args):
    return f"{args['dataset']}_n{args['n_workers']}_alpha{args['downlink_factor']}_c{args['cohort_size_factor']}" + \
        f"_reg{args['regularization_factor']}_r{args['r']}"  # todo add maybe differeng g?


def run_gd(x0, loss, worker_losses, args, index=None, threshold=None, **kwargs):
    gd = GD(loss=loss, it_local=args['gd_args']['it_local'], n_workers=args['n_workers'], lr=args['gd_args']['lr'],
            worker_losses=worker_losses, trace_len=args['gd_args']['it_max'] +
            trace_len_buffer,
            threshold=threshold,
            it_max=args['gd_args']['it_max'],
            pbars=args['progress_bars'])
    gd.run(x0=x0, it_max=args['gd_args']['it_max'])
    gd.trace.compute_loss_of_iterates()

    gd_uplink_cost = args['dim']
    gd_downlink_cost = args['dim']

    with open(f'{get_save_dir_name(args)}/gd_res.{str(uuid.uuid4())[:6]}.bin', 'wb') as f:
        pickle.dump((gd.trace, gd_uplink_cost, gd_downlink_cost), f)

    return gd.trace, gd.name, gd_uplink_cost, gd_downlink_cost


def run_scaffnew(x0, loss, worker_losses, args, index=None, threshold=None, **kwargs):
    scaffnew = Scaffnew(loss=loss, worker_losses=worker_losses, n_workers=args['n_workers'],
                        lr=args['scaffnew_args']['lr'], d=args['dim'], p=args['scaffnew_args']['p'],
                        it_max=args['scaffnew_args']['it_max'], pbars=args['progress_bars'],
                        threshold=threshold,
                        trace_len=args['scaffnew_args']['it_max'] + trace_len_buffer)
    scaffnew.run(x0=x0, it_max=args['scaffnew_args']['it_max'])
    scaffnew.trace.compute_loss_of_iterates()

    scaffnew_uplink_cost = args['dim']
    scaffnew_downlink_cost = args['dim']

    with open(f'{get_save_dir_name(args)}/scaffnew_res.{str(uuid.uuid4())[:6]}.bin', 'wb') as f:
        pickle.dump((scaffnew.trace, scaffnew_uplink_cost,
                    scaffnew_downlink_cost), f)

    return scaffnew.trace, scaffnew.name, scaffnew_uplink_cost, scaffnew_downlink_cost


def run_scaffold(x0, loss, worker_losses, args, index=None, threshold=None, **kwargs):
    scaffold = Scaffold(loss=loss, n_workers=args['n_workers'], cohort_size=args['cohort_size'],
                        it_local=args['scaffold_args']['it_local'],
                        trace_len=args['scaffold_args']['it_max'] +
                        trace_len_buffer,
                        threshold=threshold,
                        lr=args['scaffold_args']['local_lr'], worker_losses=worker_losses, d=args['dim'],
                        it_max=args['scaffold_args']['it_max'], pbars=args['progress_bars'])
    scaffold.run(x0=x0, it_max=args['scaffold_args']['it_max'])
    scaffold.trace.compute_loss_of_iterates()

    scaffold_uplink_cost = 2 * args['dim']
    scaffold_downlink_cost = 2 * args['dim']

    with open(f'{get_save_dir_name(args)}/scaffold_res.{str(uuid.uuid4())[:6]}.bin', 'wb') as f:
        pickle.dump((scaffold.trace, scaffold_uplink_cost,
                    scaffold_downlink_cost), f)

    return scaffold.trace, scaffold.name, scaffold_uplink_cost, scaffold_downlink_cost


def run_gradskip(x0, loss, worker_losses, args, index=None, threshold=None, **kwargs):
    gradskip = GradSkip(loss=loss, p=args['scaffnew_args']['p'],
                        cohort_size=args['cohort_size'], lr=args['scaffnew_args']['lr'],
                        worker_losses=worker_losses,
                        trace_len=args['scaffnew_args']['it_max'] +
                        trace_len_buffer,
                        threshold=threshold,
                        it_max=args['scaffnew_args']['it_max'], pbars=args['progress_bars'])  # todo
    gradskip.run(x0=x0, it_max=args['scaffnew_args']['it_max'])
    gradskip.trace.compute_loss_of_iterates()

    gradskip_uplink_cost = args['dim']
    gradskip_downlink_cost = args['dim']

    with open(f'{get_save_dir_name(args)}/gradskip_res.{str(uuid.uuid4())[:6]}.bin', 'wb') as f:
        pickle.dump((gradskip.trace, gradskip_uplink_cost,
                    gradskip_downlink_cost), f)

    return gradskip.trace, gradskip.name, gradskip_uplink_cost, gradskip_downlink_cost


def run_locodl(x0, loss, worker_losses, args, index, threshold=None, **kwargs):  # todo worker_losses?
    compressor = set_compressor(args, 'loco_args', index)

    loco = LoCoDL(loss=loss, p=args['loco_args'][index]['p'],
                  gamma=args['loco_args'][index]['gamma'], rho=args['loco_args'][index]['rho'], chi=args['loco_args'][index]['chi'],
                  worker_losses=kwargs['worker_losses_loco'], loss_y=kwargs['loss_y'],
                  compressor=compressor,
                  d=args['dim'],
                  loss_coeff=args['loss_coeff'],
                  cohort_size=args['cohort_size'],
                  trace_len=args['loco_args'][index]['it_max'] +
                  trace_len_buffer,
                  threshold=threshold,
                  it_max=args['loco_args'][index]['it_max'], pbars=args['progress_bars'])

    loco.run(x0=x0, it_max=args['loco_args'][index]['it_max'])
    loco.trace.compute_loss_of_iterates()

    loco_uplink_cost = compressor.uplink_cost
    loco_downlink_cost = args['dim']

    lr_coef = f" lr_coeff = {kwargs['lr_coeffs'][index]}" if 'lr_coeffs' in kwargs is not None else ''
    p_coeff = f" p_coeffs = {kwargs['p_coeffs'][index]}, p={args['loco_args'][index]['p']}" if 'p_coeffs' in kwargs is not None else ''
    label = f"{loco.name}: {compressor.name}" + lr_coef + p_coeff

    with open(f'{get_save_dir_name(args)}/{label}.{str(uuid.uuid4())[:6]}.bin', 'wb') as f:
        pickle.dump((loco.trace, loco_uplink_cost, loco_downlink_cost), f)

    return loco.trace, label, loco_uplink_cost, loco_downlink_cost


def run_adiana(x0, loss, worker_losses, args, index=None, threshold=None, **kwargs):
    compressor = set_compressor(args, 'adiana_args', index)

    adiana = ADIANA(loss=loss,
                    compressor=compressor,
                    n_workers=args['n_workers'],
                    lr=args['adiana_args'][index]['lr'], alpha=args['adiana_args'][index]['alpha'],
                    beta=args['adiana_args'][index]['beta'],
                    theta_1=args['adiana_args'][index]['theta_1'],
                    theta_2=args['adiana_args'][index]['theta_2'],
                    p=args['adiana_args'][index]['p'],
                    eta=args['adiana_args'][index]['eta'],
                    worker_losses=worker_losses, trace_len=args[
                        'adiana_args'][index]['it_max'] + trace_len_buffer,
                    dim=args['dim'],
                    threshold=threshold,
                    it_max=args['adiana_args'][index]['it_max'], pbars=args['progress_bars'])
    adiana.run(x0=x0, it_max=args['adiana_args'][index]['it_max'])
    adiana.trace.compute_loss_of_iterates()

    adiana_uplink_cost = 2*compressor.uplink_cost
    adiana_downlink_cost = 2*args['dim']

    lr_coef = f" lr_coeff = {kwargs['lr_coeffs'][index]}" if 'lr_coeffs' in kwargs is not None else ''
    label = f"{adiana.name}: {compressor.name}" + lr_coef

    with open(f'{get_save_dir_name(args)}/{label}.{str(uuid.uuid4())[:6]}.bin', 'wb') as f:
        pickle.dump((adiana.trace, adiana_uplink_cost,
                    adiana_downlink_cost), f)

    return adiana.trace, label, adiana_uplink_cost, adiana_downlink_cost


def run_diana(x0, loss, worker_losses, args, index=None, threshold=None, **kwargs):
    compressor = set_compressor(args, 'diana_args', index)

    diana = DIANA(loss=loss,
                  compressor=compressor,
                  n_workers=args['n_workers'],
                  lr=args['diana_args'][index]['lr'], dual_lr=args['diana_args'][index]['dual_lr'],
                  worker_losses=worker_losses, trace_len=args[
                      'diana_args'][index]['it_max'] + trace_len_buffer,
                  dim=args['dim'],
                  threshold=threshold,
                  it_max=args['diana_args'][index]['it_max'], pbars=args['progress_bars'])
    diana.run(x0=x0, it_max=args['diana_args'][index]['it_max'])
    diana.trace.compute_loss_of_iterates()

    diana_uplink_cost = compressor.uplink_cost
    diana_downlink_cost = args['dim']

    lr_coef = f" lr_coeff = {kwargs['lr_coeffs'][index]}" if ' lr_coeffs' in kwargs is not None else ''
    label = f"{diana.name}: {compressor.name}" + lr_coef

    with open(f'{get_save_dir_name(args)}/{label}.{str(uuid.uuid4())[:6]}.bin', 'wb') as f:
        pickle.dump((diana.trace, diana_uplink_cost,
                    diana_downlink_cost), f)

    return diana.trace, label, diana_uplink_cost, diana_downlink_cost


def run_5gcs_cc(x0, loss, worker_losses, args, index=None, threshold=None, **kwargs):
    compressor = set_compressor(args, 'five_gcs_cc_args', index)

    five_gcs = FiveGCS_CC(loss=loss, it_local=args['five_gcs_cc_args'][index]['it_local'],
                          compressor=compressor,
                          n_workers=args['n_workers'], cohort_size=args['cohort_size'], mu=loss.l2,
                          lr=args['five_gcs_cc_args'][index]['lr'], dual_lr=args['five_gcs_cc_args'][index]['dual_lr'], d=args['dim'],
                          worker_losses=worker_losses, trace_len=args[
                              'five_gcs_cc_args'][index]['it_max'] + trace_len_buffer,
                          threshold=threshold,
                          it_max=args['five_gcs_cc_args'][index]['it_max'], pbars=args['progress_bars'])
    five_gcs.run(x0=x0, it_max=args['five_gcs_cc_args'][index]['it_max'])
    five_gcs.trace.compute_loss_of_iterates()

    five_gcs_uplink_cost = compressor.uplink_cost
    five_gcs_downlink_cost = args['dim']

    lr_coef = f" lr_coeff = {kwargs['lr_coeffs'][index]}" if 'lr_coeffs' in kwargs is not None else ''
    label = f"{five_gcs.name}: {compressor.name}" + lr_coef

    with open(f'{get_save_dir_name(args)}/{label}.{str(uuid.uuid4())[:6]}.bin', 'wb') as f:
        pickle.dump((five_gcs.trace, five_gcs_uplink_cost,
                    five_gcs_downlink_cost), f)

    return five_gcs.trace, label, five_gcs_uplink_cost, five_gcs_downlink_cost


def run_tamuna(x0, loss, worker_losses, args, index=None, threshold=None, **kwargs):
    tamuna = Tamuna(loss=loss, s=args['tamuna_args']['s'], p=args['tamuna_args']['p'], eta=args['tamuna_args']['eta'],
                    n_workers=args['n_workers'], cohort_size=args['cohort_size'], lr=args['tamuna_args']['lr'],
                    worker_losses=worker_losses, d=args['dim'],
                    trace_len=args['tamuna_args']['it_max'] + trace_len_buffer,
                    threshold=threshold,
                    it_max=args['tamuna_args']['it_max'], pbars=args['progress_bars'])
    tamuna.run(x0=x0, it_max=args['tamuna_args']['it_max'])
    tamuna.trace.compute_loss_of_iterates()

    tamuna_uplink_cost = np.ceil(
        (args['tamuna_args']['s'] * args['dim']) / args['cohort_size'])
    tamuna_downlink_cost = args['dim']

    label = f"{tamuna.name}: s={args['tamuna_args']['s']}"

    with open(f'{get_save_dir_name(args)}/{label}.{str(uuid.uuid4())[:6]}.bin', 'wb') as f:
        pickle.dump((tamuna.trace, tamuna_uplink_cost,
                    tamuna_downlink_cost), f)

    return tamuna.trace, label, tamuna_uplink_cost, tamuna_downlink_cost


def run_compressed_scaffnew(x0, loss, worker_losses, args, index=None, threshold=None, **kwargs):
    compressed_scaffnew = CompressedScaffnew(loss=loss, s=args['compressed_scaffnew_args'][index]['s'],
                                             p=args['compressed_scaffnew_args'][index]['p'],
                                             eta=args['compressed_scaffnew_args'][index]['eta'],
                                             n_workers=args['n_workers'], lr=args['compressed_scaffnew_args'][index]['lr'],
                                             worker_losses=worker_losses, d=args['dim'],
                                             trace_len=args['compressed_scaffnew_args'][index]['it_max'] +
                                             trace_len_buffer,
                                             threshold=threshold,
                                             it_max=args['compressed_scaffnew_args'][index]['it_max'],
                                             pbars=args['progress_bars'])
    compressed_scaffnew.run(
        x0=x0, it_max=args['compressed_scaffnew_args'][index]['it_max'])
    compressed_scaffnew.trace.compute_loss_of_iterates()

    compressed_scaffnew_uplink_cost = np.ceil((args['compressed_scaffnew_args'][index]['s'] *
                                               args['dim']) / args['n_workers'])
    compressed_scaffnew_downlink_cost = args['dim']

    lr_coef = f" lr_coeff = {kwargs['lr_coeffs'][index]}" if 'lr_coeffs' in kwargs is not None else ''
    p_coeff = f" p_coeffs = {kwargs['p_coeffs'][index]}" if 'p_coeffs' in kwargs is not None else ''
    label = f"{compressed_scaffnew.name}: s={args['compressed_scaffnew_args'][index]['s']}" + lr_coef + p_coeff

    with open(f'{get_save_dir_name(args)}/{label}.{str(uuid.uuid4())[:6]}.bin', 'wb') as f:
        pickle.dump((compressed_scaffnew.trace, compressed_scaffnew_uplink_cost,
                    compressed_scaffnew_downlink_cost), f)

    return compressed_scaffnew.trace, label, compressed_scaffnew_uplink_cost, compressed_scaffnew_downlink_cost


def run_5gcs(x0, loss, worker_losses, args, index=None, threshold=None, **kwargs):
    five_gcs = FiveGCS(loss=loss, it_local=args['five_gcs_args']['it_local'],
                       n_workers=args['n_workers'], cohort_size=args['cohort_size'], mu=loss.l2,
                       lr=args['five_gcs_args']['lr'], dual_lr=args['five_gcs_args']['dual_lr'], d=args['dim'],
                       worker_losses=worker_losses, trace_len=args['five_gcs_args']['it_max'] +
                       trace_len_buffer,
                       threshold=threshold,
                       it_max=args['five_gcs_args']['it_max'], pbars=args['progress_bars'])
    five_gcs.run(x0=x0, it_max=args['five_gcs_args']['it_max'])
    five_gcs.trace.compute_loss_of_iterates()

    five_gcs_uplink_cost = args['dim']
    five_gcs_downlink_cost = args['dim']

    with open(f'{get_save_dir_name(args)}/five_gcs_res.{str(uuid.uuid4())[:6]}.bin', 'wb') as f:
        pickle.dump((five_gcs.trace, five_gcs_uplink_cost,
                    five_gcs_downlink_cost), f)

    return five_gcs.trace, five_gcs.name, five_gcs_uplink_cost, five_gcs_downlink_cost


def set_compressor(args, alg_name, index):
    name = args[alg_name][index]['compressor'].lower()
    d = args['dim']

    if name == 'identity':
        return Identity(d=d)
    elif name == 'rand-k':
        k = set_k(args=args, alg_name=alg_name, index=index)
        return Rand_k(k=k, d=d)
    elif name == 'sign-1':
        return Sign_1(d=d)
    elif name == 'natural':
        return Natural(d=d)
    elif name == 'natural_rand-k':
        k = set_k(args=args, alg_name=alg_name, index=index)
        return Natural_Rand_k(d=d, k=k)  # todo

    raise Exception(f"Invalid compressor name - {name}")


def set_k(args, alg_name, index):
    d = args['dim']
    k = args[alg_name][index]['k']
    if k == 'tamuna':
        k = int(d * args['tamuna_args']['s'] / args['cohort_size'])
    elif k == 'best_theoretical':
        if alg_name == "five_gcs_cc_args":
            k = 1
        elif alg_name == "loco_args":
            k = np.ceil(d / args['cohort_size'])
        elif alg_name == "diana_args":
            k = 1
        elif alg_name == "adiana_args":
            k = d/4
    else:
        k = int(k)
    args[alg_name][index]['k'] = k  # todo this in the comparing function
    assert k <= d
    return max(k, 1)


def loco_set_args(args, index, L_max, mu, compressor):
    if args['loco_args'][index]['it_max'] == "tamuna":
        tamuna_uplink_cost = np.ceil(
            (args['tamuna_args']['s'] * args['dim']) / args['cohort_size'])
        args['loco_args'][index]['it_max'] = int(
            args['tamuna_args']['it_max'] * tamuna_uplink_cost / compressor.uplink_cost)

    if args['loco_args'][index]['gamma'] == "best_theoretical":
        args['loco_args'][index]['gamma'] = 2 / (L_max + mu)

    w_av = compressor.w / args['cohort_size']
    kappa = L_max / mu

    if args['loco_args'][index]['p'] == "best_theoretical":
        args['loco_args'][index]['p'] = min(
            np.sqrt((1 + compressor.w) * (1 + w_av) / kappa), 1)

    if args['loco_args'][index]['rho'] == "best_theoretical":
        args['loco_args'][index]['rho'] = 1 / (1 + w_av)

    if args['loco_args'][index]['chi'] == "best_theoretical":
        args['loco_args'][index]['chi'] = 1 / (1 + w_av)

    loco_uplink_cost = compressor.uplink_cost
    p = args['loco_args'][index]['p']
    print(
        f'LoCoDL with {compressor.name} theoretical TotalCom: {loco_uplink_cost * p * (kappa + (1+w_av)*(1+compressor.w)/p**2)}')

    return args


def adiana_set_args(args, index, L_max, mu, compressor, tuned=False):
    w = compressor.w
    n = args['n_workers']

    if args['adiana_args'][index]['it_max'] == "tamuna":
        tamuna_uplink_cost = np.ceil(
            (args['tamuna_args']['s'] * args['dim']) / args['cohort_size'])
        args['adiana_args'][index]['it_max'] = int(
            args['tamuna_args']['it_max'] * tamuna_uplink_cost / compressor.uplink_cost/2)

    if args['adiana_args'][index]['p'] == 'best_theoretical':
        args['adiana_args'][index]['p'] = min(
            max(np.sqrt(n/32/w)-1, 1) / 2/(1+w), 1)

    if args['adiana_args'][index]['eta'] == 'best_theoretical':
        args['adiana_args'][index]['eta'] = min(
            1 / (2*L_max), n/(64*w*(2*args['adiana_args'][index]['p']*(w+1) + 1)**2 * L_max))

    if args['adiana_args'][index]['theta_1'] == 'best_theoretical':
        args['adiana_args'][index]['theta_1'] = min(
            1/4, np.sqrt(args['adiana_args'][index]['eta'] * mu / args['adiana_args'][index]['p']))

    if args['adiana_args'][index]['theta_2'] == 'best_theoretical':
        args['adiana_args'][index]['theta_2'] = 1/2

    if args['adiana_args'][index]['alpha'] == 'best_theoretical':
        args['adiana_args'][index]['alpha'] = 1/(w+1)

    if args['adiana_args'][index]['lr'] == 'best_theoretical':
        args['adiana_args'][index]['lr'] = args['adiana_args'][index]['eta'] / \
            (2*(args['adiana_args'][index]['theta_1'] +
             args['adiana_args'][index]['eta']*mu)) * 3
        if tuned:
            args['adiana_args'][index]['lr'] *= 3

    if args['adiana_args'][index]['beta'] == 'best_theoretical':
        args['adiana_args'][index]['beta'] = 1 - \
            args['adiana_args'][index]['lr'] * mu

    return args


def diana_set_args(args, index, L_max, mu, compressor, tuned=False):
    w = compressor.w
    n = args['n_workers']

    if args['diana_args'][index]['it_max'] == "tamuna":
        tamuna_uplink_cost = np.ceil(
            (args['tamuna_args']['s'] * args['dim']) / args['cohort_size'])
        args['diana_args'][index]['it_max'] = int(
            args['tamuna_args']['it_max'] * tamuna_uplink_cost / compressor.uplink_cost)

    if args['diana_args'][index]['lr'] == 'best_theoretical':
        args['diana_args'][index]['lr'] = 1 / (L_max * (1 + 6*w/n))

    if args['diana_args'][index]['dual_lr'] == 'best_theoretical':
        args['diana_args'][index]['dual_lr'] = 1 / (1 + w)

    if tuned:
        args['diana_args'][index]['lr'] *= 3

    return args


def five_gcs_cc_set_args(args, index, L_max, mu, compressor, tuned=False):
    w = compressor.w
    C = args['cohort_size']
    M = args['n_workers']

    if args['five_gcs_cc_args'][index]['it_max'] == "tamuna":
        tamuna_uplink_cost = np.ceil(
            (args['tamuna_args']['s'] * args['dim']) / args['cohort_size'])
        args['five_gcs_cc_args'][index]['it_max'] = int(
            args['tamuna_args']['it_max'] * tamuna_uplink_cost / compressor.uplink_cost)

    if args['five_gcs_cc_args'][index]["dual_lr"] == "best_theoretical":
        args['five_gcs_cc_args'][index]["dual_lr"] = (
            8 / 3) * np.sqrt(mu * L_max * (w+1)/C*1/(M*(1 + w/C)))
    if args['five_gcs_cc_args'][index]["lr"] == "same_as_tamuna":
        args['five_gcs_cc_args'][index]["lr"] = args["tamuna_args"]['lr']
    elif args['five_gcs_cc_args'][index]["lr"] == "best_theoretical":
        args['five_gcs_cc_args'][index]["lr"] = 1 / \
            (2*args['five_gcs_cc_args'][index]["dual_lr"] * M * (1+w/C))
    if args['five_gcs_cc_args'][index]["it_local"] == "best_theoretical":
        args['five_gcs_cc_args'][index]["it_local"] = int(((3 / 4) *
                                                           np.sqrt((args['cohort_size'] * L_max) /
                                                                   (args['n_workers'] * mu)) + 2) *
                                                          np.log(4 * L_max / mu))

    if tuned:
        args['five_gcs_cc_args'][index]["lr"] *= 3

    return args


def tamuna_set_args(args, L_max, l2, kappa):
    if args['tamuna_args']['s'] == "best_theoretical":
        args['tamuna_args']['s'] = int(np.max([2, np.floor(args['cohort_size'] / args['dim']),
                                               np.floor(args['downlink_factor'] * args['cohort_size'])]))

    if args['tamuna_args']['p'] == "best_theoretical":
        args['tamuna_args']['p'] = np.min(
            [np.sqrt(args['n_workers'] / (args['tamuna_args']['s'] * kappa)), 1])

    if args['tamuna_args']['lr'] == "best_theoretical":
        args['tamuna_args']['lr'] = 2 / (L_max + l2)
        # args['tamuna_args']['lr'] = 2 / (loss.max_smoothness + loss.l2)

    if args['tamuna_args']['eta'] == "best_theoretical":
        args['tamuna_args']['eta'] = args['tamuna_args']['p'] * ((args['n_workers'] * (args['tamuna_args']['s'] - 1)) /
                                                                 (args['tamuna_args']['s'] * (args['n_workers'] - 1)))

    tamuna_uplink_cost = np.ceil(
        (args['tamuna_args']['s'] * args['dim']) / args['cohort_size'])
    s = args['tamuna_args']['s']
    p = args['tamuna_args']['p']
    n = args['n_workers']
    s = args['tamuna_args']['s']
    print(
        f'tamuna theoretical TotalCom: {tamuna_uplink_cost * p * (kappa + n / (s*p**2))}')

    return args


def scaffold_set_args(args, L_max):
    if args['scaffold_args']['it_max'] == "tamuna":
        tamuna_uplink_cost = np.ceil(
            (args['tamuna_args']['s'] * args['dim']) / args['cohort_size'])
        scaffold_uplink_cost = 2 * args['dim']
        args['scaffold_args']['it_max'] = int(
            args['tamuna_args']['it_max'] * tamuna_uplink_cost / scaffold_uplink_cost)

    if args['scaffold_args']['local_lr'] == 'same_as_tamuna':
        args['scaffold_args']['local_lr'] = args['tamuna_args']['lr']
    elif args['scaffold_args']['local_lr'] == 'best_theoretical':
        args['scaffold_args']['local_lr'] = 1 / \
            (L_max * args['scaffold_args']['it_local'])
    return args


def compressed_scaffnew_set_args(args, L_max, mu, kappa, index=0, tuned=False):
    if args['compressed_scaffnew_args'][index]['s'] == "best_theoretical":
        args['compressed_scaffnew_args'][index]['s'] = int(np.max([2, np.floor(args['n_workers'] / args['dim']),
                                                                   np.floor(args['downlink_factor'] * args['n_workers'])]))

    if args['compressed_scaffnew_args'][index]['it_max'] == "tamuna":
        tamuna_uplink_cost = np.ceil(
            (args['tamuna_args']['s'] * args['dim']) / args['cohort_size'])
        compressed_scaffnew_uplink_cost = np.ceil((args['compressed_scaffnew_args'][index]['s'] *
                                                   args['dim']) / args['n_workers'])
        args['compressed_scaffnew_args'][index]['it_max'] = int(
            args['tamuna_args']['it_max'] * tamuna_uplink_cost / compressed_scaffnew_uplink_cost)  # todo

    if args['compressed_scaffnew_args'][index]['p'] == "best_theoretical":
        args['compressed_scaffnew_args'][index]['p'] = np.min(
            [np.sqrt(args['n_workers'] / (args['compressed_scaffnew_args'][index]['s'] * kappa)), 1])

    if args['compressed_scaffnew_args'][index]['lr'] == "best_theoretical":
        args['compressed_scaffnew_args'][index]['lr'] = 2 / (L_max + mu)

    if args['compressed_scaffnew_args'][index]['eta'] == "best_theoretical":
        args['compressed_scaffnew_args'][index]['eta'] = args['compressed_scaffnew_args'][index]['p'] * (
            (args['n_workers'] * (args['compressed_scaffnew_args'][index]['s'] - 1)) /
            (args['compressed_scaffnew_args'][index]['s'] * (args['n_workers'] - 1)))

    if tuned:
        args['compressed_scaffnew_args'][index]['lr'] /= 2
    return args


def five_gcs_set_args(args, L_max, mu):
    if args['five_gcs_args']['it_max'] == "tamuna":
        tamuna_uplink_cost = np.ceil(
            (args['tamuna_args']['s'] * args['dim']) / args['cohort_size'])
        five_gcs_uplink_cost = args['dim']
        args['five_gcs_args']['it_max'] = int(
            args['tamuna_args']['it_max'] * tamuna_uplink_cost / five_gcs_uplink_cost)

    if args['five_gcs_args']["lr"] == "same_as_tamuna":
        args['five_gcs_args']["lr"] = args["tamuna_args"]['lr']
    elif args['five_gcs_args']["lr"] == "best_theoretical":
        args['five_gcs_args']["lr"] = (3 / 16) * \
            np.sqrt(args['cohort_size'] / (L_max * mu * args['n_workers']))
    if args['five_gcs_args']["dual_lr"] == "best_theoretical":
        args['five_gcs_args']["dual_lr"] = 1 / \
            (2 * args['five_gcs_args']["lr"] * args['n_workers'])
    if args['five_gcs_args']["it_local"] == "best_theoretical":
        args['five_gcs_args']["it_local"] = int(((3 / 4) *
                                                 np.sqrt((args['cohort_size'] * L_max) /
                                                         (args['n_workers'] * mu)) + 2) *
                                                np.log(4 * L_max / mu))
    return args


def gd_set_args(args, L_max, mu):
    if args['gd_args']['lr'] == 'best_theoretical':
        args['gd_args']['lr'] = 2 / (L_max + mu)
    return args


def gradskip_set_args(args, L_max, mu, kappa):
    if args['gradskip_args']['it_max'] == "tamuna":
        tamuna_uplink_cost = np.ceil(
            (args['tamuna_args']['s'] * args['dim']) / args['cohort_size'])
        scaffnew_uplink_cost = args['dim']
        args['gradskip_args']['it_max'] = int(
            args['tamuna_args']['it_max'] * tamuna_uplink_cost / scaffnew_uplink_cost)

    if args['gradskip_args']['lr'] == 'best_theoretical':
        args['gradskip_args']['lr'] = 2 / (L_max + mu)
    if args['gradskip_args']['p'] == 'best_theoretical':
        args['gradskip_args']['p'] = min(1 / np.sqrt(kappa), 1)
    return args


def scaffnew_set_args(args, L_max, mu, kappa):
    if args['scaffnew_args']['it_max'] == "tamuna":
        tamuna_uplink_cost = np.ceil(
            (args['tamuna_args']['s'] * args['dim']) / args['cohort_size'])
        scaffnew_uplink_cost = args['dim']
        args['scaffnew_args']['it_max'] = int(
            args['tamuna_args']['it_max'] * tamuna_uplink_cost / scaffnew_uplink_cost)

    if args['scaffnew_args']['lr'] == 'best_theoretical':
        args['scaffnew_args']['lr'] = 2 / (L_max + mu)
    if args['scaffnew_args']['p'] == 'best_theoretical':
        args['scaffnew_args']['p'] = min(1 / np.sqrt(kappa), 1)
    return args


def grid_search(a, b, args):
    if args['load_directory'] is not None:
        load_runs(args)
        return

    num_rows, dim = a.shape

    worker_losses = compute_worker_losses(a=a, b=b, args=args, mu=0)
    loss = LogisticRegression(a, b, l1=0, l2=worker_losses[0].l2)

    args['dim'] = dim
    args['cohort_size'] = int(args['cohort_size_factor'] * args['n_workers'])

    L_max = max([l.smoothness for l in worker_losses])
    kappa = L_max / loss.l2

    print(f"L_max: {L_max}")
    print(f"mu: {loss.l2}")
    print(f'Kappa: {kappa}')

    args = tamuna_set_args(args=args, L_max=L_max, l2=loss.l2, kappa=kappa)
    args = scaffold_set_args(args=args, L_max=L_max)
    args = gradskip_set_args(args=args, L_max=L_max, mu=loss.l2, kappa=kappa)
    args = scaffnew_set_args(args=args, L_max=L_max, mu=loss.l2, kappa=kappa)
    args = compressed_scaffnew_set_args(
        args=args, L_max=L_max, mu=loss.l2, kappa=kappa)
    args = five_gcs_set_args(args=args, L_max=L_max, mu=loss.l2)
    args = gd_set_args(args=args, L_max=L_max, mu=loss.l2)

    for index in range(len(args['five_gcs_cc_args'])):
        compressor = set_compressor(args, 'five_gcs_cc_args', index)
        args = five_gcs_cc_set_args(
            args=args, index=index, L_max=L_max, mu=loss.l2, compressor=compressor)

    for index in range(len(args['diana_args'])):
        compressor = set_compressor(args, 'diana_args', index)
        args = diana_set_args(
            args=args, index=index, L_max=L_max, mu=loss.l2, compressor=compressor)

    for index in range(len(args['adiana_args'])):
        compressor = set_compressor(args, 'adiana_args', index)
        args = adiana_set_args(
            args=args, index=index, L_max=L_max, mu=loss.l2, compressor=compressor)

    worker_losses_loco, loss_y = loco_get_losses(
        a=a, b=b, mu=loss.l2 / 2, args=args)
    L_max = max([l.smoothness for l in worker_losses_loco])
    for index in range(len(args['loco_args'])):
        compressor = set_compressor(args, 'loco_args', index)
        args = loco_set_args(args=args, index=index, L_max=max(
            L_max, loss_y.smoothness), mu=loss.l2 / 2, compressor=compressor)  # todo change the max smoothness

    # worker_losses = worker_losses_loco + loss_y

    save_dir = get_save_dir_name(args=args)
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)
    with open(f'{save_dir}/args.json', 'w') as f:
        json.dump(args, f, indent=4)

    if not os.path.exists(f"{save_dir}/f_star.txt"):
        print('Finding reference optimum...')
        f_star = find_reference_optimum(dim=args['dim'], loss=loss)
        with open(f'{save_dir}/f_star.txt', 'w') as f:
            f.write(str(f_star))
    else:
        with open(f'{save_dir}/f_star.txt', 'r') as f:
            f_star = float(f.readline())
        loss.f_opt = f_star

    loco_runs = [(run_locodl, args, i) for i in range(len(args['loco_args']))]
    five_gcs_cc_runs = [(run_5gcs_cc, args, i)
                        for i in range(len(args['five_gcs_cc_args']))]
    diana_runs = [(run_diana, args, i)
                  for i in range(len(args['diana_args']))]
    adiana_runs = [(run_adiana, args, i)
                   for i in range(len(args['adiana_args']))]
    loco_runs = loco_runs[1:3]
    # loco_runs = []
    five_gcs_cc_runs = five_gcs_cc_runs[0:1]
    five_gcs_cc_runs = []
    diana_runs = diana_runs[0:1]
    diana_runs = []
    adiana_runs = adiana_runs[0:12]
    adiana_runs = []
    runs = adiana_runs + diana_runs + loco_runs + \
        [
            # (run_compressed_scaffnew, args, None),
            # (run_gradskip, args, None),
            # (run_scaffnew, args, None),
            # (run_scaffold, args, None),
        ] + five_gcs_cc_runs

    index = 1
    runs = [(run_locodl, args, index)]
    i = index + 1
    lr_coeffs = {index: 1}
    p_coeffs = {index: 1}
    # for k in [1] + list(range(10, 121, 10)):
    for k in [2]:
        # for lr_coeff in [1, 2, 3, 0.5]:
        for lr_coeff in [1, 1.5, 2]:
            for p_coeff in [1, 0.5, 0.25, 0.1, 2]:
                # for p_coeff in [0.5, 0.25]:
                # for k in [20,30]:
                #     for lr_coeff in [1, 2, 3]:
                new_args = copy.deepcopy(args['loco_args'][index])
                new_args['k'] = k
                new_args['gamma'] = lr_coeff * new_args['gamma']
                new_args['p'] = p_coeff * new_args['p']

                args['loco_args'].insert(i, new_args)
                runs.append((run_locodl, args, i))
                lr_coeffs[i] = lr_coeff
                p_coeffs[i] = p_coeff
                i += 1
    for index in range(len(args['loco_args'])):
        compressor = set_compressor(args, 'loco_args', index)
        args = loco_set_args(args=args, index=index, L_max=max(
            L_max, loss_y.smoothness), mu=loss.l2 / 2, compressor=compressor)  # todo change the max smoothness

    run(runs=runs, loss=loss, worker_losses=worker_losses,
        dim=args['dim'],
        threshold=1e-4,
        downlink_factor=args['downlink_factor'], uplink_factor=args['uplink_factor'],
        save_dir=args['load_directory'] if args['load_directory'] is not None else get_save_dir_name(
            args),
        f_star=f_star, n_repeats=args['n_repeats'],
        plt_title=f"Dataset {args['dataset']}, {args['n_workers']} workers, "
                  f"{round(args['cohort_size_factor'] * 100, 2)}% participation, "
                  f"alpha = {args['downlink_factor']}, "
                  f"r = {args['r']}",  # todo r?
                  worker_losses_loco=worker_losses_loco, loss_y=loss_y,
        lr_coeffs=lr_coeffs,
        p_coeffs=p_coeffs
        )


def compare_optimizers(a, b, args):
    if args['load_directory'] is not None:
        load_runs(args)
        return

    num_rows, dim = a.shape

    worker_losses = compute_worker_losses(a=a, b=b, args=args, mu=0)
    loss = LogisticRegression(a, b, l1=0, l2=worker_losses[0].l2)

    args['dim'] = dim
    args['cohort_size'] = int(args['cohort_size_factor'] * args['n_workers'])

    L_max = max([l.smoothness for l in worker_losses])
    kappa = L_max / loss.l2

    print(f"L_max: {L_max}")
    print(f"mu: {loss.l2}")
    print(f'Kappa: {kappa}')

    args = tamuna_set_args(args=args, L_max=L_max, l2=loss.l2, kappa=kappa)
    args = scaffold_set_args(args=args, L_max=L_max)
    args = gradskip_set_args(args=args, L_max=L_max, mu=loss.l2, kappa=kappa)
    args = scaffnew_set_args(args=args, L_max=L_max, mu=loss.l2, kappa=kappa)
    args = compressed_scaffnew_set_args(
        args=args, L_max=L_max, mu=loss.l2, kappa=kappa)
    args = five_gcs_set_args(args=args, L_max=L_max, mu=loss.l2)
    args = gd_set_args(args=args, L_max=L_max, mu=loss.l2)

    for index in range(len(args['five_gcs_cc_args'])):
        compressor = set_compressor(args, 'five_gcs_cc_args', index)
        args = five_gcs_cc_set_args(
            args=args, index=index, L_max=L_max, mu=loss.l2, compressor=compressor,
            tuned=False
        )

    for index in range(len(args['diana_args'])):
        compressor = set_compressor(args, 'diana_args', index)
        args = diana_set_args(
            args=args, index=index, L_max=L_max, mu=loss.l2, compressor=compressor,
            tuned=False
            )

    for index in range(len(args['adiana_args'])):
        compressor = set_compressor(args, 'adiana_args', index)
        args = adiana_set_args(
            args=args, index=index, L_max=L_max, mu=loss.l2, compressor=compressor,
            tuned=False
        )

    worker_losses_loco, loss_y = loco_get_losses(
        a=a, b=b, mu=loss.l2 / 2, args=args)
    L_max = max([l.smoothness for l in worker_losses_loco])
    for index in range(len(args['loco_args'])):
        compressor = set_compressor(args, 'loco_args', index)
        args = loco_set_args(args=args, index=index, L_max=max(
            L_max, loss_y.smoothness), mu=loss.l2 / 2, compressor=compressor)  # todo change the max smoothness

    save_dir = get_save_dir_name(args=args)
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)
    with open(f'{save_dir}/args.json', 'w') as f:
        json.dump(args, f, indent=4)

    if not os.path.exists(f"{save_dir}/f_star.txt"):
        print('Finding reference optimum...')
        f_star = find_reference_optimum(dim=args['dim'], loss=loss)
        with open(f'{save_dir}/f_star.txt', 'w') as f:
            f.write(str(f_star))
    else:
        with open(f'{save_dir}/f_star.txt', 'r') as f:
            f_star = float(f.readline())
        loss.f_opt = f_star

    # kwargs = {}
    # kwargs['worker_losses_loco'] = worker_losses_loco
    # kwargs['loss_y'] = loss_y
    # kwargs['compressor'] = compressor

    # if args['load_directory'] is not None:
    #     load_runs(args)
    #     return
    # else:
    loco_runs = [(run_locodl, args, i) for i in range(len(args['loco_args']))]
    five_gcs_cc_runs = [(run_5gcs_cc, args, i)
                        for i in range(len(args['five_gcs_cc_args']))]
    diana_runs = [(run_diana, args, i)
                  for i in range(len(args['diana_args']))]
    adiana_runs = [(run_adiana, args, i)
                   for i in range(len(args['adiana_args']))]
    loco_runs = loco_runs[3:]
    # loco_runs = []
    # five_gcs_cc_runs = five_gcs_cc_runs[3:]
    five_gcs_cc_runs = []
    # diana_runs = diana_runs[1:2]
    diana_runs = []
    # adiana_runs = adiana_runs[0:2] + adiana_runs[3:]
    adiana_runs = []

    runs = adiana_runs + diana_runs + loco_runs + \
        [
            # (run_compressed_scaffnew, args, 0),
            # (run_gradskip, args, None),
            # (run_scaffold, args, None),
        ] + five_gcs_cc_runs

    run(runs=runs, loss=loss, worker_losses=worker_losses,
        dim=args['dim'],
        threshold=2e-6,
        downlink_factor=args['downlink_factor'], uplink_factor=args['uplink_factor'],
        save_dir=args['load_directory'] if args['load_directory'] is not None else get_save_dir_name(
            args),
        f_star=f_star, n_repeats=args['n_repeats'],
        plt_title=f"Dataset {args['dataset']}, {args['n_workers']} workers, "
                  f"{round(args['cohort_size_factor'] * 100, 2)}% participation, "
                  f"alpha = {args['downlink_factor']}, "
                  f"r = {args['r']}",  # todo r?
                  worker_losses_loco=worker_losses_loco, loss_y=loss_y  # todo
        )


def find_reference_optimum(dim, loss):
    rest = RestNest(loss=loss, doubling=True)
    rest_tr = rest.run(x0=np.zeros(dim, dtype=np.float32), it_max=500000)
    rest_tr.compute_loss_of_iterates()
    rest_tr.plot_losses()
    plt.yscale('log')
    plt.show()
    return loss.f_opt
