import numpy as np
import matplotlib.pyplot as plt

from storm.sketch import *
from storm.hashes.srp_hash import SRPHash
from storm.hashes.asymmetric_simple_hash import AsymmetricSimpleHash
from storm.hashes.asymmetric_series_hash import AsymmetricSeriesHash
from storm.hashes.asymmetric_ball_hash import AsymmetricBallHash
from storm.optimization.iterative_optimizer import IterativeOptimizer
from storm.optimization.utils import *
from storm.datasets import *
from storm.baselines import *

import sys
import random
import pickle
import os
import time

'''
All streaming sketches are linear sketches
therefore they can't do

Sanity check on non-regularized methods
to show that we know

'''

datafile = 'data/australian.svm'
testfile = 'data/australian.svm'
# datafile = 'data/phishing.svm'
# testfile = 'data/phishing.svm'

# for australian
# R = [5,10,15,20,30,40,50,60,70,80,90,100,200,500]

# for phishing
R = [50,100,150,200,250,300,350,400,450,500,700,1000,2000]

n_trials = 20

p = 4
sigma = 0.1
n_iters = 1000
eta = 0.3
beta = 0.5
n_components = 8
use_intercept = True
seed = 42

x,y,x_test,y_test = load(datafile,testfile)
C0,C1,x,y = format_class(x,y)

storm_accuracy = np.zeros((len(R),n_trials))
storm_sizes = np.zeros((len(R),n_trials))

for j,reps in enumerate(R):
	print(f"Sketching with {reps}")
	for i in range(n_trials):
		LSH = SRPHash(N = reps, d = C0.shape[1], p = p, seed = seed+i)
		S = STORM(reps, 2**p, LSH)

		start = time.time()
		S.addMulti(C0)
		S.addMulti(-C1)
		end = time.time()
		print(f"Sketching took {end-start}")
		sys.stdout.flush()

		gradient = BallGradApprox(S, sigma, n_components)
		opt = IterativeOptimizer(S, gradient)

		theta = np.zeros(C0.shape[1])
		theta, losses = opt.optimize(theta, eta, beta, n_iters, compute_function = True, verbose = False)

		# now compute the train/test scores

		pred = np.sign( np.dot(x,theta) )
		train_accuracy = 1 - np.sum(np.abs(y-pred))/len(y)
		storm_accuracy[j,i] = train_accuracy
		storm_sizes[j,i] = S.min_size()
		print(f"\tAccuracy: {train_accuracy:.2f}")

name = os.path.splitext(os.path.basename(datafile))[0]
np.savetxt(name+".results",storm_accuracy, delimiter = ',')
np.savetxt(name+".sizes",storm_sizes, delimiter = ',')
