import os
import json
import copy
import pickle
import numpy as np
from utils import run
from optmethods.loss import LogisticRegression
from optimizers import DistributedSgd, Scaffold, Scaffnew, CompressedProxSkip


def compute_worker_losses(a, b, l2, args):
    num_rows, _ = a.shape
    permutation_indices = np.random.permutation(range(0, num_rows))

    a_p = a[permutation_indices]
    b_p = b[permutation_indices]

    losses = []
    per_worker = num_rows // args.n_workers
    for i in range(args.n_workers):
        loss_i = LogisticRegression(a_p[i * per_worker:(i + 1) * per_worker],
                                    b_p[i * per_worker:(i + 1) * per_worker],
                                    l1=0, l2=l2)
        losses.append(loss_i)
    return losses


def get_loss(a, b, args):
    loss = LogisticRegression(a, b, l1=0, l2=0)
    loss.l2 = args.reg * loss.max_smoothness
    return loss


def load(res_path: str):
    with open(res_path, 'rb') as f:
        return pickle.load(f)


def get_save_dir_name(args):
    return f'{args.dataset}_n{args.n_workers}_c{args.c_downlink}_it{args.it_max}_reg{args.reg}'


def run_distributed_gd(x0, loss, worker_losses, args):
    save_dir = get_save_dir_name(args)

    if not os.path.exists(save_dir):
        os.mkdir(save_dir)

    with open(f'{save_dir}/gd_args.json', 'wt') as f:
        json.dump(vars(args), f, indent=4)

    gd = DistributedSgd(loss=loss, it_local=1, n_workers=args.n_workers, lr=args.lr, batch_size=None,
                        worker_losses=worker_losses, trace_len=args.it_max + 4, it_max=args.it_max, pbars=args.pbars)
    gd.run(x0=x0, it_max=args.it_max)
    gd.trace.compute_loss_of_iterates()

    with open(f'{save_dir}/gd_res.bin', 'wb') as f:
        pickle.dump((gd.trace, gd.uplink_communicated_numbers,
                     gd.downlink_communicated_numbers, loss.f_opt), f)

    return gd.trace, gd.uplink_communicated_numbers, gd.downlink_communicated_numbers, loss.f_opt


def run_scaffold(x0, loss, worker_losses, args):
    save_dir = get_save_dir_name(args)

    if not os.path.exists(save_dir):
        os.mkdir(save_dir)

    with open(f'{save_dir}/scaffold_args.json', 'wt') as f:
        json.dump(vars(args), f, indent=4)

    scaffold = Scaffold(loss=loss, n_workers=args.n_workers, it_local=args.it_local, trace_len=args.it_max + 4,
                        lr=args.lr / args.it_local, batch_size=args.batch_size, worker_losses=worker_losses,
                        it_max=args.it_max, pbars=args.pbars)
    scaffold.run(x0=x0, it_max=args.it_max)
    scaffold.trace.compute_loss_of_iterates()

    with open(f'{save_dir}/scaffold_res.bin', 'wb') as f:
        pickle.dump((scaffold.trace, scaffold.uplink_communicated_numbers,
                     scaffold.downlink_communicated_numbers, loss.f_opt), f)

    return scaffold.trace, scaffold.uplink_communicated_numbers, scaffold.downlink_communicated_numbers, loss.f_opt


def run_scaffnew(x0, loss, worker_losses, args):
    save_dir = get_save_dir_name(args)

    if not os.path.exists(save_dir):
        os.mkdir(save_dir)

    with open(f'{save_dir}/scaffnew_args.json', 'wt') as f:
        json.dump(vars(args), f, indent=4)

    scaffnew = Scaffnew(loss=loss, p=args.prob, n_workers=args.n_workers, trace_len=args.it_max + 4,
                        lr=args.lr, batch_size=None, worker_losses=worker_losses, it_max=args.it_max, pbars=args.pbars)
    scaffnew.run(x0=x0, it_max=args.it_max)
    scaffnew.trace.compute_loss_of_iterates()

    with open(f'{save_dir}/scaffnew_res.bin', 'wb') as f:
        pickle.dump((scaffnew.trace, scaffnew.uplink_communicated_numbers,
                     scaffnew.downlink_communicated_numbers, loss.f_opt), f)

    return scaffnew.trace, scaffnew.uplink_communicated_numbers, scaffnew.downlink_communicated_numbers, loss.f_opt


def run_compressed_proxskip(x0, loss, losses, args):
    save_dir = get_save_dir_name(args)

    if not os.path.exists(save_dir):
        os.mkdir(save_dir)

    with open(f'{save_dir}/compressed_args.json', 'wt') as f:
        json.dump(vars(args), f, indent=4)

    compressed = CompressedProxSkip(loss=loss, s=args.s, p=args.prob, eta=args.eta, n_workers=args.n_workers,
                                    lr=args.lr, worker_losses=losses, trace_len=args.it_max + 4, it_max=args.it_max,
                                    pbars=args.pbars)
    compressed.run(x0=x0, it_max=args.it_max)
    compressed.trace.compute_loss_of_iterates()

    with open(f'{save_dir}/compressed_res.bin', 'wb') as f:
        pickle.dump((compressed.trace, compressed.uplink_communicated_numbers,
                     compressed.downlink_communicated_numbers, loss.f_opt), f)

    return compressed.trace, compressed.uplink_communicated_numbers, compressed.downlink_communicated_numbers, \
           loss.f_opt


def compare_optimizers(a, b, args):
    num_rows, dim = a.shape

    loss = get_loss(a, b, args)
    worker_losses = compute_worker_losses(a, b, loss.l2, args)

    compressed_args = copy.deepcopy(args)

    kappa = loss.max_smoothness / loss.l2

    print(f'Kappa: {kappa}')

    if compressed_args.s is None:
        compressed_args.s = int(np.max([2, np.floor(args.n_workers / dim), np.floor(args.c_downlink * args.n_workers)]))

    if args.prob is None:
        args.prob = 1 / np.sqrt(kappa)
        compressed_args.prob = np.min([np.sqrt(args.n_workers / (compressed_args.s * kappa)), 1])

    if args.it_local is None:
        args.it_local = int(1 / args.prob)

    if args.lr is None:
        compressed_args.lr = args.lr = 2 / (loss.max_smoothness + loss.l2)

    if compressed_args.eta is None:
        compressed_args.eta = (compressed_args.n_workers * (compressed_args.s - 1)) / \
                              (compressed_args.s * (compressed_args.n_workers - 1))

    if args.load_dir is not None:
        runs = [
            (load(f'{args.load_dir}/gd_res.bin'), 'GD'),
            (load(f'{args.load_dir}/scaffnew_res.bin'), 'Scaffnew'),
            (load(f'{args.load_dir}/compressed_res.bin'), 'CompressedScaffnew'),
            # (load(f'{args.load_dir}/scaffold_res.bin'), 'Scaffold')
        ]
    else:
        runs = [
            (run_distributed_gd, 'GD', args),
            (run_scaffnew, 'Scaffnew', args),
            (run_compressed_proxskip, 'CompressedScaffnew', compressed_args),
            # (run_scaffold, 'Scaffold', args)
        ]

    run(runs, loss, worker_losses, dim, args.c_downlink, args.c_uplink,
        save_dir=args.load_dir if args.load_dir is not None else get_save_dir_name(args))
