import os
import json
import math
import random
import numpy as np
from typing import Any, Callable, NamedTuple, Optional, Tuple, Union
PyTree = Any 

import matplotlib.pyplot as plt
import matplotlib as mpl

## JAX
import jax
import jax.numpy as jnp
from jax import random
from jax.tree_util import tree_map
# Seeding for random operations
main_rng = random.PRNGKey(42)


# SGD optimizer with update learning rate function
class Optimizer(NamedTuple):
    init: Callable[[PyTree], tuple]
    update: Callable[[PyTree, tuple, Optional[PyTree]], Tuple[PyTree, tuple]]

def sgd(lrs: Union[float, jnp.ndarray]):
    def init(params):
        if isinstance(lrs, jnp.ndarray):
            selected_lr = lrs[0]  # Start with the first learning rate
            return (lrs, selected_lr, 0) 
        return (jnp.array([lrs]), lrs, 0) 

    def update(updates, state, params=None):
        lrs, selected_lr, step = state
        updates = tree_map(lambda u: -selected_lr * u, updates)
        return updates, (lrs, selected_lr, step)  # Return the state unchanged

    return Optimizer(init, update)

class MOSS:
    def __init__(self, learning_rates: np.ndarray, gamma: float = 1.0):
        """
        Parameters:
            learning_rates: Array of candidate learning rates (arms).
            gamma: Exploration parameter.
        """
        self.learning_rates = learning_rates
        self.gamma = gamma
        self.num_arms = len(learning_rates)
        
        # Initialize counts and reward estimates for each arm
        self.arm_counts = np.zeros(self.num_arms)  # Number of times each arm has been pulled
        self.arm_rewards = np.zeros(self.num_arms)  # Total reward accumulated for each arm

    def select_arm(self) -> int:
        """
        Selects the best arm (learning rate index)
        
        Returns:
            The index of the selected arm.
        """
        total_pulls = np.sum(self.arm_counts)
        if total_pulls == 0:  # If no arm has been pulled yet
            return 0  # Start with the first arm
        
        ucb_values = []
        for arm in range(self.num_arms):
            if self.arm_counts[arm] == 0:
                # Give a high value to ensure unexplored arms are selected
                ucb_values.append(np.inf)
            else:
                # Compute mean reward and exploration term
                mean_reward = self.arm_rewards[arm] / self.arm_counts[arm]
                exploration = self.gamma * np.sqrt(max(
                    np.log(total_pulls / (self.num_arms * self.arm_counts[arm])), 0
                ) / self.arm_counts[arm])
                ucb_values.append(mean_reward + exploration)
        
        return int(np.argmax(ucb_values))

    def update(self, arm: int, reward: float):
        """
        Updates the statistics for a given arm based on observed reward.
        
        Parameters:
            arm: Index of the arm that was pulled.
            reward: Observed reward for the pulled arm.
        """
        self.arm_counts[arm] += 1
        self.arm_rewards[arm] += reward  

# Set of Non-Convex Loss Functions

# 1. Beale's Function
def beale_loss(w1, w2):
    return (1.5 - w1 + w1 * w2)**2 + (2.25 - w1 + w1 * w2**2)**2 + (2.625 - w1 + w1 * w2**3)**2

# 2. Bohachevsky Function
def bohachevsky_n1_loss(x, y):
    return x**2 + 2 * y**2 - 0.3 * jnp.cos(3 * jnp.pi * x) - 0.4 * jnp.cos(4 * jnp.pi * y) + 0.7

# 3. Griewank Function
def griewank_loss(x, y):
    return 1 + (x**2 + y**2) / 4000 - jnp.cos(x) * jnp.cos(y / jnp.sqrt(2))

# 4. Rosenbrock Loss
def rosenbrock_loss(w1, w2):
    return (1 - w1)**2 + 100 * (w2 - w1**2)**2

# 5. Three Hump Camel Function
def three_hump_camel(x, y):
    return 2 * x**2 - 1.05 * x**4 + (x**6) / 6 + x * y + y**2

# 6. Zakharov Function
def zakharov_loss(x, y):
    return x**2 + y**2 + (0.5 * x + y)**2 + (0.5 * x + y)**4   


def plot_2d_curve(curve_fn, x_range=(-10, 10), y_range=(-10, 10), cmap='viridis', title=None):
    """
    Plots the 2D contour plot of the provided loss function.
    
    Parameters:
        curve_fn - The function defining the curve (e.g., a loss function).
        x_range, y_range - Ranges for the plot axes.
        cmap - Colormap for the contour plot.
        title - Title of the plot.
    """
    # Create grid points
    num_points = 100
    x = np.linspace(x_range[0], x_range[1], num_points)
    y = np.linspace(y_range[0], y_range[1], num_points)
    x, y = np.meshgrid(x, y)

    # Evaluate the function on the grid
    z = curve_fn(x, y)
    z = jnp.nan_to_num(z, nan=1e6, posinf=1e6, neginf=-1e6)  # Replace NaN or Inf with large finite values

    # Plot the contour plot
    plt.figure(figsize=(6, 5))
    cp = plt.contourf(x, y, z, levels=50, cmap=cmap)  # Contour plot with filled regions
    plt.colorbar(cp, label='Loss')
    plt.xlabel(r"$w_1$", fontsize=14)
    plt.ylabel(r"$w_2$", fontsize=14)
    plt.title(title, fontsize=16)
    plt.tight_layout()

    # Return the axis for further plotting
    return plt.gca()    


def train_curve(optimizer, curve_func, num_updates=100, init=[5, 5], scheduler = None):
    weights = jnp.array(init, dtype=jnp.float32)
    grad_fn = jax.jit(jax.value_and_grad(lambda w: curve_func(w[0], w[1])))
    opt_state = optimizer.init(weights)

    # Store loss and pulled arms over the gradient steps
    list_points = []
    losses = []
    lr_list = []
    
    moss = MOSS(opt_state[0], gamma=2.0) if scheduler == "moss" else None

    if moss:
        lr_index = 0
        selected_lr = moss.learning_rates[lr_index]

    for step in range(num_updates):
        loss, grads = grad_fn(weights)

        if moss is not None:

            losses.append(float(loss))
            lr_list.append(float(selected_lr))

            # Update MOSS with the observed reward (1/|loss| + c)
            moss.update(lr_index, 1/abs(loss)+1e-6)
            
            # Select learning rate using MOSS
            lr_index = moss.select_arm() 
            selected_lr = moss.learning_rates[lr_index]         
            
            # Update optimizer state with the new learning rate
            opt_state = (opt_state[0], selected_lr, step)
        else:
            selected_lr = opt_state[1]    

        list_points.append(jnp.concatenate([weights, loss[None]], axis=0))
        updates, opt_state = optimizer.update(grads, opt_state)
        weights = jax.tree_util.tree_map(lambda w, u: w + u, weights, updates)

    points = jnp.stack(list_points, axis=0)
    points = jax.device_get(points)
    return points, lr_list, losses

#--------------------------------------------------------------------------------------------------------------------------------#

lrs = [0.005, 0.01, 0.02]
iterations = 10000
start = [-1, 1]

SGD_points1, [], [] = train_curve(sgd(lrs=0.02), curve_func=beale_loss, num_updates=iterations, init=start, scheduler=False)
SGD_points2, [], [] = train_curve(sgd(lrs=0.01), curve_func=beale_loss, num_updates=iterations, init=start, scheduler=False)
SGD_points3, [], [] = train_curve(sgd(lrs=0.005), curve_func=beale_loss, num_updates=iterations, init=start, scheduler=False)
SGD_MOSS_points, beale_lr, beale_feedback = train_curve(sgd(lrs=jnp.array(lrs)), curve_func=beale_loss, num_updates=iterations, init=start, scheduler="moss")

# Plotting the results
all_points = np.concatenate([SGD_points1, SGD_points2, SGD_points3, SGD_MOSS_points], axis=0)

# Assuming the following coordinates
initial_position = (SGD_points1[0, 0], SGD_points1[0, 1])  # Example: First point from SGD_points1
global_minima = (3, 0.5)

# Plot the loss surface
plot_2d_curve(beale_loss,
              x_range=(-np.abs(all_points[:, 0]).max() - 1, np.abs(all_points[:, 0]).max() + 1),
              y_range=(all_points[:, 1].min() - 1, all_points[:, 1].max() + 1), title="Beale Function")

# Plot points
plt.plot(SGD_points1[:, 0], SGD_points1[:, 1], color="red", marker="o", zorder=1, markersize=4, label="η = $2x10^{-2}$")
plt.plot(SGD_points2[:, 0], SGD_points2[:, 1], color="magenta", marker="o", markersize=4, zorder=4, label="η = $1x10^{-2}$")
plt.plot(SGD_points3[:, 0], SGD_points3[:, 1], color="green", marker="o", markersize=4, zorder=3, label="η = $0.5x10^{-2}$")
plt.plot(SGD_MOSS_points[:, 0], SGD_MOSS_points[:, 1], color="blue", marker="o", markersize=4, zorder=2, label="LRRL")

# Highlight initial position with a special marker
plt.scatter(*initial_position, color='gold', s=100, zorder=6, edgecolor='black', marker='^')
plt.text(initial_position[0] - 0.1, initial_position[1] - 0.3, "Start", fontsize=10, color="gold")

# Highlight global minima with a special marker
plt.scatter(*global_minima, color='gold', s=120, zorder=6, edgecolor='black', marker='*')
plt.text(global_minima[0] - 1.3 , global_minima[1] + 0.2, "Global Minima", fontsize=10, color="gold")

# Highlight final position with a special marker
plt.scatter(SGD_points1[-1, 0], SGD_points1[-1, 1], color='red', s=120, zorder=5, edgecolor='gold', marker='*')
plt.scatter(SGD_points2[-1, 0], SGD_points2[-1, 1], color='magenta', s=120, zorder=5, edgecolor='gold', marker='*')
plt.scatter(SGD_points3[-1, 0], SGD_points3[-1, 1], color='green', s=120, zorder=5, edgecolor='gold', marker='*')
plt.scatter(SGD_MOSS_points[-1, 0], SGD_MOSS_points[-1, 1], color='blue', s=120, zorder=5, edgecolor='gold', marker='*')

handles, labels = plt.gca().get_legend_handles_labels()
order = [0, 1, 2, 3]
plt.legend([handles[idx] for idx in order],[labels[idx] for idx in order])
plt.show()

#--------------------------------------------------------------------------------------------------------------------------------#

lrs = [0.05, 0.1, 0.2]
iterations = 1000
start = [5, 5]

SGD_points1, [], [] = train_curve(sgd(lrs=0.2), curve_func=bohachevsky_n1_loss, num_updates=iterations, init=start, scheduler=False)
SGD_points2, [], [] = train_curve(sgd(lrs=0.1), curve_func=bohachevsky_n1_loss, num_updates=iterations, init=start, scheduler=False)
SGD_points3, [], [] = train_curve(sgd(lrs=0.05), curve_func=bohachevsky_n1_loss, num_updates=iterations, init=start, scheduler=False)
SGD_MOSS_points, bohachevsky_n_lr, bohachevsky_n_feedback = train_curve(sgd(lrs=jnp.array(lrs)), curve_func=bohachevsky_n1_loss, num_updates=iterations, init=start, scheduler="moss")

# Plotting the results
all_points = np.concatenate([SGD_points1, SGD_points2, SGD_points3, SGD_MOSS_points], axis=0)

# Assuming the following coordinates
initial_position = (SGD_points1[0, 0], SGD_points1[0, 1])  # Example: First point from SGD_points1
global_minima = (0, 0)

# Plot the loss surface
plot_2d_curve(bohachevsky_n1_loss,
              x_range=(-np.abs(all_points[:, 0]).max() - 1, np.abs(all_points[:, 0]).max() + 1),
              y_range=(all_points[:, 1].min() - 1, all_points[:, 1].max() + 1), title="Bohachevsky N. 1 Function")

# Plot points
plt.plot(SGD_points1[:, 0], SGD_points1[:, 1], color="red", marker="o", zorder=1, markersize=4, label="η = $2x10^{-1}$")
plt.plot(SGD_points2[:, 0], SGD_points2[:, 1], color="magenta", marker="o", markersize=4, zorder=4, label="η = $1x10^{-1}$")
plt.plot(SGD_points3[:, 0], SGD_points3[:, 1], color="green", marker="o", markersize=4, zorder=3, label="η = $0.5x10^{-1}$")
plt.plot(SGD_MOSS_points[:, 0], SGD_MOSS_points[:, 1], color="blue", marker="o", markersize=4, zorder=2, label="LRRL")


# Highlight initial position with a special marker
plt.scatter(*initial_position, color='gold', s=100, zorder=6, edgecolor='black', marker='^')
plt.text(initial_position[0] - 0.2, initial_position[1] + 0.2, "Start", fontsize=10, color="gold")

# Highlight global minima with a special marker
plt.scatter(*global_minima, color='gold', s=120, zorder=6, edgecolor='black', marker='*')
plt.text(global_minima[0] - 2, global_minima[1] - 0.2, "Global Minima", fontsize=10, color="gold")

# Highlight final position with a special marker
plt.scatter(SGD_points1[-1, 0], SGD_points1[-1, 1], color='red', s=120, zorder=5, edgecolor='gold', marker='*')
plt.scatter(SGD_points2[-1, 0], SGD_points2[-1, 1], color='magenta', s=120, zorder=5, edgecolor='gold', marker='*')
plt.scatter(SGD_points3[-1, 0], SGD_points3[-1, 1], color='green', s=120, zorder=5, edgecolor='gold', marker='*')
plt.scatter(SGD_MOSS_points[-1, 0], SGD_MOSS_points[-1, 1], color='blue', s=120, zorder=5, edgecolor='gold', marker='*')

handles, labels = plt.gca().get_legend_handles_labels()
order = [0, 1, 2, 3]
plt.legend([handles[idx] for idx in order],[labels[idx] for idx in order])
plt.show()

#--------------------------------------------------------------------------------------------------------------------------------#

lrs = [5, 10, 20]
iterations = 500
start = [-600, 600]

SGD_points1, [], [] = train_curve(sgd(lrs=20), curve_func=griewank_loss, num_updates=iterations, init=start, scheduler=False)
SGD_points2, [], [] = train_curve(sgd(lrs=10), curve_func=griewank_loss, num_updates=iterations, init=start, scheduler=False)
SGD_points3, [], [] = train_curve(sgd(lrs=5), curve_func=griewank_loss, num_updates=iterations, init=start, scheduler=False)
SGD_MOSS_points, griewank_lr, griewank_feedback = train_curve(sgd(lrs=jnp.array(lrs)), curve_func=griewank_loss, num_updates=iterations, init=start, scheduler="moss")

# Plotting the results
all_points = np.concatenate([SGD_points1, SGD_points2, SGD_points3], axis=0)

# # Assuming the following coordinates
initial_position = (SGD_points1[0, 0], SGD_points1[0, 1])  # Example: First point from SGD_points1
global_minima = (0, 0)

# Plot the loss surface
plot_2d_curve(griewank_loss,
              x_range=(-np.abs(all_points[:, 0]).max(), np.abs(all_points[:, 0]).max() + 2),
              y_range=(all_points[:, 1].min(), all_points[:, 1].max() + 2), title="Griewank Function")

# Plot points
plt.plot(SGD_points1[:, 0], SGD_points1[:, 1], color="red", marker="o", markersize=4, zorder=1, label="η = 20")
plt.plot(SGD_points2[:, 0], SGD_points2[:, 1], color="magenta", marker="o", markersize=4, zorder=3, label="η = 10")
plt.plot(SGD_points3[:, 0], SGD_points3[:, 1], color="green", marker="o", markersize=4, zorder=4, label="η = 5")

plt.plot(SGD_MOSS_points[:, 0], SGD_MOSS_points[:, 1], color="blue", marker="o", markersize=4, zorder=2, label="LRRL")

# Highlight initial position with a special marker
plt.scatter(*initial_position, color='gold', s=100, zorder=5, edgecolor='black', marker='^')
plt.text(initial_position[0] + 0.2, initial_position[1], "Start", fontsize=10, color="gold")

# # # Highlight global minima with a special marker
plt.scatter(*global_minima, color='gold', s=120, zorder=6, edgecolor='black', marker='*')
plt.text(global_minima[0]+ 2, global_minima[1], "Global Minima", fontsize=10, color="gold")

# Highlight final position with a special marker
plt.scatter(SGD_points1[-1, 0], SGD_points1[-1, 1], color='red', s=120, zorder=6, edgecolor='gold', marker='*')
plt.scatter(SGD_points2[-1, 0], SGD_points2[-1, 1], color='magenta', s=120, zorder=6, edgecolor='gold', marker='*')
plt.scatter(SGD_points3[-1, 0], SGD_points3[-1, 1], color='green', s=120, zorder=6, edgecolor='gold', marker='*')
plt.scatter(SGD_MOSS_points[-1, 0], SGD_MOSS_points[-1, 1], color='blue', s=120, zorder=6, edgecolor='gold', marker='*')

handles, labels = plt.gca().get_legend_handles_labels()
order = [0, 1, 2, 3]
plt.legend([handles[idx] for idx in order],[labels[idx] for idx in order])
plt.show()

#--------------------------------------------------------------------------------------------------------------------------------#

lrs = [0.00005, 0.0001, 0.0002]
iterations = 30000
start = [-2, 2]

SGD_points1, [], [] = train_curve(sgd(lrs=0.0002), curve_func=rosenbrock_loss, num_updates=iterations, init=start, scheduler=False)
SGD_points2, [], [] = train_curve(sgd(lrs=0.0001), curve_func=rosenbrock_loss, num_updates=iterations, init=start, scheduler=False)
SGD_points3, [], [] = train_curve(sgd(lrs=0.00005), curve_func=rosenbrock_loss, num_updates=iterations, init=start, scheduler=False)
SGD_MOSS_points, rosenbrock_lr, rosenbrock_feedback = train_curve(sgd(lrs=jnp.array(lrs)), curve_func=rosenbrock_loss, num_updates=iterations, init=start, scheduler="moss")

# Plotting the results
all_points = np.concatenate([SGD_points1, SGD_points2, SGD_points3, SGD_MOSS_points,], axis=0)

# Assuming the following coordinates 
initial_position = (SGD_points1[0, 0], SGD_points1[0, 1])  # Example: First point from SGD_points1
global_minima = (1, 1)

# Plot the loss surface
plot_2d_curve(rosenbrock_loss,
              x_range=(-np.abs(all_points[:, 0]).max() - 0.5, np.abs(all_points[:, 0]).max() + 0.5),
              y_range=(all_points[:, 1].min() - 0.1, all_points[:, 1].max() + 0.1), title="Rosenbrock Function")

#Plot points
plt.plot(SGD_points1[:, 0], SGD_points1[:, 1], color="red", marker="o", markersize=4, zorder=1, label="η = $2x10^{-4}$")
plt.plot(SGD_points2[:, 0], SGD_points2[:, 1], color="magenta", marker="o", markersize=4, zorder=3, label="η = $1x10^{-4}$")
plt.plot(SGD_points3[:, 0], SGD_points3[:, 1], color="green", marker="o", markersize=4, zorder=4, label="η = $0.5x10^{-4}$")
plt.plot(SGD_MOSS_points[:, 0], SGD_MOSS_points[:, 1], color="blue", marker="o", markersize=4, zorder=2, label="LRRL")

# Highlight initial position with a special marker
plt.scatter(*initial_position, color='gold', s=100, zorder=5, edgecolor='black', marker='^')
plt.text(initial_position[0] - 0.2, initial_position[1]  - 0.12, "Start", fontsize=10, color="gold")

# Highlight global minima with a special marker
plt.scatter(*global_minima, color='gold', s=120, zorder=6, edgecolor='black', marker='*')
plt.text(global_minima[0]+ 0.1, global_minima[1], "Global Minima", fontsize=10, color="gold")

# Highlight final position with a special marker
plt.scatter(SGD_points1[-1, 0], SGD_points1[-1, 1], color='red', s=120, zorder=6, edgecolor='gold', marker='*')
plt.scatter(SGD_points2[-1, 0], SGD_points2[-1, 1], color='magenta', s=120, zorder=6, edgecolor='gold', marker='*')
plt.scatter(SGD_points3[-1, 0], SGD_points3[-1, 1], color='green', s=120, zorder=6, edgecolor='gold', marker='*')
plt.scatter(SGD_MOSS_points[-1, 0], SGD_MOSS_points[-1, 1], color='blue', s=120, zorder=6, edgecolor='gold', marker='*')

handles, labels = plt.gca().get_legend_handles_labels()
order = [0, 1, 2, 3]

plt.legend([handles[idx] for idx in order],[labels[idx] for idx in order])
plt.show()

#--------------------------------------------------------------------------------------------------------------------------------#

lrs = [0.005, 0.01, 0.02]
iterations = 10000
start = [3, 3]

SGD_points1, [], [] = train_curve(sgd(lrs=0.02), curve_func=three_hump_camel, num_updates=iterations, init=start, scheduler=False)
SGD_points2, [], [] = train_curve(sgd(lrs=0.01), curve_func=three_hump_camel, num_updates=iterations, init=start, scheduler=False)
SGD_points3, [], [] = train_curve(sgd(lrs=0.005), curve_func=three_hump_camel, num_updates=iterations, init=start, scheduler=False)
SGD_MOSS_points, camel_lr, camel_feedback = train_curve(sgd(lrs=jnp.array(lrs)), curve_func=three_hump_camel, num_updates=iterations, init=start, scheduler="moss")

# Plotting the results
all_points = np.concatenate([SGD_points1, SGD_MOSS_points, SGD_points2, SGD_points3], axis=0)

# # # Assuming the following coordinates
initial_position = (SGD_points1[0, 0], SGD_points1[0, 1]) 
global_minima = (0, 0)

# Plot the loss surface
plot_2d_curve(three_hump_camel,
              x_range=(-np.abs(all_points[:, 0]).max() - 0.5, np.abs(all_points[:, 0]).max() + 0.5),
              y_range=(all_points[:, 1].min() - 0.1, all_points[:, 1].max() + 0.1), title="Three-Hump Camel Function")

#Plot points
plt.plot(SGD_points1[:, 0], SGD_points1[:, 1], color="red", marker="o", markersize=4, zorder=1, label="η = $2x10^{-2}$")
plt.plot(SGD_points2[:, 0], SGD_points2[:, 1], color="magenta", marker="o", markersize=4, zorder=2, label="η = $1x10^{-2}$")
plt.plot(SGD_points3[:, 0], SGD_points3[:, 1], color="green", marker="o", markersize=4, zorder=4, label="η = $0.5x10^{-2}$")
plt.plot(SGD_MOSS_points[:, 0], SGD_MOSS_points[:, 1], color="blue", marker="o", markersize=4, zorder=3, label="LRRL")

# # Highlight initial position with a special marker
plt.scatter(*initial_position, color='gold', s=100, zorder=5, edgecolor='black', marker='^')
plt.text(initial_position[0] - 0.3, initial_position[1]  - 0.23, "Start", fontsize=10, color="gold")

# Highlight global minima with a special marker
plt.scatter(*global_minima, color='gold', s=120, zorder=6, edgecolor='black', marker='*')
plt.text(global_minima[0] - 0.8, global_minima[1] - 0.23, "Global Minima", fontsize=10, color="gold")

# Highlight final position with a special marker
plt.scatter(SGD_points1[-1, 0], SGD_points1[-1, 1], color='red', s=120, zorder=5, edgecolor='gold', marker='*')
plt.scatter(SGD_points2[-1, 0], SGD_points2[-1, 1], color='magenta', s=120, zorder=5, edgecolor='gold', marker='*')
plt.scatter(SGD_points3[-1, 0], SGD_points3[-1, 1], color='green', s=120, zorder=5, edgecolor='gold', marker='*')
plt.scatter(SGD_MOSS_points[-1, 0], SGD_MOSS_points[-1, 1], color='blue', s=120, zorder=5, edgecolor='gold', marker='*')

# Legend ordering
handles, labels = plt.gca().get_legend_handles_labels()
order = [0, 1, 2, 3]
plt.legend([handles[idx] for idx in order], [labels[idx] for idx in order])

plt.show()

#--------------------------------------------------------------------------------------------------------------------------------#

lrs = [0.005, 0.01, 0.02]
iterations = 100
start = [5, 1]

SGD_points1, [], [] = train_curve(sgd(lrs=0.02), curve_func=zakharov_loss, num_updates=iterations, init=start, scheduler=False)
SGD_points2, [], [] = train_curve(sgd(lrs=0.01), curve_func=zakharov_loss, num_updates=iterations, init=start, scheduler=False)
SGD_points3, [], [] = train_curve(sgd(lrs=0.005), curve_func=zakharov_loss, num_updates=iterations, init=start, scheduler=False)
SGD_MOSS_points, zakharov_lr, zakharov_feedback = train_curve(sgd(lrs=jnp.array(lrs)), curve_func=zakharov_loss, num_updates=iterations, init=start, scheduler="moss")

# Plotting the results
all_points = np.concatenate([SGD_points1, SGD_points2, SGD_points3], axis=0)

# # Assuming the following coordinates
initial_position = (SGD_points1[0, 0], SGD_points1[0, 1])  # Example: First point from SGD_points1
global_minima = (0, 0)

# Plot the loss surface
plot_2d_curve(zakharov_loss,
              x_range=(-np.abs(all_points[:, 0]).max() - 1, np.abs(all_points[:, 0]).max() + 1),
              y_range=(all_points[:, 1].min() - 1, all_points[:, 1].max() + 1), title="Zakharov Function")

# Plot points
plt.plot(SGD_points1[:, 0], SGD_points1[:, 1], color="red", marker="o", zorder=1, markersize=4, label="η = $2x10^{-2}$")
plt.plot(SGD_points2[:, 0], SGD_points2[:, 1], color="magenta", marker="o", markersize=4, zorder=2, label="η = $1x10^{-2}$")
plt.plot(SGD_points3[:, 0], SGD_points3[:, 1], color="green", marker="o", markersize=4, zorder=4, label="η = $0.5x10^{-2}$")
plt.plot(SGD_MOSS_points[:, 0], SGD_MOSS_points[:, 1], color="blue", marker="o", markersize=4, zorder=3, label="LRRL")

# Highlight initial position with a special marker
plt.scatter(*initial_position, color='gold', s=100, zorder=5, edgecolor='black', marker='^')
plt.text(initial_position[0] - 0.25, initial_position[1] + 0.1, "Start", fontsize=10, color="gold")

# Highlight global minima with a special marker
plt.scatter(*global_minima, color='gold', s=120, zorder=6, edgecolor='black', marker='*')
plt.text(global_minima[0] - 0.7, global_minima[1] + 0.15, "Global Minima", fontsize=10, color="gold")

# Highlight final position with a special marker
plt.scatter(SGD_points1[-1, 0], SGD_points1[-1, 1], color='red', s=120, zorder=5, edgecolor='gold', marker='*')
plt.scatter(SGD_points2[-1, 0], SGD_points2[-1, 1], color='magenta', s=120, zorder=5, edgecolor='gold', marker='*')
plt.scatter(SGD_points3[-1, 0], SGD_points3[-1, 1], color='green', s=120, zorder=5, edgecolor='gold', marker='*')
plt.scatter(SGD_MOSS_points[-1, 0], SGD_MOSS_points[-1, 1], color='blue', s=120, zorder=5, edgecolor='gold', marker='*')

handles, labels = plt.gca().get_legend_handles_labels()
order = [0, 1, 2, 3]
plt.legend([handles[idx] for idx in order],[labels[idx] for idx in order])
plt.show()

#--------------------------------------------------------------------------------------------------------------------------------#