import argparse
import csv
import json
import os
import time
from datetime import datetime
from functools import partial
from pathlib import Path

import datasets
import jax
import jax.numpy as np
import jax.profiler
import metrics
import pyt
import utils
import wandb
from jax import jit, random, value_and_grad
from jax.example_libraries import optimizers
from models.flax_models import (VGG16, VGG19, Alexnet, BasicMLP, LeNet,
                                LeNetLarge, ResNet18, ResNet18NoNorm)

program_start_time = time.time()

parser = argparse.ArgumentParser(description="Periodic linearisation")
# DATASET
parser.add_argument("--dataset", type=str, help="Dataset to use")
parser.add_argument("--data-limit", type=int, help="Dataset limit to use")
parser.add_argument(
    "--regression", action="store_true", help="Use regression optimisation"
)
parser.add_argument(
    "--include-flip",
    action="store_true",
    help="Allow use of flips for data augmentation",
)
parser.add_argument(
    "--random-crop",
    action="store_true",
    help="Perform random crops to each batch",
)
parser.add_argument("--extra-dims", type=int, help="Add extra dimensions")
parser.add_argument("--extra-dim-type", type=str, help="Extra dimension type")
parser.add_argument(
    "--data-dir",
    type=Path,
    default="/homes/ag2198",
    help="Directory to store datasets, checkpoints, tmp files...",
)
# MODEL SPEC
parser.add_argument("--model", type=str, help="Model to use")
parser.add_argument(
    "--layers", type=str, help="Comma separated layer sizes for CustomMLP model"
)
# OPTIMIZER
parser.add_argument("--learning-rate", type=float, default=1e-3, help="Learning rate")
parser.add_argument("--optimizer", type=str, help="Optimizer to use")
parser.add_argument("--mass", type=float, default=0.9, help="Momentum mass")
parser.add_argument(
    "-k",
    type=int,
    default=1,
    help="Number of steps between updates of jacobian in periodic linearisation optimizer (-1 for ISM)",
)
parser.add_argument("--ism-threshhold", type=float, help="Loss threshhold for iterative softmax")
# GENERAL TRAINING
parser.add_argument("--seed", type=int, default=1, help="PRNG seed")
parser.add_argument("--batch-size", type=int, help="Batch size to use for training")
parser.add_argument("--eval-batch-size", type=int, help="Batch size to use for eval")
parser.add_argument(
    "--full-batch",
    action="store_true",
    help="Use batches computationally to simulate full batch rather than algorithmically",
)
parser.add_argument("--epochs", type=int, help="Number of epochs")
parser.add_argument("--min-loss", type=float, help="Minimum loss after which to exit")
parser.add_argument(
    "--max-acc", type=float, help="Maximum train acc after which to exit"
)
parser.add_argument(
    "--weight-decay", type=float, default=0, help="Weight decay coefficient in loss"
)
parser.add_argument(
    "--decay-dist",
    action="store_true",
    help="Decay the distance from the last linearisation rather than the distance from the origin",
)
# LOGGING / RUN MANAGEMENT
parser.add_argument(
    "-d",
    "--debug",
    default=os.environ.get("TERM_PROGRAM") == "vscode",
    help="Debug version, won't log to wandb, defaults to true if running in VSCode",
)
parser.add_argument("--checkpoint-frequency", type=int, help="Checkpoint frequency")
parser.add_argument(
    "--wandb-log-frequency", type=int, default=5, help="Wandb log frequency"
)
parser.add_argument(
    "--eval-frequency", type=int, default=1, help="Train/test eval frequency"
)
parser.add_argument(
    "--log-frequency", type=int, default=5, help="How frequently to write logs to disk"
)
parser.add_argument("--continue-run", type=str, help="Run ID to continue")
parser.add_argument("--tag", type=str, help="Extra info for run")

args = parser.parse_args()

random_crop = datasets.random_crop if args.random_crop else None

key = random.PRNGKey(args.seed)
key, data_key = random.split(key)
dataset: datasets.ImageDataset = getattr(datasets, args.dataset)(
    batch_size=args.batch_size,
    data_location=args.data_dir / "data",
    include_flip=args.include_flip,
    data_aug=random_crop,
    key=data_key,
    randomise=True,
    data_limit=args.data_limit,
)
print("Num outputs: ", dataset.num_outputs)

model = {
    "mlp": partial(BasicMLP, layer_sizes=[512, 512, 512, 512, 128]),
    "CustomMLP": partial(
        BasicMLP,
        layer_sizes=[int(x) for x in args.layers.split(",")] if args.layers else [],
    ),
    "lenet": LeNet,
    "lenet-large": LeNetLarge,
    "alexnet": Alexnet,
    "resnet18BN": ResNet18,
    "resnet18NN": ResNet18NoNorm,
    "VGG16": VGG16,
    "VGG19": VGG19,
}[args.model](num_outputs=dataset.num_outputs)


loss_base = utils.mean_squared_error if args.regression else utils.softmax_cross_entropy


@jit
def loss_fn_update(params, batch_stats, images, targets):
    preds, new_state = model.apply(
        {"params": params, "batch_stats": batch_stats},
        images,
        mutable=["batch_stats"],
        train=True,
    )
    loss = loss_base(preds, targets)
    if args.weight_decay != 0 and not args.decay_dist:
        loss += pyt.normsq(params) * args.weight_decay
    return loss, new_state


@jit
def loss_fn(params, images, targets, lin_params, batch_stats):
    preds = lin_at_bs(model.apply, lin_params, params, images, batch_stats)
    loss = loss_base(preds, targets)
    if args.weight_decay != 0:
        if args.decay_dist:
            loss += pyt.normsq(params - lin_params) * args.weight_decay
        else:
            loss += pyt.normsq(params) * args.weight_decay
    return loss


@jit
def get_grads(params, batch_stats, x, y, opt_state):
    if not args.regression:
        y = jax.nn.one_hot(y, dataset.num_outputs)
    standard_form = lambda p, x: model.apply(
        {"params": p, "batch_stats": batch_stats}, x
    )
    f = lambda p, x: lin_at(standard_form, opt_state[0], p, x)
    aux, grads = value_and_grad(loss_fn, has_aux=True)(params, x, y, f)
    loss, new_state = aux
    return loss, grads, batch_stats


@jit
def update_op(params, opt_state, grads, batch_stats):
    old_params, old_batch_stats, opt_state = opt_state
    opt_state = opt_update(step, grads, opt_state)
    return get_params(opt_state), (old_params, old_batch_stats, opt_state)


@partial(jit, static_argnums=(0,))
def lin_at_bs(fn, init_params, new_params, x, batch_stats):
    dp = new_params - init_params
    base, corr = jax.jvp(
        lambda p: fn({"params": p, "batch_stats": batch_stats}, x),
        (init_params,),
        (dp,),
    )
    return base + corr


@partial(jit, static_argnums=(0,))
def lin_at(fn, init_params, new_params, x):
    dp = new_params - init_params
    base, corr = jax.jvp(lambda p: fn(p, x), (init_params,), (dp,))
    return base + corr


def lin_metrics(lin_params, cur_params, mc, forward, batch_stats):
    f = lambda p, x: lin_at_bs(forward, lin_params, p, x, batch_stats)
    metrics = mc.get_metrics(cur_params, f)
    return {
        "lin_train": metrics["train_acc"],
        "lin_test": metrics["test_acc"],
    }


metric = metrics.RegressionMSE if args.regression else metrics.OneHotAccuracyMetric

extra = {}
if args.min_loss is not None:
    extra = {"loss": metrics.SoftmaxCrossEntropyLossMetric}

metric_computer = metrics.FullMetricComputer(
    {"acc": metric, **extra},
    dataset,
    batch_size=args.batch_size
    if args.eval_batch_size is None
    else args.eval_batch_size,
)
key, params_key = random.split(key, num=2)
variables = model.init(
    {"params": params_key},
    np.ones(dataset.input_shape, np.float32),
)
params = variables["params"]
batch_stats = variables.get("batch_stats", {})

args.num_parameters = params.num_params()
args.network_description = model.describe()
if not args.debug:
    wandb_dir = args.data_dir / "wandb"
    if args.continue_run:
        wandb.init(
            project="periodic-linearisation", dir=wandb_dir, resume=args.continue_run
        )
    else:
        wandb.init(project="periodic-linearisation", dir=wandb_dir)
    wandb.config.update(args, allow_val_change=True)
    unique_id = wandb.run.name
else:
    unique_id = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

output_dir = args.data_dir / "outputs" / unique_id
output_dir.mkdir(parents=True, exist_ok=True)
with open(output_dir / "config.json", "w") as f:
    d = vars(args)
    d["data_dir"] = str(d["data_dir"])  # JSON doesn't like PosixPath
    json.dump(d, f)

if args.optimizer == "sgd":
    opt_init, opt_update, get_params = optimizers.sgd(args.learning_rate)
elif args.optimizer == "adam":
    opt_init, opt_update, get_params = optimizers.adam(args.learning_rate)
elif args.optimizer == "momentum":
    opt_init, opt_update, get_params = optimizers.momentum(
        args.learning_rate, args.mass
    )
opt_state = opt_init(params)
params = get_params(opt_state)
lin_params = params

step = 0
start_epoch = 0
checkpoint_dir = output_dir / "checkpoints"
if args.continue_run:
    state = utils.load_latest(checkpoint_dir)
    start_epoch, step, il_state, batch_stats, opt_state = state
    opt_state = optimizers.pack_optimizer_state(opt_state)


@jit
def linearised_grads(params, batch_stats, lin_params, x, y):
    if not args.regression:
        y = jax.nn.one_hot(y, dataset.num_outputs)
    loss, grads = value_and_grad(loss_fn)(params, x, y, lin_params, batch_stats)
    return grads, loss


@jit
def standard_grads(params, batch_stats, x, y):
    if not args.regression:
        y = jax.nn.one_hot(y, dataset.num_outputs)
    aux, grads = value_and_grad(loss_fn_update, has_aux=True)(params, batch_stats, x, y)
    loss, new_state = aux
    return grads, loss, new_state.get("batch_stats", {})


print(params.num_params(), "parameters")
print("Network:", model.describe())
print(f"Initialisation finished after {time.time() - program_start_time:.2f} seconds")


@partial(jit, static_argnums=(4,))
def do_batch(state, x, params, lin_params, update_step):
    data, target = x
    grads, loss = state
    if update_step:
        batch_grads, batch_loss, _batch_stats = standard_grads(params, {}, data, target)
    else:
        batch_grads, batch_loss = linearised_grads(params, {}, lin_params, data, target)
    loss += batch_loss * len(data) / len(dataset.train_data)
    grads += batch_grads * len(data) / len(dataset.train_data)
    return grads, loss


if args.full_batch and batch_stats:
    print(
        "Error: cannot combine batchstats across multiple batches to simulate fullbatch"
    )
    exit(1)

if args.log_frequency < args.eval_frequency:
    args.log_frequency = args.eval_frequency
norm_factor = pyt.l2norm(params)
losses = []
log = []
start_time = time.time()
relinearise = False
num_linearisations = 1
if start_epoch == 0 and args.checkpoint_frequency is not None:
    utils.save(
        checkpoint_dir,
        (0, step, lin_params, batch_stats, optimizers.unpack_optimizer_state(opt_state)),
        0,
    )
    print("saved")
for epoch in range(start_epoch + 1, args.epochs + 1):
    if args.full_batch:
        grads = pyt.zeros(params)
        loss = 0
        for data, target in dataset.train_loader:
            grads, loss = do_batch(
                (grads, loss),
                (data, target),
                params,
                lin_params,
                relinearise,
            )
        opt_state = opt_update(step, grads, opt_state)
        params = get_params(opt_state)
        if relinearise:
            lin_params = params
            num_linearisations += 1
        train_loss = batch_loss = loss
        step += 1
        if args.k == -1:
            relinearise = train_loss < args.ism_threshhold
            print(f"Train loss {train_loss}, relinearising")
        else:
            relinearise = step % args.k == 0
        losses.append(loss)
    else:
        train_loss = 0
        for data, target in dataset.train_loader:
            if relinearise:
                grads, batch_loss, new_batch_stats = standard_grads(
                    params, batch_stats, data, target
                )
                opt_state = opt_update(step, grads, opt_state)
                params = get_params(opt_state)
                lin_params = params
                batch_stats = new_batch_stats
                num_linearisations += 1
            else:
                grads, batch_loss = linearised_grads(
                    params, batch_stats, lin_params, data, target
                )
                opt_state = opt_update(step, grads, opt_state)
                params = get_params(opt_state)
            train_loss += batch_loss
            losses.append(batch_loss)
            step += 1
            if args.k == -1:
                relinearise = False
            else:
                relinearise = step % args.k == 0
        train_loss /= len(dataset.train_loader)
        if args.k == -1 and train_loss < args.ism_threshhold:
            relinearise = True
            print(f"Train loss {train_loss}, relinearising")

    if epoch % args.eval_frequency == 0:
        epoch_time = time.time() - start_time
        start_time = time.time()

        full_params = {"params": params, "batch_stats": batch_stats}
        metrics = metric_computer.get_metrics(full_params, model.apply)
        if args.k != 1:
            lmetrics = lin_metrics(
                lin_params, params, metric_computer, model.apply, batch_stats
            )
        else:
            lmetrics = {
                "lin_train": metrics["train_acc"],
                "lin_test": metrics["test_acc"],
            }
        dist = pyt.l2norm(lin_params - params)
        log_state = {
            "epoch": epoch,
            "train": metrics["train_acc"],
            "test": metrics["test_acc"],
            "time": time.time() - program_start_time,
            "lin_train": lmetrics["lin_train"],
            "lin_test": lmetrics["lin_test"],
            "update_time": epoch_time,
            "val_time": time.time() - start_time,
            "batch_loss": batch_loss,
            "loss": train_loss,
            "dist": dist,
            "normalised_dist": dist / norm_factor,
            "num_linearisations": num_linearisations,
        }
        log.append(log_state)
        if not args.debug and epoch % args.wandb_log_frequency == 0:
            wandb.log(log_state)
        print(
            "Epoch {epoch} | T {update_time:0.2f} (+ {val_time:0.2f}) | Loss {loss:0.4f} | Train {train:0.3f} | Test {test:0.3f} | Lin train {lin_train:0.3f} | Lin test {lin_test:0.3f}".format(
                **log_state
            ),
            flush=True,
        )
        start_time = time.time()
    if epoch % args.log_frequency == 0 or epoch == args.epochs:
        with open(output_dir / "losses", "a") as f:
            f.write(" ".join(map(str, losses)) + "\n")
        losses = []

        with open(output_dir / "log", "a") as f:
            writer = csv.DictWriter(f, log[0].keys())
            if f.tell() == 0:
                writer.writeheader()
            for entry in log:
                writer.writerow(entry)
        log = []
    if (
        args.checkpoint_frequency is not None and epoch % args.checkpoint_frequency == 0
    ) or epoch == args.epochs:
        utils.save(
            checkpoint_dir,
            (
                epoch,
                step,
                lin_params,
                batch_stats,
                optimizers.unpack_optimizer_state(opt_state),
            ),
            epoch,
        )
    if epoch % args.eval_frequency == 0:
        if args.min_loss is not None:
            if args.k > args.epochs * len(dataset.train_loader):
                if train_loss < args.min_loss:
                    break
            else:
                if metrics["train_loss"] < args.min_loss:
                    break
        if args.max_acc is not None:
            if args.k > args.epochs * len(dataset.train_loader):
                if lmetrics["lin_train"] >= args.max_acc:
                    break
            else:
                if metrics["train_acc"] >= args.max_acc:
                    break
