import numpy as np 
import matplotlib.pyplot as plt 


from storm.sketch import *
from storm.hashes.srp_hash import SRPHash
from storm.hashes.asymmetric_simple_hash import AsymmetricSimpleHash
from storm.hashes.asymmetric_series_hash import AsymmetricSeriesHash
from storm.hashes.asymmetric_ball_hash import AsymmetricBallHash
from storm.optimization.iterative_optimizer import IterativeOptimizer
from storm.optimization.utils import *
from storm.datasets import *
from storm.baselines import *

import sys
import random
import pickle
import os 
import time 

from sklearn import cluster, datasets

n_samples = 20000
Xtrain, Ytrain = datasets.make_circles(n_samples=n_samples, factor=.5, noise=.05)

C0 = np.squeeze( Xtrain[ np.where(Ytrain == 0), :] ) / 2.0
C1 = np.squeeze( Xtrain[ np.where(Ytrain == 1), :] ) / 2.0
print(C0.shape)
print(C1.shape)

# center = np.mean(np.vstack((C0,C1)),axis = 0)
# C0 -= center
# C1 -= center

d = 2
p = 1
reps = 1000

sigma = 0.1
n_iters = 500
eta = 0.01
beta = 0.5
n_components = 8
D = 500
use_intercept = False
seed = np.random.randint(0,10e3)
seed = 2143
print(f"Seed: {seed}")


FF = FourierFeatures(d, D)

input_C0 = FF.featurize(C0, intercept = use_intercept)
input_C1 = FF.featurize(C1, intercept = use_intercept)

LSH = SRPHash(N = reps, d = input_C0.shape[1], p = p, seed = seed)
S = STORM(reps,2**p, LSH)

start = time.time()
S.addMulti(input_C0)
S.addMulti(-input_C1)
end = time.time()
print(f"Sketching took {end-start}")

# constraint A forces us not to optimize the intercept
# def constraintA(theta):
# 	theta[-1] = 0
# 	return theta

gradient = BallGradApprox(S, sigma, n_components)
opt = IterativeOptimizer(S, gradient)#, constraintA)

theta = np.zeros(input_C0.shape[1])
theta, losses = opt.optimize(theta, eta, beta, n_iters, compute_function = True, verbose = True)

# def constraintB(x):
# 	x[:-1] = theta[:-1]
# 	return x

# gradient = BallGradApprox(S, 0.05, n_components)
# opt = IterativeOptimizer(S, gradient, constraintB)
# theta, losses = opt.optimize(theta, 0.5, 0, 100, compute_function = True, verbose = True)




plt.plot(C1[:,0],C1[:,1],'.')
plt.plot(C0[:,0],C0[:,1],'.',color = '#231651')

M = 200
x = np.linspace(-1,1,M)
y = np.linspace(-1,1,M)
Z = np.zeros((M,M))

for i,xi in enumerate(x): 
	for j,yi in enumerate(y): 
		features = FF.featurize(np.array([xi,yi]), intercept = use_intercept)
		Z[j,i] = np.dot(theta, features)
X,Y = np.meshgrid(x,y)

plt.contour(X, Y, Z, colors = 'k', linestyles = '--')
plt.contour(X, Y, Z, [0], colors = 'r', linestyles = '--')


plt.plot([-10,-11],[-10,-11],'k--',label = "Decision Boundary")
plt.xlim(-1.1,1.1)
plt.ylim(-1.1,1.1)

plt.grid()

plt.xlabel("x1",fontsize = 16)
plt.ylabel("x2",fontsize = 16)

plt.legend(fontsize = 12,loc = 'upper left')

plt.title("synthetic (1000 per class)")

plt.figure()
plt.plot(losses)

plt.show()


