import numpy as np
import matplotlib.pyplot as plt
from train import train
from data import gen_data
plt.rc('font', family="Arial")
plt.rcParams['font.size'] = '14'
colors = [plt.get_cmap('Set1')(i) for i in range(9)]


def leaky_sweep(args):
    fig, ax = plt.subplots(figsize=(4, 3.5))

    slope = np.linspace(0, 1, 6)
    data = gen_data(args)
    init = args.init
    for i, s in enumerate(slope):
        args.relu = s
        args.init = np.sqrt(2/(s+1)) * init
        results = train(data, args)
        if i == len(slope)-1:
            ax.plot(results['Ls'], linewidth=1, c=plt.cm.plasma_r(s), label='Simulation')
            ax.plot(theo_sol(args, data), linestyle=':', linewidth=3, alpha=0.6, c=plt.cm.plasma_r(s), label='Theory')
        ax.plot(results['Ls'], linewidth=1, c=plt.cm.plasma_r(s))
        ax.plot(theo_sol(args, data), linestyle=':', linewidth=3, alpha=0.6, c=plt.cm.plasma_r(s))
        print("init =", args.init)
    ylim = results['Ls'][0]+0.1
    ax.set_ylim((-0.05, ylim))
    ax.set_xlim((0, args.epoch))
    ax.set_xticks((0, args.epoch//2, args.epoch))
    ax.set_yticks([0,0.5,1])
    ax.set_xlabel("$t$")
    ax.set_ylabel("Loss")
    plt.tight_layout(pad=0.5)
    plt.show()


def compare(args):
    fig, ax1 = plt.subplots(figsize=(4, 3.5))
    ax1.set_xlabel(r'$\frac{2}{\alpha +1}t$')
    ax1.set_ylabel('Loss')
    ax2 = ax1.twinx()
    ax2.set_ylabel('Error')

    slope = np.linspace(0, 1, 6)
    data = gen_data(args)

    init = args.init
    lr = args.lr
    T = args.epoch
    args.relu = 1
    results_lin = train(data, args)
    for i, s in enumerate(slope):
        args.relu = s
        args.init = np.sqrt(2/(s+1)) * init
        args.lr = lr * 2/(s+1)
        results = train(data, args)
        if s == 1:
            ax1.plot(results['Ls'], linewidth=1, c=plt.cm.plasma_r(s), label='Loss')
            ax1.plot(0, 0, linestyle='--', linewidth=1, c=plt.cm.plasma_r(s), label='Error')
        if s != 1:
            ax1.plot(results['Ls'], linewidth=1, c=plt.cm.plasma_r(s))
            error = weight_error(results_lin, results, T, s)
            ax2.plot(error, linestyle='--', linewidth=1, c=plt.cm.plasma_r(s))
    ylim = results['Ls'][0]+0.1
    ax2.set_xlim((0, T))
    ax2.set_ylim((-0.0005, ylim/100))
    ax2.set_xticks((0, T//2, T))
    ax2.set_yticks([0,.005,.01], ['0', '0.5%', '1%'])
    ax1.set_ylim((-0.05, ylim))
    ax1.set_yticks([0,.5,1])
    ax1.legend()
    plt.tight_layout(pad=0.5)
    plt.show()


def weight_error(res_lin, res_relu, T, alpha):
    scale = np.sqrt((alpha+1)/2)
    error = np.zeros(T)
    for t in range(T):
        W_norm = 0
        for l in range(len(res_lin['layer'][t])):    
            W_l = res_relu['layer'][t][l].cpu().detach().numpy()
            W_l_lin = res_lin['layer'][t][l].cpu().detach().numpy()
            W_norm += np.linalg.norm(W_l_lin) ** 2 
            error[t] += np.linalg.norm(W_l_lin - W_l*scale) ** 2
        error[t] = np.sqrt(error[t]) / np.sqrt(W_norm)
    return error


def theo_sol(args, data):
    yx = np.mean(data['y'] * data['x'], axis=0)
    w_hat = np.linalg.inv(data['x'].T @ data['x']) @ data['x'].T @ data['y']
    w_hat = w_hat / np.linalg.norm(w_hat)

    sv = np.linalg.norm(yx)
    u0 = args.init
    t = np.arange(args.epoch) * (args.relu + 1) * args.lr / 2
    u_t = np.exp(2*sv*t) / (u0**(-2) + (np.exp(2*sv*t)-1)/sv)
    
    L_t = []
    for i in range(len(t)):
        L = np.mean((u_t[i]*data['x'] @ w_hat - data['y'])**2) / 2
        L_t.append(L)
    L_t = np.array(L_t)
    return L_t