import torch
import copy
import numpy as np
import time
import argparse
import torch.nn.functional as Fn
from sklearn.utils.extmath import safe_sparse_dot
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
from matplotlib.font_manager import FontProperties
import psutil as psutil

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


parser = argparse.ArgumentParser(description='strongly_convex')
parser.add_argument('--seed', type=int, default=2)
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='disables CUDA training')
parser.add_argument('--x_loop', type=int, default=5000)
parser.add_argument('--y_loop', type=int, default=100)
parser.add_argument('--x_lr', type=float, default=0.1)
parser.add_argument('--y_lr', type=float, default=0.1)
parser.add_argument('--xSize', type=int, default=1000)
parser.add_argument('--ySize', type=int, default=1000)
parser.add_argument('--log', type=int, default=5)#10

args = parser.parse_args()

print(args)

def loss_L2(parameters):
    loss = 0
    for w in parameters:
        loss += torch.norm(w, 2) ** 2
    return loss

def positive_matrix(m):
    randt = torch.rand(m) + 1
    matrix0 = torch.diag(randt)
    invmatrix0 = torch.diag(1 / randt)
    Q = torch.rand(m, m)
    Q, R = torch.qr(Q)
    matrix = torch.mm(torch.mm(Q.t(), matrix0), Q)
    invmatrix = torch.mm(torch.mm(Q.t(), invmatrix0), Q)
    return matrix, invmatrix

#problem setting
A, invA = positive_matrix(args.ySize)
invA =torch.inverse(A)
z0 = torch.rand([args.xSize, 1]) * 1
D=torch.eye(args.xSize)
invaD= torch.inverse(invA + D)
xstar = torch.mm(invaD , z0)
ystar= torch.mm(invA , xstar)

def F(x, y):
    tmp = x - z0
    return 0.5 *torch.mm(tmp.t(),tmp) + 0.5 * torch.mm(y.t(),A@y)

def f(x, y):
    return 0.5  * torch.mm(y.t(),A@y) - torch.mm(x.t(), y )

#calculate gradient
def f_y(x,y, retain_graph=False, create_graph=False):
    loss = f(x,y)
    grad = torch.autograd.grad(loss, y,
                               retain_graph=retain_graph,
                               create_graph=create_graph)[0]
    return grad

def f_x(x,y, retain_graph=False, create_graph=False):
    loss = f(x,y)
    grad = torch.autograd.grad(loss, x,
                               retain_graph=retain_graph,
                               create_graph=create_graph)[0]
    return grad

def F_y(x,y, retain_graph=False, create_graph=False):
    loss = F(x,y)
    grad = torch.autograd.grad(loss, y,
                               retain_graph=retain_graph,
                               create_graph=create_graph)[0]
    return grad

def F_x(x,y, retain_graph=False, create_graph=False):
    loss = F(x,y)
    grad = torch.autograd.grad(loss, x,
                               retain_graph=retain_graph,
                               create_graph=create_graph)[0]
    return grad

def f_xy(x,y,vs):
    gra=torch.autograd.grad(f(x,y), y, retain_graph=True,allow_unused=True,create_graph=True,only_inputs=True)[0]
    gra.requires_grad_(True)
    grad=torch.autograd.grad(gra, x, grad_outputs=vs, retain_graph=True,
                                 allow_unused=True)[0]
    return grad if grad is not None else torch.zeros_like(x)
      
def f_yy(x,y,vs):
    gra=torch.autograd.grad(f(x,y), y, retain_graph=True,allow_unused=True,create_graph=True,only_inputs=True)[0]
    gra.requires_grad_(True)
    grad=torch.autograd.grad(gra, y, grad_outputs=vs, retain_graph=True,
                                 allow_unused=True)[0]
    return grad if grad is not None else torch.zeros_like(y)

def f_y_yhat_x(y, yhat, x, retain_graph=False, create_graph=False):
    loss = f(x, y) - f(x,yhat.detach())
    grad = torch.autograd.grad(loss, [y, x],
                               retain_graph=retain_graph,
                               create_graph=create_graph)
    return loss, grad[0], grad[1]



def bfgs(x,y,tol,step,maxiter_hg,m,h0,ex_up=False): 
            y_list, s_list, mu_list = [], [], []
            y1_list, s1_list, mu1_list = [], [], [] 
            for k in range(1, step + 1):
                if k<3:
                   s=-f_y(x,y)
                   y=y+0.1*s
                   new_grad=f_y(x,y)
                   ngrad=new_grad.detach().numpy()
                   ngrad=np.squeeze(ngrad)
                else:                
                   p = two_loops(grady, m, s_list, y_list, mu_list,h0)#default H0=I
                   s= p
                   s=np.expand_dims(s,axis=1)
                   st=torch.from_numpy(s)
                   y=y+st
                   new_grad=f_y(x,y)#
                   ngrad=new_grad.detach().numpy()#\nabla_y f(x_k,y_{k+1})
                   ngrad=np.squeeze(ngrad)
                   yg=ngrad-grady
                   yg=np.squeeze(yg)
                   s=np.squeeze(s)
                    # Update the memory
                   if (safe_sparse_dot(yg,s))>1e-10:
                       y_list.append(yg.copy())
                       s_list.append(s.copy())
                       mu=1/safe_sparse_dot(yg,s)
                       mu_list.append(mu)
                   if len(y_list) > m:
                      y_list.pop(0)
                      s_list.pop(0)
                      mu_list.pop(0)
                grady=ngrad
                
            ogrady = F_y(x,y)# dy F
            gradFy=ogrady.detach().numpy()#\nabla_y F(x_k,y_{k+1})
            gradFy=np.squeeze(gradFy)
            
            if ex_up==False:
               hg = -two_loops(gradFy, m, s_list, y_list, mu_list,h0)
               hg=np.expand_dims(hg,axis=1)
               et=torch.from_numpy(hg)
            else:
                for i in range (1, maxiter_hg + 1):
                    eq = -two_loops(gradFy, m, s1_list, y1_list, mu1_list,h0)#default H0=I
                    eq=np.expand_dims(eq,axis=1)
                    et=torch.from_numpy(eq)
                    f1grad=f_y(x,y+et)
                    f1grad=f1grad.detach().numpy()
                    f1grad=np.squeeze(f1grad)
                    eq=np.squeeze(eq)
                    y_tilde1 = f1grad- grady
                    if safe_sparse_dot(y_tilde1, eq)>1e-10:
                       mu1 = 1 / safe_sparse_dot(y_tilde1, eq)
                       y1_list.append(y_tilde1.copy())
                       s1_list.append(eq.copy())
                       mu1_list.append(mu1)
                    if len(y1_list) > m:
                       y1_list.pop(0)
                       s1_list.pop(0)
                       mu1_list.pop(0)
            
            print(f'{k} iterates')
            return y, et

def two_loops(grad_y, m, s_list, y_list, mu_list,h0):
            q = grad_y.copy()
            alpha_list = []
            for s, y, mu in zip(reversed(s_list), reversed(y_list), reversed(mu_list)):
                alpha = mu * safe_sparse_dot(s, q)
                alpha_list.append(alpha)
                q -= alpha * y
            r=q
            
            for s, y, mu, alpha in zip(s_list, y_list, mu_list, reversed(alpha_list)):
                beta = mu * safe_sparse_dot(y, r)
                r += (alpha - beta) * s
            return -r


#Initialization  
x0=2
y0=2
x = (float(x0) * torch.ones([args.xSize,1])).requires_grad_(True)
y = (float(y0) * torch.ones([args.ySize,1])).requires_grad_(True)
x_loop = args.x_loop


#algorithm
       
with torch.no_grad():
    xgard0=torch.mm(D+invA,x)-z0
    dx0=torch.norm(xgard0)
    xdis0=torch.norm(x - xstar) /torch.norm( xstar)
    ydis0=torch.norm(y-ystar) / torch.norm(ystar)
    print(dx0)
    print(xdis0)
xgrad=torch.zeros([args.xSize, 1])


#blfoaeqk
xdislistfoaeqk=[]
ydislistfoaeqk=[]
dxlistfoaeqk=[]
timelistfoaeqk= [] 
x0=2
y0=2
x = (float(x0) * torch.ones([args.xSize, 1])).requires_grad_(True)
y = (float(y0) * torch.ones([args.ySize, 1])).requires_grad_(True)
dxlistfoaeqk.append(copy.deepcopy(dx0.detach().cpu().numpy()))

total_time = 0.0
timelistfoaeqk.append(total_time)
xdislistfoaeqk.append(copy.deepcopy(xdis0.detach().cpu().numpy()))
ydislistfoaeqk.append(copy.deepcopy(ydis0.detach().cpu().numpy()))

for x_itr in range(100):
    t0 = time.time()
    y,et= bfgs(x,y,tol=1/(100*(x_itr+1)),step=15,maxiter_hg=x_itr+1,m=30,h0=0.1,ex_up=True)
    
    Fx=F_x(x,y)
    xgrad=Fx+et

    x=x-0.1*xgrad
    t1 = time.time()
    total_time += t1 - t0


    with torch.no_grad():
              xgard=torch.mm(D+invA,x)-z0
              dx=torch.norm(xgrad-xgard)
              dx1=torch.norm(xgard)
              xdis=torch.norm(x - xstar) /torch.norm( xstar)
              ydis=torch.norm(y-ystar) / torch.norm(ystar)

    
    if x_itr % args.log == 0:
                print('x_itr={},xdist={:.6f},ydist={:.6f}, total_time={:.6f}'.format(
                x_itr,  xdis.detach().cpu().numpy(),ydis.detach().cpu().numpy(), total_time))
                print(torch.norm(xgrad))
                print(torch.norm(xgard))
                print(dx)
    
    timelistfoaeqk.append(total_time)
    dxlistfoaeqk.append(copy.deepcopy(dx1.detach().cpu().numpy()))
    xdislistfoaeqk.append(copy.deepcopy(xdis.detach().cpu().numpy()))
    ydislistfoaeqk.append(copy.deepcopy(ydis.detach().cpu().numpy()))
   



#blfoaeq1
xdislistfoaeq1=[]
ydislistfoaeq1=[]
dxlistfoaeq1=[]
timelistfoaeq1= [] 
x0=2
y0=2
x = (float(x0) * torch.ones([args.xSize, 1])).requires_grad_(True)
y = (float(y0) * torch.ones([args.ySize, 1])).requires_grad_(True)
dxlistfoaeq1.append(copy.deepcopy(dx0.detach().cpu().numpy()))

total_time = 0.0
timelistfoaeq1.append(total_time)
xdislistfoaeq1.append(copy.deepcopy(xdis0.detach().cpu().numpy()))
ydislistfoaeq1.append(copy.deepcopy(ydis0.detach().cpu().numpy()))

for x_itr in range(100):
    t0 = time.time()
    y,et= bfgs(x,y,tol=1/(100*(x_itr+1)),step=15,maxiter_hg=1,m=30,h0=0.1,ex_up=False)
    
    Fx=F_x(x,y)
    xgrad=Fx+et

    x=x-0.1*xgrad
    t1 = time.time()
    total_time += t1 - t0


    with torch.no_grad():
              xgard=torch.mm(D+invA,x)-z0
              dx=torch.norm(xgrad-xgard)
              dx1=torch.norm(xgard)
              xdis=torch.norm(x - xstar) /torch.norm( xstar)
              ydis=torch.norm(y-ystar) / torch.norm(ystar)

    
    if x_itr % args.log == 0:
                print('x_itr={},xdist={:.6f},ydist={:.6f}, total_time={:.6f}'.format(
                x_itr,  xdis.detach().cpu().numpy(),ydis.detach().cpu().numpy(), total_time))
                print(torch.norm(xgrad))
                print(torch.norm(xgard))
                print(dx)
    
    timelistfoaeq1.append(total_time)
    dxlistfoaeq1.append(copy.deepcopy(dx1.detach().cpu().numpy()))
    xdislistfoaeq1.append(copy.deepcopy(xdis.detach().cpu().numpy()))
    ydislistfoaeq1.append(copy.deepcopy(ydis.detach().cpu().numpy()))
   



#blfoae
xdislistfoaeq5=[]
ydislistfoaeq5=[]
dxlistfoaeq5=[]
timelistfoaeq5= [] 
x0=2
y0=2
x = (float(x0) * torch.ones([args.xSize, 1])).requires_grad_(True)
y = (float(y0) * torch.ones([args.ySize, 1])).requires_grad_(True)
dxlistfoaeq5.append(copy.deepcopy(dx0.detach().cpu().numpy()))

total_time = 0.0
timelistfoaeq5.append(total_time)
xdislistfoaeq5.append(copy.deepcopy(xdis0.detach().cpu().numpy()))
ydislistfoaeq5.append(copy.deepcopy(ydis0.detach().cpu().numpy()))

for x_itr in range(100):
    t0 = time.time()
    y,et= bfgs(x,y,tol=1/(100*(x_itr+1)),step=15,maxiter_hg=10,m=30,h0=0.1,ex_up=True)
    
    Fx=F_x(x,y)
    xgrad=Fx+et

    x=x-0.1*xgrad
    t1 = time.time()
    total_time += t1 - t0


    with torch.no_grad():
              xgard=torch.mm(D+invA,x)-z0
              dx=torch.norm(xgrad-xgard)
              dx1=torch.norm(xgard)
              xdis=torch.norm(x - xstar) /torch.norm( xstar)
              ydis=torch.norm(y-ystar) / torch.norm(ystar)

    
    if x_itr % args.log == 0:
                print('x_itr={},xdist={:.6f},ydist={:.6f}, total_time={:.6f}'.format(
                x_itr,  xdis.detach().cpu().numpy(),ydis.detach().cpu().numpy(), total_time))
                print(torch.norm(xgrad))
                print(torch.norm(xgard))
                print(dx)
    
    timelistfoaeq5.append(total_time)
    dxlistfoaeq5.append(copy.deepcopy(dx1.detach().cpu().numpy()))
    xdislistfoaeq5.append(copy.deepcopy(xdis.detach().cpu().numpy()))
    ydislistfoaeq5.append(copy.deepcopy(ydis.detach().cpu().numpy()))
    


#blfoae
xdislistfoaeq10=[]
ydislistfoaeq10=[]
dxlistfoaeq10=[]
timelistfoaeq10= [] 
x0=2
y0=2
x = (float(x0) * torch.ones([args.xSize, 1])).requires_grad_(True)
y = (float(y0) * torch.ones([args.ySize, 1])).requires_grad_(True)
dxlistfoaeq10.append(copy.deepcopy(dx0.detach().cpu().numpy()))

total_time = 0.0
timelistfoaeq10.append(total_time)
xdislistfoaeq10.append(copy.deepcopy(xdis0.detach().cpu().numpy()))
ydislistfoaeq10.append(copy.deepcopy(ydis0.detach().cpu().numpy()))

for x_itr in range(100):
    t0 = time.time()
    y,et= bfgs(x,y,tol=1/(100*(x_itr+1)),step=15,maxiter_hg=20,m=30,h0=0.1,ex_up=True)
    
    Fx=F_x(x,y)
    xgrad=Fx+et

    x=x-0.1*xgrad
    t1 = time.time()
    total_time += t1 - t0


    with torch.no_grad():
              xgard=torch.mm(D+invA,x)-z0
              dx=torch.norm(xgrad-xgard)
              dx1=torch.norm(xgard)
              xdis=torch.norm(x - xstar) /torch.norm( xstar)
              ydis=torch.norm(y-ystar) / torch.norm(ystar)

    
    if x_itr % args.log == 0:
                print('x_itr={},xdist={:.6f},ydist={:.6f}, total_time={:.6f}'.format(
                x_itr,  xdis.detach().cpu().numpy(),ydis.detach().cpu().numpy(), total_time))
                print(torch.norm(xgrad))
                print(torch.norm(xgard))
                print(dx)
    
    timelistfoaeq10.append(total_time)
    dxlistfoaeq10.append(copy.deepcopy(dx1.detach().cpu().numpy()))
    xdislistfoaeq10.append(copy.deepcopy(xdis.detach().cpu().numpy()))
    ydislistfoaeq10.append(copy.deepcopy(ydis.detach().cpu().numpy()))
   

#qnbo
xdislistfoaeq15=[]
ydislistfoaeq15=[]
dxlistfoaeq15=[]
timelistfoaeq15= [] 
x0=2
y0=2
x = (float(x0) * torch.ones([args.xSize, 1])).requires_grad_(True)
y = (float(y0) * torch.ones([args.ySize, 1])).requires_grad_(True)
dxlistfoaeq15.append(copy.deepcopy(dx0.detach().cpu().numpy()))

total_time = 0.0
timelistfoaeq15.append(total_time)
xdislistfoaeq15.append(copy.deepcopy(xdis0.detach().cpu().numpy()))
ydislistfoaeq15.append(copy.deepcopy(ydis0.detach().cpu().numpy()))

for x_itr in range(100):
    t0 = time.time()
    y,et= bfgs(x,y,tol=1/(100*(x_itr+1)),step=15,maxiter_hg=50,m=50,h0=0.1,ex_up=True)
    
    Fx=F_x(x,y)
    xgrad=Fx+et

    x=x-0.1*xgrad
    t1 = time.time()
    total_time += t1 - t0


    with torch.no_grad():
              xgard=torch.mm(D+invA,x)-z0
              dx=torch.norm(xgrad-xgard)
              dx1=torch.norm(xgard)
              xdis=torch.norm(x - xstar) /torch.norm( xstar)
              ydis=torch.norm(y-ystar) / torch.norm(ystar)

    
    if x_itr % args.log == 0:
                print('x_itr={},xdist={:.6f},ydist={:.6f}, total_time={:.6f}'.format(
                x_itr,  xdis.detach().cpu().numpy(),ydis.detach().cpu().numpy(), total_time))
                print(torch.norm(xgrad))
                print(torch.norm(xgard))
                print(dx)
    
    timelistfoaeq15.append(total_time)
    dxlistfoaeq15.append(copy.deepcopy(dx1.detach().cpu().numpy()))
    xdislistfoaeq15.append(copy.deepcopy(xdis.detach().cpu().numpy()))
    ydislistfoaeq15.append(copy.deepcopy(ydis.detach().cpu().numpy()))
    


#qnbo
xdislistfoaeq20=[]
ydislistfoaeq20=[]
dxlistfoaeq20=[]
timelistfoaeq20= [] 
x0=2
y0=2
x = (float(x0) * torch.ones([args.xSize, 1])).requires_grad_(True)
y = (float(y0) * torch.ones([args.ySize, 1])).requires_grad_(True)
dxlistfoaeq20.append(copy.deepcopy(dx0.detach().cpu().numpy()))

total_time = 0.0
timelistfoaeq20.append(total_time)
xdislistfoaeq20.append(copy.deepcopy(xdis0.detach().cpu().numpy()))
ydislistfoaeq20.append(copy.deepcopy(ydis0.detach().cpu().numpy()))

for x_itr in range(100):
    t0 = time.time()
    y,et= bfgs(x,y,tol=1/(100*(x_itr+1)),step=15,maxiter_hg=100,m=100,h0=0.1,ex_up=True)
    
    Fx=F_x(x,y)
    xgrad=Fx+et

    x=x-0.1*xgrad
    t1 = time.time()
    total_time += t1 - t0


    with torch.no_grad():
              xgard=torch.mm(D+invA,x)-z0
              dx=torch.norm(xgrad-xgard)
              dx1=torch.norm(xgard)
              xdis=torch.norm(x - xstar) /torch.norm( xstar)
              ydis=torch.norm(y-ystar) / torch.norm(ystar)

    
    if x_itr % args.log == 0:
                print('x_itr={},xdist={:.6f},ydist={:.6f}, total_time={:.6f}'.format(
                x_itr,  xdis.detach().cpu().numpy(),ydis.detach().cpu().numpy(), total_time))
                print(torch.norm(xgrad))
                print(torch.norm(xgard))
                print(dx)
    
    timelistfoaeq20.append(total_time)
    dxlistfoaeq20.append(copy.deepcopy(dx1.detach().cpu().numpy()))
    xdislistfoaeq20.append(copy.deepcopy(xdis.detach().cpu().numpy()))
    ydislistfoaeq20.append(copy.deepcopy(ydis.detach().cpu().numpy()))
    

lw = 2.5


line_styles = {
    'Q=1': {'color': 'C9', 'linestyle': '-', 'linewidth': 6},
    'Q=10': {'color': 'green', 'linestyle': '-', 'linewidth': 6},
    'Q=20': {'color': 'C0', 'linestyle': '-', 'linewidth':6},
    'Q=50': {'color': 'orange', 'linestyle': '-', 'linewidth': 6},
    'Q=100': {'color': 'purple', 'linestyle': '-', 'linewidth': 6},
    'Q=k': {'color': '#D02020', 'linestyle': '-', 'linewidth': 6}
}

legend_elements = [mlines.Line2D([], [], color=style['color'], linestyle=style['linestyle'], linewidth=style['linewidth'], label=label) for label, style in line_styles.items()]

plt.figure(figsize=(10, 9))
data = {
    'Q=1': dxlistfoaeq1,
    'Q=10': dxlistfoaeq5,
    'Q=20': dxlistfoaeq10,
    'Q=50': dxlistfoaeq15,
    'Q=100': dxlistfoaeq20,
    'Q=k': dxlistfoaeqk
}

for label, values in data.items():
    plt.plot(values, label=label, **line_styles[label])


plt.xlabel('Iter: k', fontsize=30, fontweight='bold')
plt.ylabel(r'$||\nabla\Phi(x_k)||$', fontsize=30, fontweight='bold')
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.yscale('log')
plt.grid(visible=True, which='major', linestyle='-.', alpha=0.7)


legend_font = FontProperties(weight='bold', size=10)

plt.savefig('foadqi.pdf', dpi=300, bbox_inches='tight')

plt.figure(figsize=(2, 0.5))
plt.legend(handles=legend_elements, ncol=3, fontsize=8, prop=legend_font, borderpad=1, loc='upper left',
           handlelength=2, handletextpad=2)

plt.axis('off')
plt.savefig('custom_legend.pdf', dpi=300, bbox_inches='tight')



