import os
import sys
import time
import numpy as np
import pandas as pd
import scipy.stats as stats
import matplotlib.pyplot as plt
from matplotlib import cm
import matplotlib.patches as mpatches
from matplotlib.lines import Line2D
from matplotlib.legend import Legend
import seaborn as sns

from cv_num_exper import *

start_time = time.time()

algo = str(sys.argv[1])
path_to_res = str(sys.argv[2])
dense_beta = int(sys.argv[3])
fig1 = int(sys.argv[4])


fig_folder = os.path.join(path_to_res, 'figures')
if not os.path.isdir(fig_folder):
    os.mkdir(fig_folder)

algos = [f'{algo}_opti', f'{algo}x1']

if algo == 'Lasso' and not(fig1):
    NB_SIMS = 10
    NB_REPS = 5000
    M_MSS_LS = 200
    M_SIGMA2 = 200
    OVERRIDE = None
else:
    NB_SIMS = 100
    NB_REPS = 500
    M_MSS_LS = 10000
    M_SIGMA2 = 10000
    OVERRIDE = 'sqrt'

methods = []
for algo1 in algos:
    methods.append(algo1)
for algo1 in algos:
    for algo2 in algos:
        if algo1 != algo2:
            methods.append(algo1+'_'+algo2)

methods = methods[:3]

n_samples_array = [100, 1000, 10000, 100000]
# n_samples_array = [100, 1000, 10000, 30000, 100000]
q = 2*len(algos)-1

def combine(method, n_samples_array):
    start=time.time()
    list_df = []
    nb_sims = []
    for n in n_samples_array:
        current=time.time()
        file = os.path.join(path_to_res, method, f'n_{n}', 'all_reps.h5')
        assert os.path.isfile(file)
        df = pd.read_hdf(file,'all_reps').iloc[:,-(19+8*q+1):]
        s = df.shape[0]
        list_df.append(df)
        nb_sims.append(s)
        print(n, time.time()-current, os.path.isfile(file))
    print(time.time()-start)
    return list_df, nb_sims

list_df_dict = {}
sims_dict = {}

for key in methods:
    print(key)
    list_df, sims = combine(key, n_samples_array)
    list_df_dict[key] = list_df
    sims_dict[key] = sims


def figure1_plot():

    key_sing = methods[0]
    key_comp = methods[2]
    
    i = -1 # looking at the largest sample size

    alphas = np.concatenate([[0, 1e-5, 1e-4, 1e-3, 0.005], np.arange(0.01, 1.01, 0.01)])
    target_coverages = []
    actual_coverages_sing = []
    actual_coverages_comp = []
    
    for alpha in alphas: 
        sigma_n_squared = list_df_dict[key_sing][i].iloc[:, 19+4*q].mean()
        sigma_n = np.sqrt(sigma_n_squared)
        l_bound, u_bound = CI_2sided(list_df_dict[key_sing][i].iloc[:, 0].to_numpy(), sigma_n/np.sqrt(n_samples_array[i]), alpha=alpha)
        is_contained = (l_bound <= list_df_dict[key_sing][i].iloc[:, (19+6*q)].to_numpy()) * (list_df_dict[key_sing][i].iloc[:, (19+6*q)].to_numpy() <= u_bound)
        actual_coverage_sing = is_contained.mean()
        
        sigma_n_squared = list_df_dict[key_comp][i].iloc[:, 19+4*q+2].mean()
        sigma_n = np.sqrt(sigma_n_squared)
        l_bound, u_bound = CI_2sided(list_df_dict[key_comp][i].iloc[:, 0].to_numpy(), sigma_n/np.sqrt(n_samples_array[i]), alpha=alpha)
        is_contained = (l_bound <= list_df_dict[key_comp][i].iloc[:, (19+6*q)].to_numpy()) * (list_df_dict[key_comp][i].iloc[:, (19+6*q)].to_numpy() <= u_bound)
        actual_coverage_comp = is_contained.mean()
        
        target_coverage = 1 - alpha
        target_coverages.append(target_coverage)
        actual_coverages_sing.append(actual_coverage_sing)
        actual_coverages_comp.append(actual_coverage_comp)
    
    fig = plt.figure(figsize=(6, 6))
    plt.plot(target_coverages, actual_coverages_sing, label='Relatively stable model')
    plt.plot(target_coverages, actual_coverages_comp, label='Relatively unstable model comparison', linestyle='dashed')
    plt.xlabel('Target Coverage', fontsize=14)
    plt.ylabel('Actual Coverage', fontsize=14)
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    plt.legend(loc='upper left', fontsize=12)
    # plt.show()
    
    fig_name = 'coverage_plot_CLT.pdf'

    fig.savefig(os.path.join(path_to_res, 'figures', fig_name), bbox_inches = 'tight', pad_inches = 0.001)
    
    plt.close()


def kde_plot(comp):

    sigma2_list = []
    sigma2_sd_list = []
    gamma_list = []
    gamma_sd_list = []
    ratio_list = []
    ratio_sd_list = []
    fig, ax = plt.subplots(figsize=(8, 5))

    if comp:
        pos = 0
        key = '_'.join(algos)
    else:
        pos = -1
        key = algos[0]

    cmap = cm.turbo
    colors = cmap([0.9, 0.75, 0.25, 0.1])
    # colors = cmap([0.9, 0.75, 0.5, 0.25, 0.1]) # when using sample size 30000 as well
    x = np.linspace(-10, 10, 1000)
    lw = 1.9
    grey_shade = 0.4
    shade1 = plt.fill_between(x, stats.norm.pdf(x, 0, 1), color='grey', alpha=grey_shade, label="$\mathcal{N}(0, 1)$")
    legend_box = mpatches.Rectangle(
        (0, 0), 1, 1, 
        facecolor=(0.5,0.5,0.5,0.15),
        edgecolor='black',
        label='$\mathcal{N}(0, 1)$'
    )
    legends, labels = [], []
    legends.append(legend_box)
    labels.append('$\mathcal{N}(0, 1)$')
    
    for i in range(len(n_samples_array)):
        n = n_samples_array[i]
        j = 1
        
        print(f'sample size: {n}')
        
        LS = list_df_dict[key][i].iloc[:, 19+2*q+2*(pos+1)].mean()
        LS_2ndmoment = list_df_dict[key][i].iloc[:, 19+3*q+2*(pos+1)].mean()
        LS_sd = np.sqrt((LS_2ndmoment - LS**2) / (M_MSS_LS*NB_REPS))
        print("LS:", LS, LS_sd, 100*LS_sd/LS)
        sigma_n_squared = list_df_dict[key][i].iloc[:, 19+4*q+2*(pos+1)].mean()
        sigma_n_squared_2ndmoment = list_df_dict[key][i].iloc[:, 19+5*q+2*(pos+1)].mean()
        sigma_n_squared_sd = np.sqrt((sigma_n_squared_2ndmoment - sigma_n_squared**2) / (M_SIGMA2*NB_REPS))
        print("sigma_n^2:", sigma_n_squared, sigma_n_squared_sd, 100*sigma_n_squared_sd/sigma_n_squared)
        sigma_n = np.sqrt(sigma_n_squared)
        
        m = int(n * (1 - 1/10))
        
        stab_ratio_LS_paper = (m * LS) / sigma_n_squared
        stab_ratio_LS_paper_sd = np.sqrt(((LS**2)/(n*n*sigma_n_squared_sd**2) + (LS*LS*sigma_n_squared_sd**2)/(n*n*sigma_n_squared**4)))
        
        sigma2_list.append(sigma_n_squared)
        sigma2_sd_list.append(sigma_n_squared_sd)
        gamma_list.append(LS)
        gamma_sd_list.append(LS_sd)
        ratio_list.append(stab_ratio_LS_paper)
        ratio_sd_list.append(stab_ratio_LS_paper_sd)
        
        temp = (list_df_dict[key][i].iloc[:,3*j]-list_df_dict[key][i].iloc[:,(19+6*q)])*(np.sqrt(n)/np.sqrt(sigma_n_squared))
        temp1 = (list_df_dict[key][i].iloc[:,3*j]-list_df_dict[key][i].iloc[:,(19+6*q)])*(np.sqrt(n)/list_df_dict[key][i].iloc[:,3*j+2])
        
        sns.kdeplot(temp,cut=1, color=colors[i], linestyle="--", linewidth=lw)
        sns.kdeplot(temp1,cut=1, color=colors[i], label=f"$n$: $9 \cdot 10^{int(np.log10(n))-1}$, $r$: {stab_ratio_LS_paper:.1e}", linewidth=lw)
        curve_line = Line2D([0], [0], color=colors[i], linewidth=2, label=f"$N$: $10^{int(np.log10(n))}$, $r$: {stab_ratio_LS_paper:.1e}")
        legends.append(curve_line)
    plt.minorticks_off()
    first_legend = ax.legend(fontsize=12.75, handlelength=3)
    plt.xlim(-11, 11)
    
    ax.add_artist(first_legend)
    
    solid_line = Line2D([0], [0], color='black', linestyle='-')
    dash_line = Line2D([0], [0], color='black', linestyle='--')
    
    second_legend = ax.legend(fontsize=12.75, handlelength=3, handles=[solid_line, dash_line], labels=[r'$\frac{\sqrt{\frac{n k}{k-1}}}{\hat\sigma_n(h_n)} (\hat R_n - R_n)$', r'$\frac{\sqrt{\frac{n k}{k-1}}}{\sigma(h_n)} (\hat R_n - R_n)$'], loc='upper left')
    plt.ylabel("Density", fontsize=16)
    plt.yticks(fontsize=14)
    plt.xticks(fontsize=14)
    #plt.show()
    
    if comp:
        fig_name = 'CLT_comp.pdf'
    else:
        fig_name = 'CLT_sing.pdf'
    fig.savefig(os.path.join(path_to_res, 'figures', fig_name), bbox_inches = 'tight', pad_inches = 0.001)

    plt.close()
    
    return sigma2_list, sigma2_sd_list, gamma_list, gamma_sd_list, ratio_list, ratio_sd_list


def rate_plot(sigma2_list, sigma2_sd_list, gamma_list, gamma_sd_list, ratio_list, ratio_sd_list, comp):
    fig = plt.figure(figsize=(8, 5))
    if algo == 'ST' and dense_beta and comp:
        start = 2
    else:
        start = 1
    n_samples_array2 = n_samples_array[start:]
    sigma2_list2 = sigma2_list[start:]
    sigma2_sd_list2 = sigma2_sd_list[start:]
    gamma_list2 = gamma_list[start:]
    gamma_sd_list2 = gamma_sd_list[start:]
    ratio_list2 = ratio_list[start:]
    ratio_sd_list2 = ratio_sd_list[start:]
    line1 = plt.errorbar(np.array(n_samples_array2)/n_samples_array2[0], np.array(sigma2_list2)/sigma2_list2[0], yerr=2*np.array(sigma2_sd_list2)/sigma2_list2[0], 
                        label='$\sigma^2(h_n)$', fmt='o', markersize=5, color='tab:blue')
    line3 = plt.errorbar(np.array(n_samples_array2)/n_samples_array2[0], np.array(gamma_list2)/gamma_list2[0], yerr=2*np.array(gamma_sd_list2)/gamma_list2[0], 
                        label='$\gamma(h_n)$', fmt='s', markersize=5, color='tab:orange')
    line5 = plt.errorbar(np.array(n_samples_array2)/n_samples_array2[0], np.array(ratio_list2)/ratio_list2[0], yerr=2*np.array(ratio_sd_list2)/ratio_list2[0], 
                        label='$r(h_n)$', fmt='^', markersize=5, color='tab:green')
    
    if not(comp):
        line2, = plt.plot(np.array(n_samples_array2)/n_samples_array2[0], [1 for v in np.array(n_samples_array2)/n_samples_array2[0]], label='$1$', linestyle='solid', color='tab:blue')
        line4, = plt.plot(np.array(n_samples_array2)/n_samples_array2[0], [1/v**2 for v in np.array(n_samples_array2)/n_samples_array2[0]], label='$1/n^2$', linestyle='dashed', color='tab:orange')
        line6, = plt.plot(np.array(n_samples_array2)/n_samples_array2[0], [1/v for v in np.array(n_samples_array2)/n_samples_array2[0]], label='$1/n$', linestyle='dashdot', color='tab:green')
    else:
        if algo == 'Ridge' or (algo == 'ST' and dense_beta):
            line2, = plt.plot(np.array(n_samples_array2)/n_samples_array2[0], [1/v**2 for v in np.array(n_samples_array2)/n_samples_array2[0]], label='$1/n^2$', linestyle='solid', color='tab:blue')
            line4, = plt.plot(np.array(n_samples_array2)/n_samples_array2[0], [1/v**4 for v in np.array(n_samples_array2)/n_samples_array2[0]], label='$1/n^4$', linestyle='dashed', color='tab:orange')
            line6, = plt.plot(np.array(n_samples_array2)/n_samples_array2[0], [1/v for v in np.array(n_samples_array2)/n_samples_array2[0]], label='$1/n$', linestyle='dashdot', color='tab:green')
        else:
            line2, = plt.plot(np.array(n_samples_array2)/n_samples_array2[0], [1/v**2 for v in np.array(n_samples_array2)/n_samples_array2[0]], label='$1/n^2$', linestyle='solid', color='tab:blue')
            line4, = plt.plot(np.array(n_samples_array2)/n_samples_array2[0], [1/v**2.5 for v in np.array(n_samples_array2)/n_samples_array2[0]], label='$1/(n^2 \sqrt{n})$', linestyle='dashed', color='tab:orange')
            line6, = plt.plot(np.array(n_samples_array2)/n_samples_array2[0], [np.sqrt(v) for v in np.array(n_samples_array2)/n_samples_array2[0]], label='$\sqrt{n}$', linestyle='dashdot', color='tab:green')
    
    plt.xscale('log')
    plt.yscale('log')
    plt.xlabel(f'$n$ / {int(0.9*n_samples_array2[0])} (log scale)', fontsize=16)
    plt.minorticks_off()
    plt.yticks(fontsize=14)
    plt.xticks(fontsize=14)
    plt.legend(handles=[line1, line2, line3, line4, line5, line6], fontsize=12.75, handlelength=2, ncol=3)
    # plt.show()

    if comp:
        fig_name = 'rates_comp.pdf'
    else:
        fig_name = 'rates_sing.pdf'
    fig.savefig(os.path.join(path_to_res, 'figures', fig_name), bbox_inches = 'tight', pad_inches = 0.001)

    plt.close()

if fig1:
    figure1_plot()
else:
    sigma2_list, sigma2_sd_list, gamma_list, gamma_sd_list, ratio_list, ratio_sd_list = kde_plot(comp=False)
    rate_plot(sigma2_list, sigma2_sd_list, gamma_list, gamma_sd_list, ratio_list, ratio_sd_list, comp=False)

    sigma2_list, sigma2_sd_list, gamma_list, gamma_sd_list, ratio_list, ratio_sd_list = kde_plot(comp=True)
    rate_plot(sigma2_list, sigma2_sd_list, gamma_list, gamma_sd_list, ratio_list, ratio_sd_list, comp=True)


print(' ')
print('Total run time:', time.time()-start_time)
