import random
from data.ghg import getGHG
from data.electric import getElectric
from data.gas import getGas
from misc_utils import *
import numpy as np
import sys
import time
sys.path.append('../..')
import torch
import cvxpy as _cp
np.set_printoptions(threshold=100000)

def countSketch(m, n):

    S = np.zeros((m, n))

    for i in range(n):
        row =random.randint(0, m - 1)
        sign = 2 * random.randint(0, 1) - 1
        S[row][i] = sign

    return S


def learnCountSketch(A, m, n, t, position = None):

    S = np.zeros((m, n))
    u, s, v = np.linalg.svd(A, full_matrices=0)
    val = np.zeros(n)
    # print(n)
    for i in range(n):
        # if pos[i] >= 10:
        # val[i] = np.linalg.norm(A[i, :])
        val[i] = np.linalg.norm(u[i, :])
    # thr = (-np.sort(-val))[t - 1]
    # print(val)
    thr = (-np.sort(-val))[t - 1]
    # print(thr)
    pos = 0
    for i in range(n):
        if val[i] >= thr: #position
            S[pos][i] = 2 * random.randint(0, 1) - 1
            pos += 1

        else:
            row = random.randint(t, m - 1)
            sign = 2 * random.randint(0, 1) - 1
            S[row][i] = sign

    return S, 0

def learnCountSketch_ghg(A, m, n, t, pos = None):

    idx = np.zeros(327)
    leverage = np.zeros(370)
    id = 0
    B = np.zeros((60, d))
    # for i in range(n):
    #     if pos[i] >= 20:
    #         B[id, :] = A[i, :]
    #         idx[id] = i
    #         id += 1
    start = time.perf_counter()
    # u, s, v = np.linalg.svd(A, full_matrices=0)
    u, s, v = np.linalg.svd(B, full_matrices = 0)
    end = time.perf_counter()

    S = np.zeros((m, n))
    # u, s, v = np.linalg.svd(A, full_matrices=0)
    val = np.zeros(n)
    # for i in range(n):
    #     if pos[i] >= 10:
        #     val[i] = np.linalg.norm(A[i, :])
        # val[i] = np.linalg.norm(u[i, :])
    for i in range(60):
        val[int(idx[i])] = np.linalg.norm(u[i, :])
    thr = (-np.sort(-val))[t - 1]
    pos = 0
    for i in range(n):
        if val[i] >= thr:
            S[pos][i] = 2 * random.randint(0, 1) - 1
            pos += 1

        else:
            row = random.randint(t, m - 1)
            sign = 2 * random.randint(0, 1) - 1
            S[row][i] = sign

    return S, end - start

def SaveSketch_electric(m, n, t, name): # self, A,

    position = np.load("electric_position.npy")
    S = np.zeros((m, n))
    val = np.zeros(n)

    for i in range(n):
        # val[int(idx[i])] = np.linalg.norm(u[i, :])
        val[i] = position[i]

    thr = (-np.sort(-val))[t - 1]
    pos = 0
    order = np.zeros(m)
    for i in range(m):
        order[i] = i
    np.random.seed(int(time.time()))
    np.random.shuffle(order)
    for i in range(n):
        if val[i] >= thr:
            S[int(order[pos])][i] = 2 * random.randint(0, 1) - 1
            pos += 1

        else:
            row = random.randint(t, m - 1)
            sign = 2 * random.randint(0, 1) - 1
            S[int(order[row])][i] = sign
    np.save("ini_sketch_" + name, S)

def SaveSketch_ghg(m, n, t, name): # self, A,

    position = np.load("ghg_position.npy")
    S = np.zeros((m, n))
    val = np.zeros(n)
    idx = np.zeros(327)
    id = 0
    B = np.zeros((121, d))
    for i in range(n):
        if position[i] >= 10:
            B[id, :] = A[i, :]
            idx[id] = i
            id += 1

    S = np.zeros((m, n))

    val = np.zeros(n)

    for i in range(121):
        val[int(idx[i])] = np.linalg.norm(B[i, :])
    thr = (-np.sort(-val))[t - 1]
    pos = 0
    for i in range(n):
        if val[i] >= thr:
            S[pos][i] = 2 * random.randint(0, 1) - 1
            pos += 1

        else:
            row = random.randint(t, m - 1)
            sign = 2 * random.randint(0, 1) - 1
            S[row][i] = sign

    np.save("ini_sketch_" + name, S)


if __name__ == '__main__':

    rawdir = "data"


    AB_train, AB_test, n, d_a, d_b = getGHG(False, 500, rawdir, 100)
    # AB_train, AB_test, n, d_a, d_b = getElectric(False, rawdir, 100)
    A_train = AB_train[0]
    B_train = AB_train[1]
    A_test = AB_test[0]
    B_test = AB_test[1]

    N_train = len(A_train)
    N_test = len(A_test)
    # print("Dim= ", n, d_a, d_b)
    print("N train=", N_train, "N test=", N_test)

    A_train = A_train.cpu().numpy()
    B_train = B_train.cpu().numpy()
    A_test = A_test.cpu().numpy()
    B_test = B_test.cpu().numpy()

    print(A_test.shape)
    print(B_test.shape)

    sum = 0
    sum_l = 0
    pos = np.zeros(A_train[0].shape[0])
    for i in range(A_train.shape[0]):
        print("train data " + str(i))
        A = A_train[i]
        B = B_train[i]
        n = A.shape[0]
        d = A.shape[1]
        u, sigma, v = np.linalg.svd(A_train[i], full_matrices = 0)

        sum = 0

        for j in range(n):
            tmp = u[j, :]
            sum += np.linalg.norm(tmp) ** 2
            if np.linalg.norm(tmp) ** 2 >= 5 * d / n:
                pos[j] += 1
    print(pos)

    np.save("ghg_position", pos)

    for i in range(3):
        # SaveSketch_electric(90, 370, 27, str(i))
        SaveSketch_ghg(84, 327, 25, str(i))