import numpy as np
import scipy.stats
from sklearn import datasets, linear_model, metrics
from numpy import linalg as LA
from pulp import *
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

def mseLoss(theta, X):
    l =0
    for i in range(0,X.shape[0]):
        l = l + (np.dot(theta,X[i,:]))**2
    l = np.sqrt(l/X.shape[0])
    return l

# x is theta
def gradapprox(sketch, lsh, x, sigma, n_components, p):
    d = x.shape[0]
    directions = np.random.normal(size = (n_components,d))
    out = np.zeros_like(x)
    hc = lsh.hash_independent(x,p)[0]
    fx = sketch.query(hc)
    for di in directions:
        hci =  lsh.hash_independent(x + sigma*di,p)[0]
        out += (1.0 / (n_components * sigma)) * ( sketch.query(hci) - fx) * di
    return out

def center(L):
    return np.array([round(np.median(np.array(L)), 3), round(np.sqrt(np.var(np.array(L))), 3)])

def loc_of1s0s(value, p):
    allJ1 =[]
    allJ0 = []
    A = (((value & (1 << np.arange(p)))) > 0).astype(int)
    # A =np.array([int(x) for x in list('{0:0b}'.format(value))])[::-1]
    allJ1 = np.argwhere(A==1)
    allJ0 = np.argwhere(A==0)
    # powersOfTwo = np.array([2**t for t in range(p)])
    # for t in range(p):
    #     if (value%powersOfTwo[t]):
    #         allJ1.append(p-t-1)
    #     else:
    #         allJ0.append(p-t-1)

    return [allJ0, allJ1]

def GradDesc_tanh(Ws,Ys,theta,ridge_lambda,hash_type, learning_rate = 0.01, iterations= 100):
    m = len(Ys)
    cost_history = np.zeros(iterations)
    if hash_type == "alsh_prp":
        theta = np.append(theta, np.array([0,0]))
        for it in range(iterations):
            prediction = np.tanh(np.dot(Ws, theta))
            # print(Ws, theta, prediction, Ys)
            theta = theta - (1/m)*learning_rate*(Ws.T.dot(np.multiply((prediction- Ys),(1-prediction**2)))+ ridge_lambda*theta)
            theta[-3] = -1
            theta[-2] = 0
            theta[-1] = np.sqrt(1- np.linalg.norm(theta[:-3])**2)
            cost_history[it] = (sum((prediction- Ys)**2))**0.5
    else:
        for it in range(iterations):
            prediction = np.tanh(np.dot(Ws, theta))
            # print(Ws, theta, prediction, Ys)
            theta = theta - (1/m)*learning_rate*(Ws.T.dot(np.multiply((prediction- Ys),(1-prediction**2)))+ ridge_lambda*theta)
            theta[-1] = -1
            cost_history[it] = (sum((prediction- Ys)**2))**0.5
        # plt.plot(cost_history)
        # plt.show()
    return theta, cost_history

def plotTheta(theta, dataset):
    if dataset.shape[1]>3:
        print ("high dimention data, can't plot")
    else:
        x = np.linspace(-20,20,100)
        y = theta[0]*dataset[:,0] + theta[1]*dataset[:,1]

        fig = plt.figure()
        # ax = fig.add_subplot(111, projection='3d')
        # # ax.scatter(dataset[:,0], dataset[:,1], dataset[:,2], c='r', marker='o', alpha = "0.5")
        # ax.scatter(dataset[:,0], dataset[:,1], y, color='b')
        #
        # ax.set_xlabel('X Label')
        # ax.set_ylabel('Y Label')
        # ax.set_zlabel('Z Label')
        #
        # plt.show()

        plt.scatter(dataset[:,0], dataset[:,2], alpha = "0.5")
        plt.scatter(dataset[:,0], y, color='slateblue', marker = ".", label = "STORM")
        plt.title("synthetic (n=1600,d=2)")
        plt.xlabel("x1",fontsize = 18)
        plt.ylabel("Y",fontsize = 18)
        plt.legend(prop={'size': 15})
        plt.grid()
        # plt.plot(x,y,color='r')
        fig.savefig('storm.svg', dpi=500)

        # plt.show()
        # # plt.axhline(testloss_ls_rs, color='y', linestyle=':', label = "test loss RS")

def plotres(plotVars, ridge_lambda):
    plotVars = np.array(plotVars)
    # print (plotVars)

    [datasetName,N,d, L1, L2, M1, M2, A1,A2,A3,A4,S1,S2,C1,C2,R1,R2,V1,V2,T1,T2,Reps,P,sparseMem,rho,avgcount] = [plotVars[:,i] for i in range(plotVars.shape[1])]
    mem_fact = rho.astype(float)
    ops = [L1, L2, M1, M2, A1,A2,A3,A4,S1,S2,C1,C2,R1,R2,V1,V2,T1,T2]
    [L1, L2, M1, M2, A1,A2,A3,A4,S1,S2,C1,C2,R1,R2,V1,V2,T1,T2] = [np.stack(op).astype(float) for op in ops]

    errorBar = False
    fig = plt.figure()
    plt.yscale('log')
    plt.plot(mem_fact, T1[:,0], label = "STORM_LS, BestACE, P=4", linestyle='-.', color='red')

    plt.plot(mem_fact, M1[:,0], label = "Mean", color='black')
    plt.plot(mem_fact, L1[:,0], label = "Least Sq", linestyle='-', color="orange")
    plt.plot(mem_fact, R1[:,0], label = "Random sampling", linestyle='-', color='slateblue')
    if ridge_lambda==0:
        plt.plot(mem_fact, V1[:,0], label = "leverage sampling", linestyle='-', color='darkslateblue')
        plt.plot(mem_fact, C1[:,0], label = "Clarkson2009", linestyle='-', color='blueviolet')
    if errorBar:
        plt.fill_between(mem_fact, T1[:,0] - T1[:,1], T1[:,0] + T1[:,1], color='red', alpha=0.1)
        plt.fill_between(mem_fact, R1[:,0] - R1[:,1], R1[:,0] + R1[:,1], color='slateblue', alpha=0.1)
        plt.fill_between(mem_fact, V1[:,0] - V1[:,1], V1[:,0] + V1[:,1], color='darkslateblue', alpha=0.1)
        plt.fill_between(mem_fact, C1[:,0] - C1[:,1], C1[:,0] + C1[:,1], color='blueviolet', alpha=0.1)

    # plt.plot(Rep_, M_[:,0], label = "Mean")
    plt.title(datasetName[0]+ "_train")
    plt.xlabel("mem_factor")
    plt.ylabel("MSE")
    plt.legend(prop={'size': 8})
    plt.grid()
    # plt.show()
    fig.savefig(datasetName[0]+'train_lambda_'+str(ridge_lambda)+'.png', dpi=500)

    fig = plt.figure()

    plt.yscale('log')
    plt.plot(mem_fact, T2[:,0], label = "STORM_LS, BestACE, P=4", linestyle='-.', color='red')

    plt.plot(mem_fact, M2[:,0], label = "Mean", color='black')
    plt.plot(mem_fact, L2[:,0], label = "Least Sq", linestyle='-', color="orange")
    plt.plot(mem_fact, R2[:,0], label = "Random sampling", linestyle='-', color='slateblue')
    if ridge_lambda==0:
        plt.plot(mem_fact, V2[:,0], label = "leverage sampling", linestyle='-', color='darkslateblue')
        plt.plot(mem_fact, C2[:,0], label = "Clarkson2009", linestyle='-', color='blueviolet')
    if errorBar:
        plt.fill_between(mem_fact, T2[:,0] - T2[:,1], T2[:,0] + T2[:,1], color='red', alpha=0.1)
        plt.fill_between(mem_fact, R2[:,0] - R2[:,1], R2[:,0] + R2[:,1], color='slateblue', alpha=0.1)
        plt.fill_between(mem_fact, V2[:,0] - V2[:,1], V2[:,0] + V2[:,1], color='darkslateblue', alpha=0.1)
        plt.fill_between(mem_fact, C2[:,0] - C2[:,1], C2[:,0] + C2[:,1], color='blueviolet', alpha=0.1)

    # plt.plot(Rep_, M_[:,0], label = "Mean")
    plt.title(datasetName[0] + "_test")
    plt.xlabel("mem_factor")
    plt.ylabel("MSE")
    plt.legend(prop={'size': 8})
    plt.grid()
    # plt.show()
    fig.savefig(datasetName[0]+'test_lambda_'+str(ridge_lambda)+'.png', dpi=500)
