import numpy as np
import scipy.linalg as la
import time
import os, sys
from absl import app
from absl import flags, logging
from methods import *
from sklearn.cluster import KMeans
from sklearn.metrics import pairwise_distances_argmin_min
FLAGS = flags.FLAGS

def main(argv):

    # load data
    X = np.load(f'data/{FLAGS.data}_X.npy')
    Y = np.load(f'data/{FLAGS.data}_Y.npy')
    (n,d) = X.shape
    c = 10
    b = FLAGS.b
    if FLAGS.data =='imagenet50': c = 50

    # initialize
    if FLAGS.data=='cifar10':
        idx_sel = np.arange(10)
    else:
        idx_sel = np.empty(c, dtype='int')
        for i in range(c):
            np.random.seed(i)
            x = np.where(Y == i)[0]
            idx_sel[i] = np.random.choice(x, size=1)
    pool = np.delete(np.arange(n), idx_sel)
    n_pool = pool.shape[0]

    # run
    if FLAGS.method == 'FIRAL':
        idx = FIRAL(X, Y, idx_sel, c, b, FLAGS.eta, FLAGS.tol) 

    if FLAGS.method == 'random':
        np.random.seed(FLAGS.seed)
        sel_ = np.random.choice(np.arange(len(pool)), b)
        idx = np.concatenate((idx_sel, pool[sel_]))
    if FLAGS.method == 'kmeans':
        km = KMeans(n_clusters = b, n_init=30, random_state = FLAGS.seed).fit(X[pool])
        sel_, _ = pairwise_distances_argmin_min(km.cluster_centers_, X[pool])
        idx = np.concatenate((idx_sel, pool[sel_]))
    if FLAGS.method == 'entropy' or FLAGS.method=='varratios':
        idx = uncertainty(X,Y, idx_sel, b, FLAGS.method)

    if FLAGS.method == 'BAIT':
        idx = BAIT(X,Y, idx_sel, c, b)

    # accuracy with labeled samples
    cl = LogisticRegression(penalty = 'l2',class_weight='balanced',random_state = 0, fit_intercept=False, multi_class="multinomial", max_iter = 1000).fit(X[idx], Y[idx])
    pre_model = cl.predict(X)
    accuracy = cal_acc(pre_model, Y)
    print("\n\n =================logistic regression===============")
    print("method:", FLAGS.method)
    print("total number of labeled samples:", len(idx))
    print("accuracy:", accuracy)

    return



if __name__ == '__main__':
    flags.DEFINE_string('data', 'mnist', 'dataset: mnist/cifar10/imagenet50')
    flags.DEFINE_string('method', 'FIRAL', 'active learning method: FIRAL/random/kmeans/entropy/varratios/BAIT')
    flags.DEFINE_integer('b', 10, 'sample budget')
    flags.DEFINE_float('eta', 200.0, 'learning rate in FIRAL')
    flags.DEFINE_float('tol', 1.e-8, 'tolerance for termination in solving relaxed problem in FIRAL')
    flags.DEFINE_integer('seed', 0, 'random number seed for random/kmeans sampling')
    app.run(main)



