# filled contour plot for 2d objective function and show the optima
from numpy import arange
from numpy import meshgrid
import numpy as np
from matplotlib import pyplot as plt
from matplotlib import colormaps
import argparse

prop_cycle = plt.rcParams['axes.prop_cycle']
colors = prop_cycle.by_key()['color']
prec = 5

class Adam():
    def __init__(self, k, init=[]):
        self.k = k
        self.beta1,self.beta2 = 0.9,0.999
        self.init,self.m,self.v = [],[],[]
        self.condition_buffer = []
        for init_val in init:
            self.init.append(init_val)
            self.m.append(0)
            self.v.append(0)
            self.condition_buffer.append(0)
        
    def step(self,t, param_list=[],grad_list=[],lr=0.1):
        new_param_list = []
        for i, (param,grad) in enumerate(zip(param_list,grad_list)):
            # Decay the first and second moment running average coefficient 
            self.m[i] = self.beta1 * self.m[i] + (1 - self.beta1) * grad
            self.v[i] = self.beta2 * self.v[i] + (1 - self.beta2) * grad**2
            # Bias Correction
            m = self.m[i] / (1 - self.beta1**t)
            v = self.v[i] / (1 - self.beta2**t)
            # Update
            new_param = param - lr * m/(v**(1/2)+1e-8)
            # Selective Projection Decay
            if self.k > 0:
                new_param = self.SDP(grad,self.init[i],new_param,param,i) 
            new_param_list.append(new_param)
        return new_param_list

    def SDP(self,grad,init,curr,prev,i):
        self.condition_buffer[i] -= grad*(prev - init)
        if self.condition_buffer[i]<0:
            gamma_curr,gamma_prev = np.linalg.norm(curr-init),np.linalg.norm(prev-init)
            ratio = (gamma_curr-gamma_prev)/gamma_curr
            ratio = np.minimum(np.maximum(0,ratio),1.0)
        else:
            ratio = 0
        return curr - self.k * ratio * (curr - init)

def objective(x, y):
    # Define Himmelblau's function
    return (x**2 + y - 11)**2 + (x + y**2 -7)**2

def grad(x,y):
    # Gradient of Himmelblau's function
    x = round(x,prec)
    y = round(y,prec)
    return 4*(x**2 + y - 11)*x + 2*(x + y**2 -7), 4*(x + y**2-7)*y + 2*(x**2 + y - 11)

def opt(x_init, y_init, iters=20, lr=0.1, k = 0):
    # Optimization Loop
    optimizer = Adam(k, init=[x_init,y_init])
    x_hist=[x_init]
    y_hist=[y_init]
    lr_decay = lr/iters
    for t in range(iters):
        x_grad,y_grad = grad(x_hist[-1], y_hist[-1])
        x,y = optimizer.step(t+1, param_list=[x_hist[-1], y_hist[-1]], grad_list=[x_grad, y_grad], lr=lr)
        x_hist.append(x)
        y_hist.append(y)
        # Linear learning rate decay
        lr = np.maximum(0,lr - lr_decay) 
    return x_hist, y_hist


def main(args):
    
    # Define range for input
    r_min, r_max = -5.0, 5.0
    # Sample input range uniformly at 0.1 increments
    xaxis = arange(r_min, r_max, 0.1)
    yaxis = arange(r_min, r_max, 0.1)
    # Create a mesh from the axis
    x, y = meshgrid(xaxis, yaxis)
    # Compute targets
    results = objective(x, y)
    # Create a filled contour plot with 50 levels and jet color scheme
    plt.figure()
    plt.contourf(x, y, results, levels=50, cmap='jet',alpha=0.5)

    for i,k in enumerate(np.linspace(0,1.0,5)):
        # Optimization & Plot
        x_hist, y_hist = opt(args.x_init, args.y_init, iters=args.iters, lr=args.lr, k = k)
        plt.plot(x_hist, y_hist, '-', color=colormaps['Oranges'](i*50), markerfacecolor=colors[i],label="Adam-SPD($\lambda$:{})".format(str(round(k,2))))

    # Draw the starting point as a white star
    plt.plot([args.x_init], [args.y_init], '^', color='yellow')
    plt.plot([3], [2], '*', color='white')
    plt.plot([-3.78], [-3.78], '*', color='white')
    plt.plot([-2.81], [3.13], '*', color='white')
    plt.plot([3.58], [-1.85], '*', color='white')
    plt.legend()
    plt.savefig("./trajectory.png",bbox_inches='tight',pad_inches=0, dpi = 500)

def get_args_parser():
    parser = argparse.ArgumentParser('DeiT training and evaluation script', add_help=False)
    parser.add_argument('--x_init','-X', default=0, type=float, help="x inialization location")
    parser.add_argument('--y_init','-Y', default=-2, type=float, help="y inialization location")
    parser.add_argument('--lr', default=0.6, type=float, help="learning rate")
    parser.add_argument('--iters', default=60, type=int, help="number of optimization iterations")
    return parser

if __name__ == '__main__':
    # Example: python trajectory.py -X 0 -Y -2 --lr 0.6 --iters 60
    parser = argparse.ArgumentParser('Himmelblau function toy example', parents=[get_args_parser()])
    args = parser.parse_args()
    main(args)