import torch
import copy
import numpy as np
from eos_line_search.utils import *


class MalMis(torch.optim.Optimizer):
    """
    Adaptive gradient descent based on the local smoothness constant

    Arguments:
        params: model parameters
        eps (float, optional): an estimate of 1 / L^2, where L is the global smoothness constant (default: 0.0)
        lr0 (float, optional): initial step size (default: 1e-3)
        alpha (float, optional): scaling of 1/L (default: 0.5)
        gamma (float, optional): scaling of theta_k (default: 1.0)
    """

    def __init__(self, params, eps=0.0, lr0=1e-3, alpha=0.5, gamma=1.0):
        if not 0.0 <= eps:
            raise ValueError("Invalid eps: {}".format(eps))

        defaults = dict(eps=eps, step_size0=lr0)
        super(MalMis, self).__init__(params, defaults)

        self.eps = eps
        self.step_size0 = lr0
        self.alpha = alpha
        self.gamma = gamma
        self.step_size = lr0

        self.state["theta"] = float("inf")
        self.state["w_old"] = None
        self.state["grad_old"] = None
        self.state["step_count"] = 0

    def get_step_size(self, params_current=None, grad_current=None, grad_norm=None):
        """
        Compute the MalMis step size without updating parameters

        Args:
            params_current: current parameters (optional, computed if not provided)
            grad_current: current gradients (optional, computed if not provided)
            grad_norm: current gradient norm (optional, computed if not provided)

        Returns:
            step_size: computed step size
        """
        if params_current is None:
            params_current = copy.deepcopy(self.param_groups[0]["params"])
        if grad_current is None:
            grad_current = get_grad_list(self.param_groups[0]["params"])
        if grad_norm is None:
            grad_norm = compute_grad_norm(self.param_groups[0]["params"])

        # If no previous step, return initial learning rate
        if self.state["w_old"] is None or self.state["grad_old"] is None:
            # Update state for next call
            self.state["w_old"] = params_current
            self.state["grad_old"] = grad_current
            return self.step_size

        # Compute local smoothness constant
        param_flat = torch.cat([p.data.flatten() for p in params_current])
        w_old_flat = torch.cat([p.flatten() for p in self.state["w_old"]])
        grad_flat = torch.cat([g.flatten() for g in grad_current])
        grad_old_flat = torch.cat([g.flatten() for g in self.state["grad_old"]])

        w_diff_norm = torch.norm(param_flat - w_old_flat).item()
        grad_diff_norm = torch.norm(grad_flat - grad_old_flat).item()

        if w_diff_norm == 0:
            return self.step_size

        L = grad_diff_norm / w_diff_norm

        if np.isinf(self.state["theta"]):
            new_step_size = self.alpha / (L + 1e-8)
        else:
            new_step_size = min(
                np.sqrt(1 + self.state["theta"] * self.gamma) * self.step_size,
                self.eps / (self.step_size + 1e-8) + self.alpha / (L + 1e-8),
            )

        # Update state for next call
        self.state["w_old"] = params_current
        self.state["grad_old"] = grad_current

        return new_step_size

    def set_step_size(self, step_size):
        """
        Set the step size externally (e.g., after line search)

        Args:
            step_size: step size to set
        """
        current_step_size = self.step_size
        if current_step_size > 0:
            self.state["theta"] = step_size / current_step_size
        else:
            self.state["theta"] = float("inf")
        self.step_size = step_size

    def step(self):
        """Performs a single optimization step."""
        loss = None
        self.state["step_count"] += 1

        params_current = copy.deepcopy(self.param_groups[0]["params"])
        grad_current = get_grad_list(self.param_groups[0]["params"])

        # Get step size
        step_size = self.get_step_size(params_current, grad_current)
        # Update parameters
        self.gd_update(
            self.param_groups[0]["params"], step_size, params_current, grad_current
        )

        # Update theta and learning rate
        self.set_step_size(step_size)

        return step_size

    def gd_update(self, params, step_size, params_current, grad_current):
        """Update parameters using gradient descent"""
        zipped = zip(params, params_current, grad_current)

        for p_next, p_current, g_current in zipped:
            p_next.data = p_current - step_size * g_current
