#%%
import os
from datetime import datetime
from pathlib import Path
from typing import Sequence
import matplotlib.pyplot as plt
from datargs import argsclass, parse
import jax
import jax.numpy as jnp
from jax import grad, jit
import numpy as np
from jax.tree_util import tree_flatten

import mlflow

plt.rcParams['figure.facecolor'] = 'white'

from jax.config import config
config.update("jax_enable_x64", True)


class NaNException(Exception):
    pass


#######################################################
# Config
#######################################################
@argsclass
class Args:
    server_iters: int = 100
    client_iters: int = 100
    num_clients : int = 10
    centralized: bool = False
    
    method: str = "FedAvg"
    
    server_lr: float = 1.0 
    client_lr: float = 0.1
    use_lips: bool = False
    client_prox_reg: float = 1.0
    no_refinement: bool = False
    # alpha0: float = 1
    # decrease : str = None
    # decrease_factor: float = 1
    
    dataset_name : str = "w8a"
    d: int = 300
    reg : float = 0.0001

    seed: int = None
    name: str = None

    def setup(self):
        if self.seed is None:
            random_data = os.urandom(4) 
            self.seed = int.from_bytes(random_data, byteorder="big")

        now = datetime.now()
        name = now.strftime("%Y-%m-%d_%Hh%Mm%Ss")
        if self.name is not None:
            name = f"{self.name}({name})"
        
        self.dir = os.path.join("output", name)
        self.write_args_to_yaml_file()
        return self

    def path(self, filename):
        return os.path.join(self.dir, filename)

    def write_args_to_yaml_file(self):
        import yaml
        import dataclasses
        Path(self.dir).mkdir(parents=True, exist_ok=True)
        with open(os.path.join(self.dir, "args.yaml"), "w") as args_file:
            yaml.dump(dataclasses.asdict(self), args_file)


#######################################################
# Logistic regression
#######################################################

def sigmoid(x):
    return 1 / (1 + jnp.exp(-x))


def predict(params, data):
    W, b = params['W'], params['b']
    logits = jnp.dot(data, W) + b
    return sigmoid(logits)


def create_logistic_regression_loss(reg, dataset):
    data, targets = dataset

    def loss_func(params):
        preds = predict(params, data)
        loss = -jnp.mean(targets * jnp.log(preds) + (1 - targets) * jnp.log(1 - preds))
        param, _ = tree_flatten(params)
        loss += reg * jnp.sum(jnp.array([jnp.vdot(p, p) for p in param])) / 2.0
        return loss

    return loss_func


def compute_lipschitz(reg, dataset):
    """
    Expects targets in [0, 1].

    Uses that the Hessian of $\ell(x)=\log(1+\exp{w^\top x})$ is 
    $$\nabla^2 \ell(x) = w^\top w \frac{\exp{w^\top x}}{(1+\exp{w^\top x})^2}$$
    and $\frac{\exp{t}}{(1+\exp{t})^2} \leq 1/4$ for any $t$.
    """
    data, targets = dataset

    labels = targets*2-1
    W = -data*labels[:, None]
    W = data
    return (0.25*jnp.linalg.norm(W, ord=2)**2/(len(data))).item() + reg


def create_losses(args, datasets):
    client_losses = []
    lips_constants = []
    for trainset_i in datasets:
        loss_i = create_logistic_regression_loss(args.reg, trainset_i)
        client_losses.append(loss_i)
        lips_constants.append(compute_lipschitz(args.reg, trainset_i))

    total_loss = lambda x: jnp.mean(jnp.array([loss(x) for loss in client_losses]))
    return client_losses, lips_constants, total_loss


def create_libsvm_losses(args):
    URLs = {
        "w8a": "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/w8a",
        "a9a": "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/a9a",
    } 

    # Download
    import sklearn.datasets
    import urllib.request
    url = URLs[args.dataset_name]
    data_path = f'./{args.dataset_name}'
    if not os.path.exists(data_path):
        print("Downloading LIBSVM dataset...")
        urllib.request.urlretrieve(url, data_path)
    A, b = sklearn.datasets.load_svmlight_file(data_path)

    # Preprocessing
    n, dim = A.shape
    if n % args.num_clients != 0:
        A = A[:n - (n % args.num_clients)]
        b = b[:n - (n % args.num_clients)]
    b_unique = np.unique(b)
    if (b_unique == [1, 2]).all():
        # Transform labels {1, 2} to {0, 1}
        b = b - 1
    elif (b_unique == [-1, 1]).all():
        # Transform labels {-1, 1} to {0, 1}
        b = (b+1) / 2
    else:
        # replace class labels with 0's and 1's
        b = 1. * (b == b[0])
        
    # non iid
    permutation = b.squeeze().argsort()

    idx = [0] + [(n * i) // args.num_clients for i in range(1, args.num_clients)] + [n]
    datasets = []
    for i in range(args.num_clients):
        idx_i = permutation[idx[i] : idx[i+1]]
        trainset_i = (
            jnp.float64(jnp.array(A[idx_i].A)), 
            jnp.float64(jnp.array(b[idx_i])),
        )
        datasets.append(trainset_i)

    return create_losses(args, datasets)


#######################################################
# Client & Server
#######################################################

def make_grad(L):
    return jit(grad(L, argnums=0))


def norm_squared(x):
    x, _ = tree_flatten(x)
    return jnp.sum(jnp.array([jnp.vdot(xi, xi) for xi in x]))


def add(*args):
    return sum(args)


def tree_sum(x):
    return sum(tree_flatten(x)[0])


class Client(object):
    def __init__(self, args, loss, lips_const=None):
        self.args = args
        self.loss = loss
        self.grad = make_grad(loss)
        self.lips_const = lips_const

    def init(self, x):
        return {}

    def step(self, server_comm, client_state):
        return server_comm, client_state


class Server(object):
    def __init__(self, args, total_loss, clients):
        self.args = args
        self.total_loss = total_loss
        self.clients = clients

    def init(self, x):
        return {}

    def step(self, t, server_state, client_states):
        """A skeleton function body which only updated the client states.
        """
        new_client_states = [None for c in client_states]
        server_comm = {'x': server_state['x']}
        for i, client in enumerate(self.clients):
            client_comm, new_client_states[i] = client.step(server_comm, client_states[i])
        return server_state, new_client_states


class FedAvgClient(Client):
    def init(self, x):
        return {}

    def stepsize(self):
        if self.args.use_lips:
            assert self.lips_const is not None
            return self.args.client_lr * 1/self.lips_const
        else:
            return self.args.client_lr

    def step(self, server_comm, client_state):
        x = server_comm['x']
        
        stepsize = self.stepsize()
        for i in range(self.args.client_iters):
            x = jax.tree_map(lambda x,g: x - stepsize * g, x, self.grad(x))
        
        client_comm = {'x': x}
        return client_comm, client_state


class FedAvgServer(Server):
    def init(self, x):
        return {'x': x}

    def step(self, t, server_state, client_states):
        new_client_states = [None for c in client_states]
        x = server_state['x']
        server_comm = {'x': x}
        new_x = jax.tree_map(lambda x_i: jnp.zeros_like(x_i), x)

        for i, client in enumerate(self.clients):
            client_comm, new_client_states[i] = client.step(server_comm, client_states[i])
            new_x = jax.tree_map(lambda a,b: a+b, new_x, client_comm['x'])
        new_x = jax.tree_map(lambda x,new_x: 
                             x - self.args.server_lr * (x - new_x / len(self.clients)), 
                             x, new_x)
        new_server_state = {'x': new_x}
        return new_server_state, new_client_states


class ProxBaseClient(Client):
    def stepsize(self):
        if self.args.use_lips:
            # uses the lipschitz constaint of $\nabla f_i^\gamma$ which is $1+\gamma L_i$.
            assert self.lips_const is not None
            return self.args.client_lr * 1/ (1 + self.args.client_prox_reg * self.lips_const)
        else:
            return self.args.client_lr


class FedProxClient(ProxBaseClient):
    def init(self, x):
        return {}

    def step(self, server_comm, client_state):
        x_anchor = server_comm['x']
        x = x_anchor
        
        lr = self.stepsize()
        for i in range(self.args.client_iters):
            x = jax.tree_map(lambda x,g,x_anchor: 
                             x - lr * (self.args.client_prox_reg * g + (x - x_anchor)), 
                             x, self.grad(x), x_anchor)

        client_comm = {'x': x}
        return client_comm, client_state


class FedProxServer(FedAvgServer):
    pass


class FedDRClient(ProxBaseClient):
    def init(self, x):
        return {'s': x, 'xbar': x}

    def step(self, server_comm, client_state):
        phat = server_comm['x'] # server_comm.x refers to `phat`
        s_prev = client_state['s']
        xbar_prev = client_state['xbar']

        s_prev = jax.tree_map(lambda s_prev,phat,xbar_prev: 
                              s_prev + self.args.server_lr * (phat - xbar_prev), 
                              s_prev, phat, xbar_prev)
        xbar = s_prev
        
        lr = self.stepsize()
        for i in range(self.args.client_iters):
            xbar = jax.tree_map(lambda xbar,g,s_prev: 
                                xbar - lr * (self.args.client_prox_reg * g + (xbar - s_prev)), 
                                xbar, self.grad(xbar), s_prev)
        
        p = jax.tree_map(lambda xbar,s_prev: 2 * xbar - s_prev, xbar, s_prev)

        client_comm = {'x': p} # client_comm.x refers to `p`
        new_client_state = {'s': s_prev, 'xbar': xbar}
        return client_comm, new_client_state


class FedDRServer(FedAvgServer):
    pass


class InexactFedDRClient(ProxBaseClient):
    def init(self, x):
        return {'s': x, 'xbar': x}

    def step(self, server_comm, client_state):
        phat = server_comm['phat']
        alpha = server_comm['alpha']
        s = client_state['s']
        xbar = client_state['xbar']

        # Do not update if refining
        if not server_comm['refine']:
            s = jax.tree_map(lambda s,phat,xbar: 
                             s - self.args.server_lr * alpha * (xbar - phat), 
                             s, phat, xbar)
            xbar = s
        
        lr = self.stepsize()
        for i in range(self.args.client_iters):
            xbar = jax.tree_map(lambda xbar,g,s: 
                                xbar - lr * (self.args.client_prox_reg * g + (xbar - s)), 
                                xbar, self.grad(xbar), s)
                
        client_comm = {'s': s, 'xbar': xbar, 'xbar_grad': self.grad(xbar)}
        new_client_state = {'s': s, 'xbar': xbar}
        return client_comm, new_client_state


class InexactFedDRServer(Server):
    def init(self, x):
        return {'x': x, 'alpha': 0.0}

    def step(self, t, server_state, client_states):
        # print("server step")
        new_client_states = [c for c in client_states]
        server_comm = {'phat': server_state['x'], 'alpha': server_state['alpha'], 'refine': False}

        gamma = self.args.client_prox_reg
        num_clients = len(self.clients)

        num_client_refinements = 0
        err_condition_met = False
        while not err_condition_met:
            xbar_norm_sq_sum = 0.0
            innerprod_sum = 0.0
            sgammagrad_norm_sq_sum = 0.0
            epsilon_sum = 0.0
            zeroes = jax.tree_map(jnp.zeros_like, server_state['x'])
            phat = zeroes
            xbar_sum = zeroes
            sgammagrad_sum = zeroes
            
            # For stats
            client_states = []

            for i, client in enumerate(self.clients):
                client_comm, new_client_states[i] = client.step(server_comm, new_client_states[i])
                xbar_i = client_comm['xbar']
                xbar_grad_i = client_comm['xbar_grad']
                s_i = client_comm['s']
                
                sgammagrad = jax.tree_map(lambda s_i,xbar_grad_i: 
                                          s_i - gamma * xbar_grad_i, 
                                          s_i,xbar_grad_i)
                delta = jax.tree_map(lambda s_i,xbar_grad_i,xbar_i: 
                                     s_i - gamma * xbar_grad_i - xbar_i, 
                                     s_i,xbar_grad_i,xbar_i)
                norm_sq_xbar_i = norm_squared(xbar_i)
                norm_sq_sgammagrad = norm_squared(sgammagrad)
                norm_sq_epsilon_sum_delta = norm_squared(delta)

                xbar_sum = jax.tree_map(add, xbar_sum, xbar_i)
                xbar_norm_sq_sum += norm_sq_xbar_i
                sgammagrad_norm_sq_sum += norm_sq_sgammagrad
                sgammagrad_sum = jax.tree_map(add, sgammagrad_sum, sgammagrad)
                innerprod_sum -= tree_sum(jax.tree_map(jnp.dot, xbar_i, sgammagrad))
                epsilon_sum += norm_sq_epsilon_sum_delta
                phat = jax.tree_map(lambda phat,xbar_i,xbar_grad_i: 
                                    phat + xbar_i - gamma * xbar_grad_i, 
                                    phat,xbar_i,xbar_grad_i)
                
                client_states.append((xbar_i, sgammagrad))

            phat = jax.tree_map(lambda phat: phat/num_clients, phat)
            phat_norm_sq = norm_squared(phat)
            innerprod1 = tree_sum(jax.tree_map(jnp.dot, xbar_sum, phat))
            innerprod2 = tree_sum(jax.tree_map(jnp.dot, phat, sgammagrad_sum))
            xi = xbar_norm_sq_sum - 2 * innerprod1 + num_clients * phat_norm_sq
            mu = -innerprod2 - innerprod_sum - innerprod1 + num_clients * phat_norm_sq
            zeta = sgammagrad_norm_sq_sum - 2 * innerprod2 + num_clients * phat_norm_sq
            zeta *= 1/(gamma**2)
            # print("mu", mu)

            alpha = mu/xi

            # Check if error condition is met
            sigma_sq = 0.99
            err_condition_met = epsilon_sum <= sigma_sq * max(xi, zeta)

            if self.args.no_refinement:
                err_condition_met = True

            # update statistics
            num_client_refinements += 1

            # Continue refinement if condition is not met
            if not err_condition_met:
                server_comm['refine'] = True
                if num_client_refinements > 100:
                    print(f"refines {num_client_refinements}: sigma^2={sigma_sq},"
                      f"eps={epsilon_sum},xi={xi},zeta={zeta},alpha={alpha}")
                    raise RuntimeError("The number of refinements exceeded the maximum allowance of 100.")
                continue
            else:
                print(f"refines {num_client_refinements}: sigma^2={sigma_sq},"
                      f"eps={epsilon_sum},xi={xi},zeta={zeta},alpha={alpha}")
                # For debugging
                mlflow.log_metrics({
                    'innerprod2': innerprod2.item(),
                    'innerprod_sum': innerprod_sum.item(),
                    'innerprod1': innerprod1.item(),
                    'xbar_norm_sq_sum': xbar_norm_sq_sum.item(),
                    'sgammagrad_norm_sq_sum': sgammagrad_norm_sq_sum.item(),
                    'epsilon_sum': epsilon_sum.item(),
                    'mu': mu.item(),
                    'zeta': zeta.item(),
                    'xi': xi.item(),
                    'alpha': alpha.item(),
                    'num_client_refinements': num_client_refinements
                }, t)
                
                # Update the server model and stepsize
                new_server_state = {'x': phat, 'alpha': alpha}
                return new_server_state, new_client_states


#######################################################
# Helper functions
#######################################################

def init_history(keys=[], length=0):
    history = {}
    for key in keys:
        # Use nans so we can tell if run was terminated early
        history[key] = jnp.nan * jnp.arange(length, dtype="float64")
    return history

def update_history(history, t, values: dict):
    for k,v in values.items():
        if k in history:
            history[k] = history[k].at[t].set(v)
    return history


#######################################################
# Main
#######################################################
def main(args):
    mlflow.set_tracking_uri("sqlite:///mlflow.db")
    mlflow.start_run(run_name=args.name)
    mlflow.log_params(vars(args))

    if args.method == "FedAvg":
        create_client = FedAvgClient
        create_server = FedAvgServer
    elif args.method == "FedProx":
        create_client = FedProxClient
        create_server = FedProxServer
    elif args.method == "FedDR":
        create_client = FedDRClient
        create_server = FedDRServer
    elif args.method == "iFedDR":
        create_client = InexactFedDRClient
        create_server = InexactFedDRServer
    else:
        raise ValueError("Not valid `args.method`")

    client_losses, lips_constants, total_loss = create_libsvm_losses(args)

    if args.centralized:
        clients = [create_client(args, total_loss, jnp.mean(jnp.array(lips_constants)))]
    else:
        clients = [create_client(args, client_losses[i], lips_constants[i]) 
                for i in range(len(client_losses))]
    server = create_server(args, total_loss, clients)
    print("Lipschitz constants: ", lips_constants)
    
    total_grad = make_grad(total_loss)

    # Initialize
    globalkey = jax.random.PRNGKey(args.seed)
    globalkey, subkey = jax.random.split(globalkey)
    W = jax.random.uniform(subkey, (args.d,))/args.d
    b = 0.0
    x = {'W': W, 'b': b}
    client_states = [client.init(x) for client in clients]
    server_state = server.init(x)    

    # Logging
    history = init_history(keys=[
        'total_grad_squared', 
        'total_loss', 
    ], length=args.server_iters)


    def loop_body(t, state):
        history, server_state, client_states = state 

        # Log prior to step to ensure first iterate is logged
        current_grad, _ = tree_flatten(total_grad(server_state['x']))
        entry = {
            'total_grad_squared': jnp.sum(jnp.array([jnp.vdot(g, g) for g in current_grad])).item(),
            'total_loss': total_loss(server_state['x']).item(),
            'step': t,
        }
        history = update_history(history, t, entry)
        mlflow.log_metrics(entry, step=t)

        server_state, client_states = server.step(t, server_state, client_states)

        if t % 10 == 0:
            print(f"Iterations {t}: total loss of {entry['total_loss']}")

        return (history, server_state, client_states)

    print("Starting iterations")

    init_state = (history, server_state, client_states)
    #state = jax.lax.fori_loop(1, args.server_iters, loop_body, init_state)
    state = init_state
    for t in range(0, args.server_iters):
        try:
            state = loop_body(t, state)
        except NaNException:
            print(f"Founds NaNs: terminating prematurely at iteration {t}")
            break
    (history, server_state, client_states) = state

    print("Finished iterations")
    print("Final total loss:", total_loss(server_state['x']))

    for k in history.keys():
        # Truncate the history if terminated early
        truncated_history = history[k][:t]

        if len(truncated_history):
            fig, ax = plt.subplots(1, 1)
            ax.plot(truncated_history)
            ax.set_yscale('log')
            ax.set_xscale('log')
            ax.set_title(k)
            fig.tight_layout()
            fig.savefig(args.path(f"{k}.png"))
            plt.close(fig)
            jnp.save(args.path(f"{k}.npy"), truncated_history)

    mlflow.end_run()


if __name__ == '__main__':
    args = parse(Args).setup()
    args.setup()
    main(args)
