import torch.optim as optim
import numpy as np
import torch as torch
from eos_line_search.utils import *
from eos_line_search.optimizers import SLS, PoNoS, SAM, CDAT, MalMis
from eos_line_search.experiment import *
from eos_line_search.run import *
from eos_line_search.plot import *

warmup_counter = 0


def setup_optimizer(run):
    if run.optimizer.opt_name == "constant_stepsize_GD":
        opt_obj = optim.SGD(
            run.model.model_obj.parameters(), lr=run.optimizer.step_size
        )
    elif (
        run.optimizer.opt_name == "warmup_GD"
        or run.optimizer.opt_name == "warmup_GD_small"
    ):
        opt_obj = optim.SGD(
            run.model.model_obj.parameters(), lr=run.optimizer.step_size
        )
    elif run.optimizer.opt_name == "polyak_stepsize_GD":
        opt_obj = optim.SGD(
            run.model.model_obj.parameters(),
            lr=0,
        )
    elif run.optimizer.opt_name == "CDAT":
        opt_obj = CDAT.CDAT(
            run.model.model_obj.parameters(),
            sigma=run.optimizer.c,
            eps=run.optimizer.eps,
        )
    elif run.optimizer.opt_name == "MM":
        opt_obj = MalMis.MalMis(
            run.model.model_obj.parameters(),
            alpha=1.0,
            lr0=run.optimizer.init_step_size,
            gamma=0.02,
        )
    elif run.optimizer.opt_name == "Adam":
        opt_obj = optim.Adam(
            run.model.model_obj.parameters(),
            lr=run.optimizer.step_size,
            foreach=False,
        )
    elif run.optimizer.opt_name == "SLS":
        opt_obj = SLS.SLS(
            run.model.model_obj.parameters(),
            c=run.optimizer.c,
            init_step_size=run.optimizer.init_step_size,
            max_eta=run.optimizer.max_step_size,
            n_batches_per_epoch=run.num_batches,
            beta=run.optimizer.decrease_factor,
            reset_option=run.optimizer.reset_option,
            forward_option=run.optimizer.forward_option,
            eps=run.optimizer.eps,
        )
    elif (
        run.optimizer.opt_name == "PoNoS"
        and run.optimizer.reset_option != 2
        and run.optimizer.reset_option != 3
        and run.optimizer.forward_option != 13
    ):
        opt_obj = PoNoS.PoNoS(
            run.model.model_obj.parameters(),
            c=run.optimizer.c,
            init_step_size=run.optimizer.init_step_size,
            max_eta=run.optimizer.max_step_size,
            n_batches_per_epoch=run.num_batches,
            beta=run.optimizer.decrease_factor,
            reset_option=run.optimizer.reset_option,
            forward_option=run.optimizer.forward_option,
            eps=run.optimizer.eps,
            adapt_c=run.optimizer.adapt_c,
            nonmonotone_option=run.optimizer.nonmonotone_option,
            M=run.optimizer.M,
            adapt_M=run.optimizer.adapt_M,
            zhang_xi=run.optimizer.xi,
            num_classes=run.dataset.output_dim,
        )
    elif run.optimizer.opt_name == "PoNoS" and run.optimizer.forward_option == 13:
        opt_obj = PoNoS.PoNoS(
            run.model.model_obj.parameters(),
            c=run.optimizer.c,
            init_step_size=run.optimizer.init_step_size,
            max_eta=run.optimizer.max_step_size,
            n_batches_per_epoch=run.num_batches,
            beta=run.optimizer.decrease_factor,
            reset_option=run.optimizer.reset_option,
            forward_option=0,
            eps=run.optimizer.eps,
            adapt_c=run.optimizer.adapt_c,
            nonmonotone_option=run.optimizer.nonmonotone_option,
            M=run.optimizer.M,
            adapt_M=run.optimizer.adapt_M,
            save_backtracks=True,
            zhang_xi=run.optimizer.xi,
        )
    elif run.optimizer.opt_name == "CDAT-NLS" or (
        run.optimizer.opt_name == "PoNoS" and run.optimizer.reset_option == 2
    ):
        cdat = CDAT.CDAT(
            run.model.model_obj.parameters(),
            sigma=2.06,
            eps=run.optimizer.eps,
        )
        run.optimizer.reset_option = 2
        opt_obj = PoNoS.PoNoS(
            run.model.model_obj.parameters(),
            c=run.optimizer.c,
            init_step_size=run.optimizer.init_step_size,
            max_eta=run.optimizer.max_step_size,
            n_batches_per_epoch=run.num_batches,
            beta=run.optimizer.decrease_factor,
            reset_option=run.optimizer.reset_option,
            forward_option=run.optimizer.forward_option,
            eps=run.optimizer.eps,
            adapt_c=run.optimizer.adapt_c,
            nonmonotone_option=run.optimizer.nonmonotone_option,
            M=run.optimizer.M,
            adapt_M=run.optimizer.adapt_M,
            cdat_optimizer=cdat,
            zhang_xi=run.optimizer.xi,
        )
    elif run.optimizer.opt_name == "MalMis-NLS" or (
        run.optimizer.opt_name == "PoNoS" and run.optimizer.reset_option == 3
    ):
        malmis = MalMis.MalMis(
            run.model.model_obj.parameters(),
            alpha=1.0,
            lr0=run.optimizer.init_step_size,
            gamma=0.02,
        )
        run.optimizer.reset_option = 3
        opt_obj = PoNoS.PoNoS(
            run.model.model_obj.parameters(),
            c=run.optimizer.c,
            init_step_size=run.optimizer.init_step_size,
            max_eta=run.optimizer.max_step_size,
            n_batches_per_epoch=run.num_batches,
            beta=run.optimizer.decrease_factor,
            reset_option=run.optimizer.reset_option,
            forward_option=run.optimizer.forward_option,
            eps=run.optimizer.eps,
            adapt_c=run.optimizer.adapt_c,
            nonmonotone_option=run.optimizer.nonmonotone_option,
            M=run.optimizer.M,
            adapt_M=run.optimizer.adapt_M,
            malmis_optimizer=malmis,
            zhang_xi=run.optimizer.xi,
        )
    elif run.optimizer.opt_name == "RMSProp":
        opt_obj = optim.RMSprop(
            run.model.model_obj.parameters(),
            foreach=False,
        )
    elif run.optimizer.opt_name == "AdaGrad":
        opt_obj = optim.Adagrad(
            run.model.model_obj.parameters(),
            foreach=False,
        )
    elif run.optimizer.opt_name == "SAM":
        opt_obj = SAM.SAMSGD(
            run.model.model_obj.parameters(),
            lr=run.optimizer.step_size,
            rho=run.optimizer.rho,
            momentum=run.optimizer.momentum,
            weight_decay=run.optimizer.weight_decay,
        )
    else:
        raise ValueError("Not a valid optimizer")

    return opt_obj


def opt_step(loss, run, X, y, sharpness, iteration):
    along_g_dict = {}
    if (
        run.optimizer.opt_name == "constant_stepsize_GD"
        or run.optimizer.opt_name == "Adam"
        or run.optimizer.opt_name == "AdaGrad"
        or run.optimizer.opt_name == "RMSProp"
    ):
        run.opt_obj.step()

        # Record metrics
        final_step_size = run.optimizer.step_size
        init_step_size = run.optimizer.step_size
        backtracks = 0
        function_evaluations = 1
        a = 0
    elif (
        run.optimizer.opt_name == "warmup_GD"
        or run.optimizer.opt_name == "warmup_GD_small"
    ):
        classes = run.dataset.output_dim
        increase_factor = 1.25
        stepsize_post_warmup = classes * 0.9

        if run.optimizer.opt_name == "warmup_GD":
            global warmup_counter
            stepsize_threshold = 1.25  # 1.01
            num_steps_counter = 1  # 10

            if warmup_counter <= num_steps_counter:
                if (
                    run.optimizer.step_size * increase_factor
                    <= classes * stepsize_threshold
                ):
                    run.opt_obj.param_groups[0]["lr"] = (
                        run.optimizer.step_size * increase_factor
                    )
                elif (
                    run.optimizer.step_size * increase_factor
                    > classes * stepsize_threshold
                ):
                    warmup_counter += 1
                    run.opt_obj.param_groups[0]["lr"] = classes * stepsize_threshold
            else:
                run.opt_obj.param_groups[0]["lr"] = stepsize_post_warmup
            run.optimizer.step_size = run.opt_obj.param_groups[0]["lr"]
            run.opt_obj.step()

        elif run.optimizer.opt_name == "warmup_GD_small":
            stepsize_threshold = 0.9  # 1.01

            if (
                run.optimizer.step_size * increase_factor
                <= classes * stepsize_threshold
            ):
                run.opt_obj.param_groups[0]["lr"] = (
                    run.optimizer.step_size * increase_factor
                )
            else:
                run.opt_obj.param_groups[0]["lr"] = stepsize_post_warmup
            run.optimizer.step_size = run.opt_obj.param_groups[0]["lr"]
            run.opt_obj.step()

        # Record metrics
        final_step_size = run.optimizer.step_size
        init_step_size = run.optimizer.step_size
        backtracks = 0
        function_evaluations = 1
        a = 0
    elif run.optimizer.opt_name == "SAM":
        run.opt_obj.step(loss)

        # Record metrics
        final_step_size = run.optimizer.step_size
        init_step_size = run.optimizer.step_size
        backtracks = 0
        function_evaluations = 2
        a = 0
    elif run.optimizer.opt_name == "CDAT":
        step_size = run.opt_obj.step()

        # Record metrics
        final_step_size = step_size
        init_step_size = step_size
        backtracks = 0
        function_evaluations = 3
        a = 0
    elif run.optimizer.opt_name == "MM":
        step_size = run.opt_obj.step()

        # Record metrics
        final_step_size = step_size
        init_step_size = step_size
        backtracks = 0
        function_evaluations = 1
        a = 0
    elif run.optimizer.opt_name == "polyak_stepsize_GD":
        # Compute polyak step size
        grad_norm = compute_grad_norm(run.model.model_obj.parameters())
        step_size = np.minimum(
            loss.item() / (grad_norm**2 + 1e-8),
            run.optimizer.max_step_size,
        )
        run.opt_obj.param_groups[0]["lr"] = step_size

        run.opt_obj.step()

        # For recording "a"
        loss_next = compute_batch_training_loss_fn(run, X, y)
        a = (2 / step_size**2) * (
            loss_next.item() - loss.item()
        ) / grad_norm + 2 / step_size

        # Record metrics
        final_step_size = run.opt_obj.param_groups[0]["lr"]
        init_step_size = run.opt_obj.param_groups[0]["lr"]
        backtracks = 0
        function_evaluations = 1
    elif (
        run.optimizer.opt_name == "SLS"
        or run.optimizer.opt_name == "PoNoS"
        or run.optimizer.opt_name == "CubicLS"
    ):
        check_Lw_asmpt = "Lw_asmpt" in run.plot_metrics.metrics and (
            (iteration % 100 == 0 and iteration < 1000)
            or iteration % 1000 == 0
            or iteration == run.epochs - 1
        )
        (
            final_step_size,
            init_step_size,
            backtracks,
            function_evaluations,
            a,
            along_g_dict,
        ) = run.opt_obj.step(
            closure=loss,
            sharpness=sharpness,
            iteration=iteration,
            check_Lw_asmpt=check_Lw_asmpt,
        )
    else:
        raise ValueError("Not a valid optimizer")

    return (
        final_step_size,
        init_step_size,
        backtracks,
        function_evaluations,
        a,
        along_g_dict,
    )
