import numpy as np 
import matplotlib.pyplot as plt 


from storm.sketch import *
from storm.hashes.srp_hash import SRPHash
from storm.hashes.asymmetric_simple_hash import AsymmetricSimpleHash
from storm.hashes.asymmetric_series_hash import AsymmetricSeriesHash
from storm.hashes.asymmetric_ball_hash import AsymmetricBallHash
from storm.optimization.iterative_optimizer import IterativeOptimizer
from storm.optimization.utils import *
from storm.datasets import *
from storm.baselines import *

import sys
import random
import pickle
import os 
import time 

mu1 = (0.3,0.5)
cov1 = [[0.01,0],[0,0.01]]
mu2 = (0.0,0.0)
cov2 = [[0.001,0.008],[0.008,0.0015]]

N0 = 10000
N1 = 10000

C0 = np.random.multivariate_normal(mu2,cov2,size = N0)
C1 = np.random.multivariate_normal(mu1,cov1,size = N1)
center = np.mean(np.vstack((C0,C1)),axis = 0)
C0 -= center
C1 -= center

d = 2
p = 4
reps = 1000

seed = np.random.randint(0,10e3)

sigma = 0.1
n_iters = 2000
eta = 0.1
beta = 0.5
n_components = 8

# input_C0 = C0
# input_C1 = C1

# use intercepts instead:
input_C0 = np.hstack((C0,1*np.ones((N0,1))))
input_C1 = np.hstack((C1,1*np.ones((N1,1))))

LSH = SRPHash(N = reps, d = input_C0.shape[1], p = p, seed = seed)
# LSH = AsymmetricSimpleHash(N = reps, d = input_C0.shape[1], p = p, seed = seed)
# LSH = AsymmetricSeriesHash(N = reps, d = input_C0.shape[1], m = 4, p = p, seed = seed)
# LSH = AsymmetricBallHash(N = reps, d = input_C0.shape[1], p = p, seed = seed)

S = STORM(reps, 2**p, LSH)


start = time.time()
S.addMulti(input_C0)
S.addMulti(-input_C1)
end = time.time()
print(f"Sketching took {end-start}")

gradient = BallGradApprox(S, sigma, n_components)
opt = IterativeOptimizer(S, gradient)

theta = np.zeros(input_C0.shape[1])
theta, losses = opt.optimize(theta, eta, beta, n_iters, compute_function = True)


plt.plot(C1[:,0],C1[:,1],'.')
plt.plot(C0[:,0],C0[:,1],'.',color = '#231651')

x0 = np.linspace(-0.5,0.5,100)
x1 = -theta[0]*x0/theta[1]

plt.plot(x0,x1,'k--',label = "STORM Classifier")

if len(theta) == 3:
	x1 = -theta[0]*x0/theta[1] - theta[2]/theta[1]
	plt.plot(x0,x1,'r-', label = "With intercept")



plt.grid()

plt.xlabel("x1",fontsize = 16)
plt.ylabel("y",fontsize = 16)

plt.legend(fontsize = 12)

plt.title("synthetic (1000 per class)")

plt.xlim(-0.5,0.5)
plt.ylim(-0.5,0.5)

plt.figure()
plt.plot(losses)



plt.show()


# def constraint(theta):
# 	norm = np.linalg.norm(theta)
# 	print(norm)
# 	if norm >= 1:
# 		return theta/norm
# 	else:
# 		return theta