from forecasters import run_on_dataset
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.stats import mstats
import copy
from forecasters import *

dataset_name = 'vehicle'

if dataset_name == 'vehicle':
    dataset = load_svmlight_file('vehicle.scale')
    K = 4
elif dataset_name == 'shuttle':
    dataset = load_svmlight_file('shuttle.scale')
    K = 7
elif dataset_name == 'segment':
    dataset = load_svmlight_file('segment.scale')
    K = 7

X, Y = dataset
X, Y = X.todense(), Y.astype(int)-1
d = X.shape[1]

forecasters = {}

if dataset_name == 'vehicle':
    forecasters['OGD'] = OGD(d, K, lbd=10.)
    forecasters['ONS'] = ONS(d, K, lbd=.3, beta=.3)
    forecasters['GAF'] = GAF(d, K, lbd=.1, beta=.3)
elif dataset_name == 'shuttle':
    forecasters['OGD'] = OGD(d, K, lbd=1.)
    forecasters['ONS'] = ONS(d, K, lbd=.03, beta=.1)
    forecasters['GAF'] = GAF(d, K, lbd=.03, beta=.3)
elif dataset_name == 'segment':
    forecasters['OGD'] = OGD(d, K, lbd=10.)
    forecasters['ONS'] = ONS(d, K, lbd=.3, beta=.1)
    forecasters['GAF'] = GAF(d, K, lbd=.01, beta=3.)

nb_run = 20

fig=plt.figure(figsize=(4,3), dpi= 100, facecolor='w', edgecolor='k')  # to have big plots

for i, (k, f) in enumerate(forecasters.items()):
    n = len(Y)
    L = np.zeros((n,nb_run))
    for j in range(nb_run):
        print(k,j)
        perm = np.random.permutation(n)
        Xr, Yr = X[perm,:], Y[perm]
        ff = copy.deepcopy(f)
        L[:,j] = run_on_dataset(ff,Xr,Yr)
    t_range = np.arange(n)+1
    L /= t_range[:,None]
    quantiles = mstats.mquantiles(L,axis=1)
    colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
    plt.loglog(t_range, quantiles[:,1], color=colors[i], label=k)
    plt.fill_between(t_range, quantiles[:,0] ,quantiles[:,2], color=colors[i], alpha=.2)

plt.legend()
if dataset_name == 'vehicle':
    plt.savefig("vehicle_plot_error_bar.pdf", bbox_inches = "tight")
elif dataset_name == 'shuttle':
    plt.savefig("shuttle_plot_error_bar.pdf", bbox_inches = "tight")
elif dataset_name == 'segment':
    plt.savefig("segment_plot_error_bar.pdf", bbox_inches = "tight")

plt.show()