#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from glob import glob
import pickle
import numpy as np
import re
import matplotlib.pyplot as plt
import pandas as pd
import matplotlib as mpl

import sys
dataset = sys.argv[1]
if sys.argv[2]=='True':
    completion = True
elif sys.argv[2]=='False':
    completion = False
else:
    raise Exception("Third arg should be True or False and gives whether to do matrix completion.")

files = glob(f"sim_out/{dataset}/*{completion}*.pkl")
fig_prefix = f"{dataset}_{completion}_"

alphas = np.linspace(0,1,num=50)

print(f"Found {len(files)} files.")

rets = []
for file in files:
    with open(file, 'rb') as f:
        ret = pickle.load(f)
    rets.append(ret)

taus = ret['taus']

competitors = rets[0]['comps']
C = len(competitors)

reps = len(rets)
ncomps = len(competitors)
res = pd.DataFrame(np.zeros([reps,ncomps]))

# Point things.
targs = ['dist2true']
if completion:
    targs.append('pred')
for targ in targs:
    res.columns = rets[0]['comps']
    for rep in range(reps):
        res.iloc[rep,:] = rets[rep][targ]

    res_nada = res[[v for v in res.columns if 'nada' in v]]
    res_cl = res[[v for v in res.columns if 'cl' in v]]
    res_ada = res['ada']

    ## MSE
    #fig = plt.figure(figsize=[4,3])
    fig = plt.figure(figsize=[3.5,2.5])
    #fig = plt.figure(figsize=[3,2])
    #plt.plot(taus, np.mean(res_nada,axis=0), label = 'Bayes')
    plt.plot(taus, np.mean(res_nada,axis=0), label = r"Untuned $\lambda$")
    #plt.plot(taus, np.mean(res_cl,axis=0), label = 'ISTA')
    plt.hlines(np.mean(res_ada), taus[0], taus[-1], color = 'green', label = r"Learned $\lambda$", linestyle='--')
    M = N = 80
    noisesd = 0.1
    if not completion and targ =='dist2true':
        plt.hlines(M*N*np.square(noisesd), taus[0], taus[-1], color = 'orange', label = "Naive Estimate", linestyle='--')
    plt.legend()
    plt.xscale('log')
    #if targ in ['dist2true','pred']:
    #if targ in ['pred']:
    if completion:
        plt.yscale('log')
    plt.xlabel(r"$\lambda$")
    if targ=='dist2true':
        ylab = "SSE"
    elif targ=='nll':
        ylab = "NLL"
    elif targ=='pred':
        ylab = "Missing SSE"
    elif targ=='ncov95':
        ylab = 'Noncoverage'
    else:
        ylab = targ
    plt.ylabel(ylab)
    plt.title(dataset)
    plt.tight_layout()
    plt.savefig(fig_prefix+f"{targ}_vs_tau.pdf")
    plt.close()

