import torch as tc

from torch.optim import Adam, AdamW, RAdam, NAdam, Adamax, SparseAdam, SGD, Adadelta, Adagrad, Adafactor, ASGD, RMSprop, LBFGS, Rprop

from src.utils import grad, linear_lr_lambda, exponential_lr_lambda, get_batch, TradTrainingArgs, l2_err_TC


def step(model, x, L, dxdt, optim, loss_fn, sched=None):
    """
    Takes one gradient-descent based training step and updates all trainable params.
    Args:
        x       of shape (batch_size, n_features)
        L       Poisson matrix of shape (n_features, n_features), which is constant
        dxdt    time derivatives, of shape (batch_size, n_features)
        optim   Optimizer
        loss_fn a proper loss function for the regression task, mse or L2
        sched   LR scheduler (default=None)
    Returns:
        loss    Scalar loss value of the given batch.
    """
    def closure():
        optim.zero_grad()
        dedx_pred = grad(model, x, create_graph=True, mode="forward")   # (batch_size, n_features)
        dxdt_pred = tc.mm(dedx_pred, L.T)               # Hamilton's eq, resulting shape same as dedx_pred
        loss = loss_fn(dxdt, dxdt_pred)
        loss.backward()
        return loss

    if isinstance(optim, LBFGS):
        loss = optim.step(closure)
    else:
        loss = closure()
        optim.step()
    if sched is not None: sched.step()
    return loss.item()

def traditional_training(model, x, x_test, L, dxdt, dxdt_test, args: TradTrainingArgs, pre_op=None):
    """
    Applies traditional training (gradient-descent based training) to the model. This can be used
    to also further tune all the learnable parameters of the network after we initialize them using
    SWIM or any initialization method and fit last linear layer weights using linear solve.

    Args:
        x, x_test           Train and test inputs of shape (n_points, n_features)
        L                   Poisson matrix of shape (n_features, n_features)
        dxdt, dxdt_test     Time derivatives (targets) of shape (n_points, n_features)
        args                Gradient-descent-based training hyperparams.
        pre_op              Specifies a lambda function (optional) to call before a step, e.g. for
                            computing dynamic edge indices.
    """
    # transfer everything for faster computation
    x, dxdt, L = x.to(args.device), dxdt.to(args.device), L.to(args.device)
    if (not x_test is None) and (not dxdt_test is None):
        x_test, dxdt_test = x_test.to(args.device), dxdt_test.to(args.device)
    model.to(args.device)

    model = model.train()

    if args.weight_init != "none":
        model.init_params(method=args.weight_init)

    if  args.optim_type == "adam": optim = Adam(model.parameters(), args.lr_start, weight_decay=args.weight_decay)
    elif args.optim_type== "adamw": optim = AdamW(model.parameters(), args.lr_start, weight_decay=args.weight_decay)
    elif args.optim_type == "radam": optim = RAdam(model.parameters(), args.lr_start, weight_decay=args.weight_decay)
    elif args.optim_type == "nadam": optim = NAdam(model.parameters(), args.lr_start, weight_decay=args.weight_decay)
    elif args.optim_type == "adamax": optim = Adamax(model.parameters(), args.lr_start, weight_decay=args.weight_decay)
    elif args.optim_type == "sparseadam": optim = SparseAdam(model.parameters(), args.lr_start)
    elif args.optim_type == "sgd": optim = SGD(model.parameters(), args.lr_start, weight_decay=args.weight_decay)
    elif args.optim_type == "sgdmomentum": optim = SGD(model.parameters(), args.lr_start, weight_decay=args.weight_decay, momentum=0.9)
    elif args.optim_type == "adagrad": optim = Adagrad(model.parameters(), args.lr_start, weight_decay=args.weight_decay)
    elif args.optim_type == "adadelta": optim = Adadelta(model.parameters(), args.lr_start, weight_decay=args.weight_decay)
    elif args.optim_type == "adafactor": optim = Adafactor(model.parameters(), args.lr_start, weight_decay=args.weight_decay)
    elif args.optim_type == "asgd": optim = ASGD(model.parameters(), args.lr_start)
    elif args.optim_type == "lbfgs": optim = LBFGS(model.parameters(), args.lr_start)
    elif args.optim_type == "rmsprop": optim = RMSprop(model.parameters(), args.lr_start)
    elif args.optim_type == "rprop": optim = Rprop(model.parameters(), args.lr_start)
    else: raise ValueError("Unsupported optimizer")

    if args.sched_type == "linear": lr_lambda = linear_lr_lambda(args.n_steps, args.lr_start, args.lr_end)
    elif args.sched_type == "exponential": lr_lambda = exponential_lr_lambda(args.n_steps, args.lr_start, args.lr_end)
    elif args.sched_type is None or args.sched_type == "none": sched=None
    else: raise ValueError(f"Unsupported optimizer, got {args.sched_type}")

    if args.sched_type is None or args.sched_type == "none":
        sched = None
    else:
        sched = tc.optim.lr_scheduler.LambdaLR(optim, lr_lambda=lr_lambda)

    loss_fn = lambda y_true, y_pred: (y_true - y_pred).pow(2).mean()
    results = { "loss": [], "test_loss": [], "best_test_loss": float("inf"),
                "test_error": [], "best_test_error": float("inf"),
                "lr": [] }

    patience_counter = 0 # by looking at the best_test_loss during training
    for step_idx in range(args.n_steps):
        if args.batch_size is None:
            x_batch, dxdt_batch = x.requires_grad_(True), dxdt
        else:
            x_batch = get_batch(x, step_idx, args.batch_size).clone().detach().requires_grad_(True)
            dxdt_batch = get_batch(dxdt, step_idx, args.batch_size).clone().detach().requires_grad_(False)
        if not pre_op is None: pre_op(model, x_batch) # e.g. compute dynamic edge index for this batch
        loss = step(model, x_batch, L, dxdt_batch, optim, loss_fn, sched)
        results["loss"].append(loss)

        if (not x_test is None) and (not dxdt_test is None):
            if not pre_op is None: pre_op(model, x_test) # e.g. compute dynamic edge index for this batch
            dedx_test_pred = grad(model, x_test, mode="forward") # (batch_size, n_features)
            dxdt_test_pred = tc.mm(dedx_test_pred, L.T) # Hamilton's equations, resulting shape same as dedx_pred
            test_loss = loss_fn(dxdt_test, dxdt_test_pred)
            test_loss = test_loss.item()
            results["test_loss"].append(test_loss)
            test_error = l2_err_TC(dxdt_test, dxdt_test_pred, verbose=False)
            results["test_error"].append(test_error)

        if sched is not None: results["lr"].append(sched.get_last_lr()[0])
        else: results["lr"].append(args.lr_start)

        if step_idx % max(args.n_steps // 10, 1) == 0:
            if args.sched_type is None or args.sched_type == "none":
                if (not x_test is None) and (not dxdt_test is None):
                    print(f"-> Step {step_idx}: Train Loss = {loss:.5e}, Test Loss = {test_loss:.5e}")
                else:
                    print(f"-> Step {step_idx}: Train Loss = {loss:.5e}")
            else:
                if (not x_test is None) and (not dxdt_test is None):
                    print(f"-> Step {step_idx}: Train Loss = {loss:.5e}, Test Loss = {test_loss:.5e} | LR: {sched.get_last_lr()[0]:.2e}")
                else:
                    print(f"-> Step {step_idx}: Train Loss = {loss:.5e} | LR: {sched.get_last_lr()[0]:.2e}")

        if (not x_test is None) and (not dxdt_test is None):
            if test_loss < results["best_test_loss"] - 1e-8:  # with small tolerance to avoid numerical issues
                results["best_test_loss"] = test_loss
                results["best_test_error"] = test_error
                patience_counter = 0
            else:
                patience_counter += 1
                if patience_counter >= args.patience:
                    print(f"Early stopping at step {step_idx} (best test loss = {results['best_test_loss']:5e})")
                    break

    # transfer back to CPU
    model.to('cpu')

    x, dxdt, L = x.to('cpu'), dxdt.to('cpu'), L.to('cpu')

    if (not x_test is None) and (not dxdt_test is None):
        x_test, dxdt_test = x_test.to('cpu'), dxdt_test.to('cpu')

    return results
