import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from IPython import embed
import pickle
import random
import os
from IPython import embed


def generate_data(seed):
    np.random.seed(seed+52)
    random.seed(seed+10001)
    eta = 0.1
    Lambda = 0.01
    p = 100

    a = np.ones(p)
    a[0] = 0
    A = np.diag(a)
    # A = np.diag(np.arange(1, p+1, 1)/p)

    I = np.eye(p)

    def L(x):
        return x @ A @ x / (x @ x) / 2

    def P(x):
        norm = np.sqrt(x @ x)
        x_ = x.reshape(-1,1) / norm
        return 1 / norm * (I - x_ @ x_.T)

    def K(x):
        return x @ (A*A) @ x / (x @ x) 

    sigma = 1
    norm2_theorem = (p-1)*sigma**2*eta / 2 / Lambda

    Iters = 30000

    # upper bound and lower bound for expected loss

    REC_Loss = []
    REC_au = []
    REC_norm = []

    # x = np.random.rand(p)
    # x = np.ones(p)
    x = np.random.randn(p)
    x = x / np.sqrt(x @ x) * norm2_theorem**0.25

    for i in tqdm(range(Iters)):
        if i == 10000:
            eta /= 10
        REC_Loss.append(L(x))
        norm_ = np.sqrt(x @ x)
        grad = P(x) @ (A @ x / norm_ + np.random.randn(p,)*sigma)
        REC_norm.append(norm_)
        au = np.sqrt(grad @ grad) * eta / norm_
        REC_au.append(au)
        x = x - eta * (grad + Lambda * x)

    return  np.array([REC_Loss, REC_au, REC_norm])

def generate_data2(seed):
    np.random.seed(seed+52)
    random.seed(seed+10001)
    eta = 0.1
    Lambda = 0.01
    p = 100

    a = np.ones(p)
    a[0] = 0
    A = np.diag(a)
    # A = np.diag(np.arange(1, p+1, 1)/p)

    I = np.eye(p)

    def L(x):
        return x @ A @ x / (x @ x) / 2

    def P(x):
        norm = np.sqrt(x @ x)
        x_ = x.reshape(-1,1) / norm
        return 1 / norm * (I - x_ @ x_.T)

    def K(x):
        return x @ (A*A) @ x / (x @ x) 

    sigma = 1
    norm2_theorem = (p-1)*sigma**2*eta / 2 / Lambda

    Iters = 30000

    # upper bound and lower bound for expected loss

    REC_Loss = []
    REC_au = []
    REC_norm = []

    # x = np.random.rand(p)
    # x = np.ones(p)
    x = np.random.randn(p)
    x = x / np.sqrt(x @ x) * norm2_theorem**0.25

    for i in tqdm(range(Iters)):
        if i == 10000:
            eta /= 10
            x /= (10)**0.25
        REC_Loss.append(L(x))
        norm_ = np.sqrt(x @ x)
        grad = P(x) @ (A @ x / norm_ + np.random.randn(p,)*sigma)
        REC_norm.append(norm_)
        au = np.sqrt(grad @ grad) * eta / norm_
        REC_au.append(au)
        x = x - eta * (grad + Lambda * x)

    return  np.array([REC_Loss, REC_au, REC_norm])

def plot():

    fig1 = plt.figure(1)
    ax1 = fig1.add_subplot()
    fig2 = plt.figure(2)
    ax2 = fig2.add_subplot()
    fig3 = plt.figure(3)
    ax3 = fig3.add_subplot()
    fig4 = plt.figure(4)
    ax4 = fig4.add_subplot()

    with open('output/data.txt', 'rb') as fp:
        RECs = pickle.load(fp)
        fp.close()

    label=label = r'Normal'
    color="blue"

    means = np.mean(RECs, axis=0)
    Iters = range(RECs.shape[2])
    stds = np.std(RECs, axis=0)
    # embed()
    # ax1.plot(Iters, RECs[0,0,:], label=label, color=color)
    # ax3.plot(Iters, RECs[0,1,:], label=label, color=color)
    # ax4.plot(Iters, RECs[0,2,:], label=label, color=color)   
    ax1.plot(Iters, means[0,:], label=label, color=color)
    ax2.plot(Iters, means[0,:], label=label, color=color)
    ax2.fill_between(Iters, means[0,:]-stds[0,:], means[0,:]+stds[0,:] ,alpha=0.3, facecolor=color)
    ax3.plot(Iters, means[1,:], label=label, color=color)
    ax3.fill_between(Iters, means[1,:]-stds[1,:], means[1,:]+stds[1,:] ,alpha=0.3, facecolor=color)
    ax4.plot(Iters, means[2,:], label=label, color=color)
    ax4.fill_between(Iters, means[2,:]-stds[2,:], means[2,:]+stds[2,:] ,alpha=0.3, facecolor=color)

    with open('output/data2.txt', 'rb') as fp:
        RECs = pickle.load(fp)
        fp.close()

    label=label = r'Rescale'
    color="red"

    means = np.mean(RECs, axis=0)
    Iters = range(RECs.shape[2])
    stds = np.std(RECs, axis=0)
    # ax1.plot(Iters, RECs[1,0,:], label=label, color=color)
    # ax3.plot(Iters, RECs[0,1,:], label=label, color=color)
    # ax4.plot(Iters, RECs[0,2,:], label=label, color=color) 
    ax1.plot(Iters, means[0], label=label, color=color)
    ax2.plot(Iters, means[0], label=label, color=color)
    ax2.fill_between(Iters, means[0,:]-stds[0,:], means[0,:]+stds[0,:] ,alpha=0.3, facecolor=color)
    ax3.plot(Iters, means[1,:], label=label, color=color)
    ax3.fill_between(Iters, means[1,:]-stds[1,:], means[1,:]+stds[1,:] ,alpha=0.3, facecolor=color)
    ax4.plot(Iters, means[2,:], label=label, color=color)
    ax4.fill_between(Iters, means[2,:]-stds[2,:], means[2,:]+stds[2,:] ,alpha=0.3, facecolor=color)

    ax1.set_xlabel('t',fontsize=14)
    ax1.set_ylabel(r'$r_t$',fontsize=14)
    ax1.grid()
    ax1.legend(fontsize=12)
    ax1.set_yscale('log')
    # ax1.set_ylim([-0.005, 0.15])
    fig1.savefig(f'output/figs/risk.png')

    ax2.set_xlabel('t', fontsize=14)
    ax2.set_ylabel(r'$L_t$', fontsize=14)
    ax2.grid()
    ax2.legend(fontsize=12)
    ax2.set_yscale('log')
    fig2.savefig(f'output/figs/loss.png') 

    ax3.set_xlabel('t', fontsize=14)
    ax3.set_ylabel(r'$\Delta_t$', fontsize=14)
    ax3.grid()
    ax3.legend(fontsize=12)
    ax3.set_yscale('log')
    fig3.savefig(f'output/figs/au.png')

    ax4.set_xlabel('t', fontsize=14)
    ax4.set_ylabel(r'$||X_t||_2$', fontsize=14)
    # ax4.set_ylim([70, 100])
    ax4.grid()
    ax4.legend(fontsize=12)
    fig4.savefig(f'output/figs/norm.png')


if __name__ == "__main__":
    # RECs = []
    # for i in range(100):
    #     rec = generate_data(1002+i)
    #     RECs.append(rec)
    # RECs = np.stack(RECs, axis=0)
    # with open('output/data.txt', 'wb') as fp:
    #     pickle.dump(RECs, fp)
    #     fp.close()
    # RECs = []
    # for i in range(100):
    #     rec = generate_data2(1002+i)
    #     RECs.append(rec)
    # RECs = np.stack(RECs, axis=0)
    # with open('output/data2.txt', 'wb') as fp:
    #     pickle.dump(RECs, fp)
    #     fp.close()  

    plot()