from submodlib.functions.facilityLocationMutualInformation import FacilityLocationMutualInformationFunction
import pyspark.sql.functions as F
from pyspark.ml.feature import BucketedRandomProjectionLSH
from pyspark.ml.linalg import Vectors
from pyspark.sql.functions import col
from pyspark.sql import SparkSession
from pyspark.ml.clustering import KMeans
from pyspark.ml.evaluation import ClusteringEvaluator
from pyspark.sql.types import *
import numpy as np
import sys

spark = SparkSession.builder.getOrCreate()

# Read train
train_centroids = spark.read.parquet(sys.argv[1]).select("centroid")
train = np.array([each[0] for each in train_centroids.collect()])

# Read X
X_centroids = spark.read.parquet(sys.argv[2]).select("centroid")
X = np.array([each[0] for each in X_centroids.collect()])

# Optimize
obj = FacilityLocationMutualInformationFunction(len(train), len(X), data=train, queryData=X)
greedyList = obj.maximize(budget=sys.argv[3], optimizer="LazyGreedy", stopIfZeroGain=True, stopIfNegativeGain=True, verbose=False, show_progress=True)

# Get lists
print()
print(len(greedyList))
print(len(train))
arr = []
for each in greedyList:
    arr += [train[each[0]]]

# Write out
spark.createDataFrame(data=[(i, each.tolist()) for i, each in enumerate(arr)], schema=["id", "centroid"]).repartition(1).write.mode("overwrite").parquet(sys.argv[4])
