
from race import *
from utils import *
from dataLoader import *
import numpy as np
import scipy.stats
import pickle
from sklearn.linear_model import LinearRegression
from sklearn import datasets, linear_model, metrics
from numpy import linalg as LA
import random
import math
import matplotlib.pyplot as plt
from pulp import *
import random

'''
import time
start = time.time()
# end = time.time()
# print("Sketch insertion time: ", end - start)

'''

def getSketch(datasetName, dataset_train, P, REPS, hash_type):
    # NI = dataset.shape[0]
    # tr = int(NI*(0.8)) # 80% points in training
    print("hash_type is: ", hash_type)
    tr = dataset_train.shape[0]
    # dataset_train = dataset[:tr,:]
    # dataset_test = dataset[tr:,:]

    S = RACECounts(REPS,2**P) # race sketch
    lsh = SRP(REPS*P,dataset_train.shape[1]) # srp hashing mechanism

    if (path.exists("../RACEsketchNormalised2/"+datasetName+str(dataset_train.shape[0])+"_"+str(REPS)+"_"+str(P)+"_"+hash_type+".p")):
        print ("Loading sketch")
        Sketch_Hyp = pickle.load(open("../RACEsketchNormalised2/"+datasetName+str(dataset_train.shape[0])+"_"+str(REPS)+"_"+str(P)+"_"+hash_type+".p", "rb"))
        S.counts = Sketch_Hyp[0]
        lsh.W = Sketch_Hyp[1]
    else:
        print ("preparing sketch")
        for e in dataset_train:
            S.add(lsh.hash_independent(e,P)[0])
            S.add(lsh.hash_independent(-e,P)[0]) # Comment this out for SRP
        #save STORM sketch
        pickle.dump([S.counts, lsh.W], open("../RACEsketchNormalised2/"+datasetName+str(dataset_train.shape[0])+"_"+str(REPS)+"_"+str(P)+"_"+hash_type+".p", "wb"))

    return S,lsh

def OneEditOneJump(sketch, lsh, x, p, alpha, skLocs):
    d = x.shape[0]
    h = lsh.hash_independent(x,p)[1]
    # i = np.random.randint(h.shape[0])
    i = random.choice(skLocs)
    p = h.shape[1]
    nearh = np.zeros((p+1,p))
    nearsk = np.zeros(p+1)
    powersOfTwo = np.array([2**t for t in range(p)])

    # flip bit each time
    nearsk[0] = sketch.counts[i, int(np.dot(h[i], powersOfTwo))]
    nearh[0,:] = h[i]
    for j in range(1, p+1):
        key = h[i]
        key[j-1] = 1- key[j-1]
        nearsk[j] = sketch.counts[i, int(np.dot(key, powersOfTwo))]
        nearh[j, :] = key

    # find min
    # print (nearsk)
    # print (sketch.counts[i,:])
    J = np.argmin(nearsk)
    if J ==0:
        return x
    else:
        hp = lsh.W[i*p + J-1]
        # jump
        x = x - (1+ alpha)*(np.dot(hp,x))/(LA.norm(hp)**2)*hp
        return x

def AllEditOneJump(sketch, lsh, x, p, alpha, skLocs):
    d = x.shape[0]
    h = lsh.hash_independent(x,p)[1]
    # i = np.random.randint(h.shape[0])
    i = random.choice(skLocs)
    p = h.shape[1]
    powersOfTwo = np.array([2**t for t in range(p)])

    loc = np.argmin(sketch.counts[i, :])
    allJ = loc_of1s0s((int(np.dot(h[i], powersOfTwo)))^loc, p)[1] # xor

    if not allJ:
        return x
    else:
        direction = np.zeros_like(x)
        for J in allJ:
            hp = lsh.W[i*p + J]
            direction = direction + (np.dot(hp,x))/(LA.norm(hp)**2)*hp
        x = x - (1+ alpha)*direction
        return x

def getThetaFromLP(hpPos, hpNeg, ThetaPrev, J):
    prob = LpProblem("Theta_Problem",pulp.LpMinimize)
    RANGE = [i for i in range(len(ThetaPrev))]
    x = pulp.LpVariable.dicts("x", RANGE,  cat="Continuous") # shd not be an integer

    if J ==2:
        for hp in hpPos:
            prob += lpSum([hp[i]*x[i] - 0.1 for i in RANGE]) >= -np.dot(hp,ThetaPrev)
        for hp in hpNeg:
            prob += lpSum([hp[i]*x[i] + 0.1 for i in RANGE]) <= -np.dot(hp,ThetaPrev)
        prob += x[len(ThetaPrev)-1]+ ThetaPrev[len(ThetaPrev)-1] +1 == 0
        prob.solve()
        return ([x[i].varValue +ThetaPrev[i] for i in RANGE])
    else:
        for hp in hpPos:
            prob += lpSum([hp[i]*x[i] - 0.1 for i in RANGE]) >= 0
        for hp in hpNeg:
            prob += lpSum([hp[i]*x[i] + 0.1 for i in RANGE]) <= 0
        prob += x[len(ThetaPrev)-1] +1 == 0
        prob.solve()
        return ([x[i].varValue for i in RANGE])


# iterative minimisation using LP
def AllEditLP(sketch, lsh, x, p, skLocs):
    powersOfTwo = np.array([2**t for t in range(p)])
    # x is initial theta

    h = lsh.hash_independent(x,p)[1]
    # i = np.random.randint(h.shape[0])
    i = random.choice(skLocs)
    loc = np.argmin(sketch.counts[i, :])
    [allJ0, allJ1] = loc_of1s0s(loc, p)

    # theta = np.zeros_like(x)
    #all 1s
    hpPos = []
    for J in allJ1:
        hp = lsh.W[i*p + J]
        hpPos.append(hp)
    #all 0s
    hpNeg = []
    for J in allJ0:
        hp = lsh.W[i*p + J]
        hpNeg.append(hp)
    theta = np.array(getThetaFromLP(hpPos, hpNeg, x, 2))

    return theta

# Single best ACE
def BestACE_LP(sketch, lsh, x, p, alpha, skLocs, dataset_train, dataset_test):
    Theta = []
    powersOfTwo = np.array([2**t for t in range(p)])
    # x is initial theta
    # if (n_iters > sketch.counts.shape[0]):
    #     n_iters = sketch.counts.shape[0]
    n_iters = len(skLocs)
    losses_train = np.zeros(n_iters)
    losses_test = np.zeros(n_iters)
    for itrr, i in enumerate(skLocs):
        h = lsh.hash_independent(x,p)[1]
        # i = np.random.randint(h.shape[0])
        loc = np.argmin(sketch.counts[i, :])
        [allJ0, allJ1] = loc_of1s0s(loc, p)

        # theta = np.zeros_like(x)
        #all 1s
        hpPos = []
        for J in allJ1:
            hp = lsh.W[i*p + J]
            hpPos.append(hp)
        #all 0s
        hpNeg = []
        for J in allJ0:
            hp = lsh.W[i*p + J]
            hpNeg.append(hp)
        theta = np.array(getThetaFromLP(hpPos, hpNeg, x, 3))
        Theta.append(theta)
        losses_train[itrr] = mseLoss(theta, dataset_train)
        losses_test[itrr] = mseLoss(theta, dataset_test)

    return [Theta, losses_train, losses_test]

def BestACE_LS(sketch, lsh, theta, p, alpha, ridge_lambda, skLocs, dataset_train, dataset_test, serPath, hash_type, K):
    d = dataset_test.shape[1]
    if (path.exists(serPath)):
        print ("Loading results")
        [Theta, losses_train, losses_test] = pickle.load(open(serPath,"rb"))
    else:
        print ("preparing results")
        Theta = []
        losses_train = np.zeros(sketch.counts.shape[0])
        losses_test = np.zeros(sketch.counts.shape[0])

        for i in range(sketch.counts.shape[0]):
            st = max(0,i-K)
            en = min(i+K,sketch.counts.shape[0]-1)
            Ys = np.empty([0])
            Ws = np.empty([0, lsh.W.shape[1]])
            for j in range(st, en):
                loc = np.argmin(sketch.counts[j, :])
                Ys1 = 2*(((loc & (1 << np.arange(p)))) > 0).astype(int) -1
                Ys = np.append(Ys, Ys1)
                Ws1 = lsh.W[j*p: (j+1)*p]
                Ws = np.append(Ws, Ws1, axis=0)
            # print (Ys1, Ws1, Ys, Ws)
            theta, cost_history = GradDesc_tanh(Ws,Ys,theta, ridge_lambda, hash_type)
            theta = theta[:d]
            Theta.append(theta)
            losses_train[i] = mseLoss(theta, dataset_train)
            losses_test[i] = mseLoss(theta, dataset_test)
        pickle.dump([Theta, losses_train, losses_test], open(serPath, "wb"))

    Lt = []
    Ls = []
    theta_hat = []
    for i in skLocs:
        losses_train1 = np.array(losses_train[i])
        losses_test1 = np.array(losses_test[i])
        Theta1 = list(map(Theta.__getitem__, i))
        # Theta1 = np.array(Theta[i])
        a = np.argmin(losses_train1)
        Lt.append(losses_train1[a])
        Ls.append(losses_test1[a])
        theta_hat.append(Theta1[a])

    # Theta = []
    # powersOfTwo = np.array([2**t for t in range(p)])
    # # x is initial theta
    # # if (n_iters > sketch.counts.shape[0]):
    # #     n_iters = sketch.counts.shape[0]
    # n_iters = len(skLocs)
    # losses_train = np.zeros(n_iters)
    # losses_test = np.zeros(n_iters)
    # for itrr, i in enumerate(skLocs):
    #     h = lsh.hash_independent(x,p)[1]
    #     # i = np.random.randint(h.shape[0])
    #     loc = np.argmin(sketch.counts[i, :])
    #     [allJ0, allJ1] = loc_of1s0s(loc, p) # where p is the number of hyperplanes
    #     # verufy Ys from loc_of1s0s
    #     Ys = 2*(((loc & (1 << np.arange(p)))) > 0).astype(int) -1
    #     # print (loc)
    #     # print ([allJ0, allJ1])
    #     # print (Ys)
    #     Ws = lsh.W[i*p: (i+1)*p]
    #     theta = np.zeros_like(x)
    #     theta, cost_history = GradDesc_tanh(Ws,Ys,theta, ridge_lambda)
    #     theta[-1] = -1
    #     Theta.append(theta)
    #     losses_train[itrr] = mseLoss(theta, dataset_train)
    #     losses_test[itrr] = mseLoss(theta, dataset_test)

    return [theta_hat, Lt, Ls]

def stormRegression(datasetName, dataset, Reps, P, method, n_iters, alpha, ridge_lambda, init, REPS, hash_type):
    np.set_printoptions(threshold=sys.maxsize)

    NI = dataset.shape[0]
    tr = int(NI*(0.8)) # 80% points in training
    ts = NI-tr
    # normY = LA.norm(dataset[:,-1])
    # dataset[:,-1] = dataset[:,-1]/normY
    dataset_train = dataset[:tr,:]
    dataset_test = dataset[tr:,:]

    #alsh data pre-process
    if hash_type =="alsh_prp":
        col1 = np.sqrt(1- np.linalg.norm(dataset_train[:,:-1], axis=1)**2)
        print (np.linalg.norm(dataset_train[:,:-1], axis=1))
        col2 = np.zeros(tr)
        S,lsh = getSketch(datasetName, np.column_stack((dataset_train, col1, col2)), P, REPS, hash_type)
    else:
        S,lsh = getSketch(datasetName, dataset_train, P, REPS, hash_type)

    Results = []
    p = P
    # initialisation
    if init == "zeros":
        theta =  list(0*np.ones(dataset_train.shape[1]))
    else:
        theta =  list(np.random.normal(scale = 1, size=dataset_train.shape[1]))

    # theta = np.append(theta_ls, -1)
    # print ("initial loss= ", mseLoss(theta, dataset_train))

    theta = np.array(theta,dtype = np.float64)
    skLocs = []
    sparseMem = []
    K = 4

    for reps in Reps:
        sampl = random.sample(range(0,REPS), reps)
        skLocs.append(sampl)
        # sparseMem.append(S.min_size(reps)) #bytes
        # print (S.equivalent_size(reps, sampl, K), S.min_size(reps))
        sparseMem.append(S.equivalent_size(reps, sampl, K)) #bits

    if method =="BestACE_LP":
        [Theta, losses_train, losses_test] = BestACE_LP(S, lsh, theta, p, alpha, skLocs, dataset_train, dataset_test)
        # theta = (i*theta_prev + theta)/(i+1)

    elif method =="BestACE_LS":
        serPath = "../BestACE_LS_preprocessed/" +datasetName+str(dataset_train.shape[0])+"_"+str(REPS)+"_"+str(P)+"_"+ str(alpha)+"_"+ str(ridge_lambda)+"_"+hash_type+".p"
        [Theta, losses_train, losses_test] = BestACE_LS(S, lsh, theta, p, alpha, ridge_lambda, skLocs, dataset_train, dataset_test, serPath, hash_type, K)

    else:
        losses_train = np.zeros(n_iters)
        losses_test = np.zeros(n_iters)
        Theta = []
        for i in range(n_iters):

        	if method =="OneEditOneJump":
        	       theta = OneEditOneJump(S, lsh, theta, p, alpha, skLocs)
        	       theta[dataset_train.shape[1] -1 ] = -1 # project back onto the constraint
        	if method =="AllEditOneJump":
        	       theta = AllEditOneJump(S, lsh, theta, p, alpha, skLocs)
        	       theta[dataset_train.shape[1] -1 ] = -1 # project back onto the constraint
        	if method =="AllEditLP":
        	       theta = AllEditLP(S, lsh, theta, p, skLocs)

        	# derivative free
            # else:
        	#        theta = theta - eta*gradapprox(S,lsh,theta,sigma,8, p)

        	losses_train[i] = mseLoss(theta, dataset_train)
        	losses_test[i] = mseLoss(theta, dataset_test)

        	Theta.append(theta)
        	sys.stdout.flush()

    # pickle.dump([losses_train,losses_test, theta] , open("LSOutput/"+datasetName+str(dataset_train.shape[0])+"_"+str(REPS)+"_"+str(p)+".p", "wb") )
    # for i, reps in enumerate(Reps):
    #     Results.append()

    return ([Theta, losses_train, losses_test, sparseMem])
