import numpy as np
import math
import matplotlib.pyplot as plt
import os

import argparse

np.random.seed(0)


def inv_linear_covariance(dim):
    cov = [1 / (i + 1) for i in range(dim)]
    return np.array(cov)


def inv_linear_log_covariance(dim, exp=1):
    cov = [1 / ((i + 1) * (np.log(i + 2) ** exp)) for i in range(dim)]
    return np.array(cov)


def inv_poly_covariance(dim, exp=2):
    cov = [1 / ((i + 1) ** exp) for i in range(dim)]
    return np.array(cov)


def dataGenerator(samplesize, dim, cov, theta_star, sigma_x=1, sigma_y=1):
    X = []
    Y = []
    for i in range(samplesize):
        x = np.random.randn(dim) * cov * sigma_x
        y = np.dot(x, theta_star) + np.random.randn(1) * sigma_y
        X.append(x)
        Y.append(y)
    return np.array(X), np.array(Y).squeeze()


def train_linear(X, Y, theta0, lr, time, cov, theta_star, sample_size):
    theta = theta0
    gen_res = []
    gen_res.append(0.5 *(cov * (theta - theta_star) * (theta - theta_star)).sum())
    train_res = []
    tmp = Y - np.dot(X, theta)
    train_res.append(0.5 / sample_size * (tmp * tmp).sum())
    for t in range(time):
        if (t + 1) % 10000 == 0:
            print(t + 1)
        theta = theta + lr * np.dot(X.T, Y - np.dot(X, theta))
        gen_res.append(0.5 * (cov * (theta - theta_star) * (theta - theta_star)).sum())
        tmp = Y - np.dot(X, theta)
        train_res.append(0.5 / sample_size * (tmp * tmp).sum())
    return train_res, gen_res


def train(X, Y, theta0, lr, time):
    theta = theta0
    theta_list = [theta0]
    for t in range(time):
        if (t + 1) % 100000 == 0:
            print(t + 1)
        theta = theta + lr * np.dot(X.T, Y - np.dot(X, theta))
        theta_list.append(theta)
    return theta_list


def get_final_theta(samplesize, dim, X, Y):
    # overparametrized
    if samplesize <= dim:
        inv = np.linalg.inv(np.dot(X, X.T) + 0.0001 * np.identity(samplesize))
        return np.dot(X.T, np.dot(inv, Y))
    # underparametrized
    else:
        inv = np.linalg.inv(np.dot(X.T, X))
        return np.dot(inv, np.dot(X.T, Y))


def gen_error(theta_list, cov, theta_star):
    gen_res = []
    for theta in theta_list:
        gen_res.append(0.5 * (cov * (theta - theta_star) * (theta - theta_star)).sum())
    return np.array(gen_res)


def train_error(theta_list, sample_size, X, Y):
    train_res = []
    for theta in theta_list:
        tmp = Y - np.dot(X, theta)
        train_res.append(0.5 / sample_size * (tmp * tmp).sum())
    return np.array(train_res)


def plot_simple(time, train_res, gen_res, path, title):
    fig, ax = plt.subplots()
    ax2 = ax.twinx()
    x = [i for i in range(time + 1)]

    line1 = ax.plot(x, gen_res, label="Excess Risk", color='#1f77b4', linewidth=2.5)
    line2 = ax2.plot(x, train_res, label="Training Loss", color='#ff7f0e', linewidth=2.5)

    ax.set_xlabel('Epochs', fontsize=25)
    ax.set_ylabel('Excess Risk', fontsize=25)
    ax2.set_ylabel('Training Loss', fontsize=25)

    ax.set_xscale('log')
    ax.set_yscale('log')
    ax2.set_yscale('log')

    ax.set_ylim((10 ** -2, 10 ** 4))
    ax2.set_ylim((10 ** -3, 10 ** 1))

    fig.legend(loc=1, bbox_to_anchor=(1, 1), bbox_transform=ax.transAxes, fontsize=20)

    ax.spines['bottom'].set_linewidth(1.5)
    ax.spines['top'].set_linewidth(1.5)
    ax.spines['left'].set_linewidth(1.5)
    ax.spines['right'].set_linewidth(1.5)

    fig.set_size_inches(8, 6)
    plt.savefig(path + title, dpi=300)


# plot the training curve
def plot_res(samples, dim, cov_type,lr):
    total_time = 10000000
    lamb = lr
    sigma_x = 1
    sigma_y = 1

    theta_star = np.random.randn(dim)

    if cov_type == 0:
        cov = inv_linear_covariance(dim)
        print('inv linear')
        dir = './linear/inv_linear'

    if cov_type == 1:
        cov = inv_poly_covariance(dim, exp=2)
        print('inv poly 2')
        dir = './linear/inv_poly2'
        
    if cov_type == 2:
        cov = inv_poly_covariance(dim, exp=3)
        print('inv poly 3')
        dir = './linear/inv_poly3'

    if cov_type == 3:
        cov = inv_linear_log_covariance(dim, exp=1)
        print('inv log 1')
        dir = './linear/inv_log1'

    if cov_type == 4:
        cov = inv_linear_log_covariance(dim, exp=2)
        print('inv log 2')
        dir = './linear/inv_log2'

    if cov_type == 5:
        cov = inv_linear_log_covariance(dim, exp=3)
        print('inv log 3')
        dir = './linear/inv_log3'

    if not os.path.exists(dir):
        os.makedirs(dir, exist_ok=True)

    X, Y = dataGenerator(samples, dim, cov, theta_star, sigma_x, sigma_y)
    theta0 = np.zeros(dim)

    train_result, gen_result = train_linear(X, Y, theta0, lamb, total_time, cov, theta_star, samples)

    np.save(dir + '/gen.npy', gen_result)
    np.save(dir + '/train.npy', train_result)
    plot_simple(total_time, train_result, gen_result, dir, '/fig.png')


# calculate the expected early-stopping excess risk and min-norm excess risk
def mean_risk(samples, dim, cov_type):
    print(samples, dim)
    total_time = 10000
    lamb = 0.001
    sigma_x = 1
    sigma_y = 1

    theta_star = np.random.randn(dim)

    if cov_type == 0:
        cov = inv_linear_covariance(dim)
        print('inv linear')

    if cov_type == 1:
        cov = inv_poly_covariance(dim, exp=2)
        print('inv poly 2')

    if cov_type == 2:
        cov = inv_poly_covariance(dim, exp=3)
        print('inv poly 3')

    if cov_type == 3:
        cov = inv_linear_log_covariance(dim, exp=1)
        print('inv log 1')

    if cov_type == 4:
        cov = inv_linear_log_covariance(dim, exp=2)
        print('inv log 2')

    if cov_type == 5:
        cov = inv_linear_log_covariance(dim, exp=3)
        print('inv log 3')

    repeat = 1000
    gen_min_list = []
    gen_final_list = []
    for i in range(repeat):
        X, Y = dataGenerator(samples, dim, cov, theta_star, sigma_x, sigma_y)
        theta0 = np.zeros(dim)
        theta_result = train(X, Y, theta0, lamb, total_time)
        gen_result = gen_error(theta_result, cov, theta_star)
        gen_min = np.min(gen_result)
        final_theta = get_final_theta(samples, dim, X, Y)
        gen_final_result = gen_error([final_theta], cov, theta_star)
        if np.isnan(gen_result[-1]):
            continue
        gen_min_list.append(gen_min)
        gen_final_list.append(gen_final_result)

    print(np.mean(gen_min_list))
    print(np.var(gen_min_list))
    print(np.mean(gen_final_list))
    print(np.var(gen_final_list))
    print()


def load_and_plot(dir):
    gen_result = np.load(dir + '/gen.npy')[0:10000001]
    train_result = np.load(dir + '/train.npy')[0:10000001]
    plot_simple(10000000, train_result, gen_result, dir, '/fig.png')


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--samples', type=int, default=100)
    parser.add_argument('--dim', type=int, default=1000)
    parser.add_argument('--cov', type=int, default=0)
    parser.add_argument('--lr', type=float, default=0.001)
    args = parser.parse_args()

    mean_risk(args.samples, args.dim, args.cov)
    plot_res(args.samples, args.dim, args.cov,args.lr)
