import torch
from rayleigh_quotient import Rayleigh_Quotient
from tensorboardX import SummaryWriter
from My_SGDM import SGD
import time
import math
import numpy as np
import matplotlib.pyplot as plt
from f_tail_xi import f_tail_list
import os

PATH = './output'
if os.path.exists(PATH) == False:
    os.makedirs(PATH)
    os.makedirs('./output/fig')
    os.makedirs('./output/data')


def WN(lr, wd, s2, p):

    return  (lr * s2 * (p-1) / 2 / wd) ** 0.25


def AU(lr, wd):

    return  np.sqrt(2 * lr * wd)


def UB(t, lr, wd, s2, p, a, xi, r0):
    
    x1 = np.linspace(0.01, 1, 2000, endpoint=True)
    x2 = np.logspace(-6,-2,100, endpoint=True)
    x = np.concatenate((x2,x1))

    delta = np.sqrt(2 * lr * wd)
    t1 = a / np.sqrt((p-1) * s2) + p / (p-1) * delta
    g = delta * t1 * xi
    r = delta/ t1 / xi
    k = np.sqrt(p-1) / np.sqrt(2 * lr * wd * s2) / 2
    eps = f_tail_list(x, [k], p , xi)
    return r + (r0-r) * np.exp(-g * t) + eps
    

def LB(t, lr, wd, s2, p , a, r0):

    delta = np.sqrt(2 * lr * wd)
    t1 = 2*  a / np.sqrt((p-1) * s2) + p / (p-1) * delta
    g = delta * t1
    r = delta / t1
    return r + (r0-r) * np.exp(-g * t)


def my_optimizer(toy_model, lr, wd):

    param_groups = []

    param_groups.append( {'params':[toy_model.weight]})
    optimizer = SGD(param_groups, lr=lr, momentum=0, weight_decay=wd)

    return optimizer


def main():

    s2 = 1
    p = 100
    wd_list = [1e-3, 1e-3]
    lr_list = [5e-2, 5e-3]
    xi_list = [0.93, 0.97]

    max_iter = 16000
    gap = 200
    rep = 100

    D = torch.ones(p)
    D[0] = 0
    A = torch.diag(D)

    for k in range(len(wd_list)):
        w_ = ((p-1) * s2 * lr_list[k] / 2 / wd_list[k]) ** 0.25

        risk_data = np.zeros((rep, max_iter//gap))
        au_data = np.zeros((rep, max_iter//gap))
        loss_data = np.zeros((rep, max_iter//gap))
        wn_data = np.zeros((rep, max_iter//gap))

        for i in range(rep):
            model = Rayleigh_Quotient(mat = A, noise_scale=s2, initial = w_)
            optimizer = my_optimizer(model, lr_list[k], wd_list[k])
            scheduler=torch.optim.lr_scheduler.MultiStepLR(optimizer, [2*max_iter], 1)

            j = 0
            for iteration in range(0, max_iter):

                loss = model.forward()
                grads = model.cal_grad()
                optimizer.step(grads)

                if iteration % gap == 0:

                    w=model.weight.clone().detach()
                    tg =model.true_grad.clone().detach()
                    ng = model.noisy_grad.clone().detach()
                    w_n = torch.norm(w).item()
                    #tg_n =torch.norm(tg).item()
                    ng_n = torch.norm(ng).item()
                    lr = optimizer.param_groups[0]['lr']
                    b = w[0, 0].item()
                    risk = (1 - min((b / w_n) ** 2,1.0))

                    risk_data[i, j] = risk
                    wn_data[i, j] = w_n
                    au_data[i, j] = lr * ng_n / w_n
                    loss_data[i, j] = loss

                    j += 1

                if iteration % 5000 == 0:
                    print(k, i, iteration,loss)
               
                scheduler.step()

        np.savez('./output/data/ckpt_'+str(k), risk=risk_data, loss=loss_data, au=au_data, wn=wn_data)
    
    T = np.arange(0, max_iter, gap)
    K = len(wd_list)

    
    #########################################################################################
    #plot weight norm

    label_list=[r'$\eta=5\times10^{-2}, \lambda=1\times10^{-3}$ ', r'$\eta=5\times 10^{-3}, \lambda=1\times 10^{-3}$ ']

    for k in  range(len(wd_list)):
        data = np.load('./output/data/ckpt_'+str(k)+'.npz')
        wn = data['wn']
        m = np.mean(wn, axis=0)
        if rep > 1:
            std = np.std(wn, axis=0)
            eub = m + std
            elb = m - std
            plt.fill_between(T, eub[:max_iter//gap], elb[:max_iter//gap], facecolor=[1*(1-k/(K-1)),0,k/(K-1)], edgecolor='white',alpha=0.3)
       
        plt.plot(T, m[:max_iter//gap], linestyle='-',  color=[1*(1-k/(K-1)),0,k/(K-1)], label=label_list[k]+'experimental')
        plt.plot(T, WN(lr_list[k], wd_list[k], s2, p) * np.ones_like(T),  linestyle='--',  color=[1*(1-k/(K-1)),0,k/(K-1)])

    plt.plot(T, -1*np.ones_like(T),  linestyle='--',  color='black', label='theoretical')
    plt.xlabel("t", fontsize=14)
    plt.ylabel(r'$||X_t||_2$', fontsize=14)
    plt.xlim(0, max_iter)
    plt.ylim(2, 9)
    plt.rcParams.update({'font.size':12})
    plt.legend()
    plt.grid()

    plt.savefig('./output/fig/nrq_wn.png', dpi=800)

    #########################################################################################
    #plot au

    label_list=[r'$\eta=5\times10^{-2}, \lambda=1\times10^{-3}$ ', r'$\eta=5\times 10^{-3}, \lambda=1\times 10^{-3}$ ']

    for k in  range(len(wd_list)):
        data = np.load('./output/data/ckpt_'+str(k)+'.npz')
        au = data['au']
        m = np.mean(au, axis=0)
        if rep > 1:
            std = np.std(au, axis=0)
            eub = m + std
            elb = m - std
            plt.fill_between(T, eub[:max_iter//gap], elb[:max_iter//gap], facecolor=[1*(1-k/(K-1)),0,k/(K-1)], edgecolor='white',alpha=0.3)
        
        plt.plot(T, m[:max_iter//gap], linestyle='-', color=[1*(1-k/(K-1)),0,k/(K-1)], label=label_list[k]+'experimental')
        plt.plot(T, AU(lr_list[k], wd_list[k]) * np.ones_like(T), linestyle='--',  color=[1*(1-k/(K-1)),0,k/(K-1)])

    plt.plot(T, -1*np.ones_like(T), linestyle='--',  color='black', label='theoretical')
    plt.xlabel("t", fontsize=14)
    plt.ylabel(r'$\Delta_t$', fontsize=14)
    plt.xlim(0, max_iter)
    plt.yscale('log')
    plt.ylim(2.5e-3, 0.012)
    plt.rcParams.update({'font.size':12})
    plt.legend()
    plt.grid()

    plt.savefig('./output/fig/nrq_au.png', dpi=800)


    #########################################################################################
    #plot loss

    label_list=[r'$au=\sqrt{1\times 10^{-4}}$ ', r'$au=\sqrt{1\times 10^{-5}}$ ']

    for k in  range(len(wd_list)):
        data = np.load('./output/data/ckpt_'+str(k)+'.npz')
        loss = data['loss']
        
        m = np.mean(loss, axis=0)
        if rep > 1:
            std = np.std(loss, axis=0)
            eub = m+ std
            elb = m - std
            plt.fill_between(T, eub[:max_iter//gap], elb[:max_iter//gap], facecolor=[1*(1-k/(K-1)),0,k/(K-1)], edgecolor='white',alpha=0.3)
        plt.plot(T, m[:max_iter//gap], linestyle='-', color=[1*(1-k/(K-1)),0,k/(K-1)], label=label_list[k])
    
    plt.xlabel("t", fontsize=14)
    plt.xlim(0, max_iter)
    plt.ylabel(r'$L_t$', fontsize=14)
    plt.yscale('log')
    plt.rcParams.update({'font.size':12})
    plt.legend()
    plt.grid()

    plt.savefig('./output/fig/nrq_loss.png', dpi=800)


    #########################################################################################
    #plot risk

    label_list=[r'$au=\sqrt{1\times 10^{-4}}$ ', r'$au=\sqrt{1\times 10^{-5}}$ ']

    for k in  range(len(wd_list)):
        data = np.load('./output/data/ckpt_'+str(k)+'.npz')
        risk = data['risk']
        
        m = np.mean(risk, axis=0)

        plt.plot(T, LB(T, lr_list[k], wd_list[k], s2, p, 1, risk[0,0]), color=[1*(1-k/(K-1)),0,k/(K-1)], linestyle='--')
        plt.plot(T, UB(T, lr_list[k], wd_list[k], s2, p, 1, xi_list[k], risk[0,0]), color=[1*(1-k/(K-1)),0,k/(K-1)], linestyle=':') 
        plt.plot(T, m[:max_iter//gap], linestyle='-', color=[1*(1-k/(K-1)),0,k/(K-1)], label=label_list[k]+'experimental mean')

    plt.plot(T, -1*np.ones_like(T), label='theoretical lower bounds', color='black', linestyle='--')
    plt.plot(T, -1*np.ones_like(T), label='theoretical upper bounds', color='black', linestyle=':')
    plt.xlabel("t", fontsize=14)
    plt.ylabel(r'$r_t$', fontsize=14)
    plt.xlim(0, max_iter)
    plt.yscale('log')
    plt.grid()
    plt.rcParams.update({'font.size':12})
    plt.legend()
    

    plt.savefig('./output/fig/nrq_risk.png', dpi=800)

if __name__ ==  "__main__":

    main()
