import argparse
import json
import os
from functools import partial
from pathlib import Path

import jax
import numpy as onp
import pyt
from datasets import get_dataset
from jax import jacrev, jit
from jax import numpy as np
from NN_helpers import NTK, MLP_init
from NN_helpers import MLP_predict as orig_predict

# Defaults to override for debugging
W_SCALE = 1
B_SCALE = 0.01

parser = argparse.ArgumentParser(description="Generate functionspace graphs")
parser.add_argument("--dataset", type=str, help="Dataset to use", required=True)
parser.add_argument(
    "--width", type=int, help="Layer width for hidden layers", required=True
)
parser.add_argument("--layers", type=int, help="Number of layers", required=True)
parser.add_argument("--learning-rate", type=float, help="Learning rate")
parser.add_argument(
    "--prefix", type=str, default="", help="The prefix used when writing files"
)
parser.add_argument("--opt", type=str, help="Optimisation type")
parser.add_argument("-k", type=int, help="K parameter for periodic lin optimisation")
parser.add_argument(
    "--lambda",
    type=float,
    help="Regularisation for ILS",
    dest="lambda_",
    metavar="LAMBDA",
)
parser.add_argument(
    "--kernalised",
    action="store_true",
    help="Use kernalised closed form with ILS",
)
parser.add_argument(
    "--noise", type=float, default=0, help="Noise variance to add to data"
)
parser.add_argument(
    "--max-epochs", type=int, default=500000, help="Max cap on epochs per section"
)
parser.add_argument("--seed", type=int, default=1, help="PRNG seed")
args = parser.parse_args()

print("Lambda:", args.lambda_, "seed", args.seed)

details = dict(vars(args))

W_init = jax.nn.initializers.lecun_normal()
MLP_predict = lambda params, x: orig_predict(params, x)

lr = f"-LR{args.learning_rate:.0e}".replace("e-0", "e-") if args.learning_rate else ""
kern = "-kern" if args.kernalised else ""
opt = f"-{args.opt}" if args.opt else ""
lmbda = ""
if args.opt == "ils":
    lmbda = f"{args.lambda_:.0e}".replace("e-0", "e-")
    if float(lmbda) != args.lambda_:
        lmbda = f"{args.lambda_:e}"

OUTPUT_DIR = os.path.join(
    args.prefix,
    "outputs",
    f"{args.dataset}-{args.layers}x{args.width}-{lr}{opt}{kern}{lmbda}",
)

Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)


data, targets, test = get_dataset(
    args.dataset, noise=args.noise, random_key=jax.random.PRNGKey(0)
)
xmin = data.min() - 3
xmax = data.max() + 3
details["xmin"] = xmin.tolist()
details["xmax"] = xmax.tolist()
input_space = np.linspace(xmin, xmax, 100)

N = len(data)
SPEC = [1] + [args.width] * args.layers + [1]
opt_key, key = jax.random.split(jax.random.PRNGKey(args.seed), num=2)
init_params = MLP_init(SPEC, key, W_SCALE, B_SCALE, W_init=W_init)

class Snapshots:
    def __init__(self, fn, input_space, init_params):
        self.fn = fn
        self.input_space = input_space
        self.init_params = init_params
        self.outputs = [onp.array(fn(init_params, input_space))]
        self.lin_outputs = [onp.array(fn(init_params, input_space))]
        self.losses = []
        self.test_mse = []
        self.jac = jit(jacrev(fn))

    def log(self, i, p, loss, data, lin_fn):
        # vals, grads = value_and_grad(self.fn)(p, self.input_space)
        vals = self.fn(p, self.input_space)
        if test is not None:
            predictions = self.fn(p, test[0])
            mse = np.mean((predictions - test[1]) ** 2)
            self.test_mse[i].append(mse)
        self.losses[i].append(loss)
        self.outputs[i].append(onp.array(vals))
        if lin_fn is not None:
            self.lin_outputs[i].append(onp.array(lin_fn(p, self.input_space)))



@partial(jit, static_argnums=(0,))
def lin_at(fn, init_params, new_params, x):
    dp = new_params - init_params
    base, corr = jax.jvp(lambda p: fn(p, x), (init_params,), (dp,))
    return base + corr


class IterativeLeastSquares:
    def __init__(
        self,
        fn,
        lmbda,
        kernalised,
    ):
        self.fn = fn
        self.jac = jit(jacrev(fn))
        self.lmbda = lmbda
        self.kernalised = kernalised
        self.fn_lin = None
        self.beta = None

    def __call__(self, params, x, y, step):
        self.fn_lin = lambda p, x: lin_at(self.fn, params, p, x)
        z = self.fn(params, x)
        phi = self.jac(params, x)
        loss = np.vdot(y - z, y - z) / 2
        p_fun = lambda p: self.fn(p, x)
        z2, vjp_fun = jax.vjp(p_fun, params)
        if self.kernalised:
            ntk = NTK(params, x)
            lhs = ntk + self.lmbda * np.identity(N)
            rhs = y - z
            alpha = np.linalg.solve(lhs, rhs)
            change = vjp_fun(-alpha)[0]
            self.beta = alpha
        else:
            jacobian = np.concatenate(
                [x.reshape(N, -1) for x in jax.tree_util.tree_flatten(phi)[0]], axis=1
            )
            rhs = vjp_fun(y - z)[0]
            flat_rhs, tree_shape = rhs.full_flatten()
            lhs = jacobian.T @ jacobian + self.lmbda * np.identity(num_params)
            c, low = jax.scipy.linalg.cho_factor(lhs)
            flat_change = jax.scipy.linalg.cho_solve((c, low), flat_rhs)
            change = -pyt.Params.full_unflatten(flat_change, tree_shape)

        if args.learning_rate is None:
            w = params - change
        else:
            w = params - change * args.learning_rate
        return w, loss


class IterativeLinUpdate:
    def __init__(self, k, fn):
        self.k = k
        self.fn = fn
        self.jac = jit(jacrev(fn))
        self.phi = None
        self.fn_lin = None

    def __call__(self, params, x, y, step):
        if step % self.k == 0:
            self.phi = self.jac(params, x)
            self.fn_lin = lambda p, x: lin_at(self.fn, params, p, x)
        diff = self.fn_lin(params, x) - y
        loss_diff = self.fn(params, x) - y
        loss = np.vdot(loss_diff, loss_diff) / 2
        grads = pyt.unary_op(lambda x: np.einsum("ij...,ij", x, diff), self.phi)
        return params - grads * args.learning_rate, loss


update_fn = {
    "il": IterativeLinUpdate(args.k, MLP_predict, key=None),
    "ils": IterativeLeastSquares(MLP_predict, args.lambda_, args.kernalised),
}[args.opt]


snapshots = Snapshots(MLP_predict, input_space, init_params)
params = init_params
num_params = params.num_params()

# Exit loop when one of
# * The log loss is less than -15
# * The difference in losses 200 steps apart is less than 1e-5
# * The total number of steps is more than 50000
LOG_LOSS_LIMIT = -5
ITERATION_LIMIT = args.max_epochs
old_loss = 1e9
i = 0
reason = None
while i < ITERATION_LIMIT:
    max_log_loss = -15
    p, batch_loss = update_fn(params, data, targets, i)
    if batch_loss is None or np.isnan(batch_loss).any():
        print("Batch loss is NaN", i)
    snapshots.log(p, batch_loss, data, update_fn.fn_lin)
    params = p
    max_log_loss = max(max_log_loss, np.log(batch_loss))
    i += 1
    if i % 5000 == 0:
        print(i, max_log_loss)
    if i % 200 == 0:
        print(i, old_loss, max_log_loss)
        if test is not None:
            print("test:", " ".join(str(x[-1].tolist()) for x in snapshots.test_mse))
        if abs(old_loss - max_log_loss) < 1e-10:
            print("UNCHANGING", i, old_loss, max_log_loss)
            reason = "Unchanging loss"
            max_log_loss = LOG_LOSS_LIMIT
        old_loss = max_log_loss
    if max_log_loss <= LOG_LOSS_LIMIT or i >= ITERATION_LIMIT:
        if reason is None:
            if i >= ITERATION_LIMIT:
                reason = "Max epochs"
            elif max_log_loss <= LOG_LOSS_LIMIT:
                reason = "Loss low enough"
        print(i, max_log_loss, batch_loss)
        if test is not None:
            print(
                "test:", " ".join(str(x[-1].tolist()) for x in snapshots.test_mse)
            )
        i = ITERATION_LIMIT

if reason is None:
    reason = "Unknown"


final_output = onp.array(MLP_predict(params, test[0])[:,0])
onp.save(f"{OUTPUT_DIR}/final_output_{args.seed}.npy", final_output)
np_outputs = onp.array(snapshots.outputs).swapaxes(0, 1)
onp.save(f"{OUTPUT_DIR}/outputs.npy", np_outputs)
if len(snapshots.lin_outputs[0]) > 1:
    np_lin_outputs = onp.array(snapshots.lin_outputs).swapaxes(0, 1)
    onp.save(f"{OUTPUT_DIR}/lin_outputs.npy", np_lin_outputs)
np_test_mse = onp.array(snapshots.test_mse).swapaxes(0, 1)
onp.save(f"{OUTPUT_DIR}/test_mse.npy", np_test_mse)
details["steps"] = len(snapshots.losses[0])
details["reason"] = reason
with open(f"{OUTPUT_DIR}/details.json", "w") as f:
    json.dump(details, f)
