import numpy as np
import tensorflow as tf


from utils import fitCB, Batcher
from embedder import Embedder



def load_newsgroups(restrict_to_labels = None):

    from sklearn.model_selection import KFold
    from sklearn.datasets import fetch_20newsgroups
    from sklearn.feature_extraction.text import CountVectorizer
    import sklearn
    
    
    voc_size = 15000
    print('Fetching data...')
    newsgroups_fetch_res = fetch_20newsgroups(remove=('headers', 'footers', 'quotes'), subset = 'all',
                                              #download_if_missing = False
    )
    newsgroups_data = newsgroups_fetch_res.data
    labels = newsgroups_fetch_res.target     
    label_names = newsgroups_fetch_res.target_names

    
    tf_vectorizer = CountVectorizer(max_df=0.95, min_df=2,
                                    max_features = voc_size,
                                    stop_words='english')
    features = tf_vectorizer.fit_transform(newsgroups_data)
    #dictionary = dict([(j,i) for (i,j) in tf_vectorizer.vocabulary_.iteritems()])
    dictionary = dict([(j,i) for (i,j) in tf_vectorizer.vocabulary_.items()])
    feature_names = [dictionary[i] for i in range(voc_size)]
    
    
    print('Dictionary size: {}, n_samples: {}'.format(len(dictionary),features.shape[0]))
    
    features = sklearn.preprocessing.normalize(features, axis = 1, copy = False).toarray()
    #features = sklearn.preprocessing.normalize(features, axis = 0, copy = False).toarray()

    if restrict_to_labels is not None:
        features = features[labels< restrict_to_labels]
        labels = labels[labels< restrict_to_labels]
        label_names = label_names[:restrict_to_labels]


    Nclasses = len(label_names)

    rs = 5532
    cv = KFold(5,shuffle = True, random_state = rs)
    #cv = KFold(5,shuffle = True)
    
    train_index, test_index = next(cv.split(features))
    
    X_train, X_test = features[train_index], features[test_index]
    y_train, y_test = labels[train_index], labels[test_index]


    return X_train,y_train,X_test,y_test, feature_names, Nclasses, feature_names, label_names




import argparse
if __name__ == '__main__':


    parser=argparse.ArgumentParser(description='experiment runner')
    parser.add_argument('K', type = int)
    parser.add_argument('activation', choices=['sigmoid','linear','relu'])
    parser.add_argument('run_id', type = str)
    parser.add_argument('--epoch_cnt', type = int, default=180) 
    parser.add_argument('--Bx', type = int,  default=100) 
    parser.add_argument('--By', type = int,  default=500) 
    parser.add_argument('--no_save', action="store_true") 
    parser.add_argument('--l2inf_reg', type = float, default= float(0.)) 
    parser.add_argument('--l2_reg', type = float, default= float(0.)) 
    parser.add_argument('--show_picture', action="store_true") 
    parser.add_argument('-intermediate_layers', nargs='+', type = int)    
    args=parser.parse_args()

    
    if args.intermediate_layers is None:
        args.intermediate_layers = []

    intermediate_layers_str = 'x' 
    for s in args.intermediate_layers:
        intermediate_layers_str = intermediate_layers_str + str(s) + 'x'


    save_file = f'newsgroup_run_{args.K}_{args.activation}_{args.epoch_cnt}_{args.run_id}_{intermediate_layers_str}.npz'


    print(save_file)    
        
    
    X_train,y_train,X_test,y_test, feature_names, Nclasses, feature_names, label_names = load_newsgroups(10)
    
    print('Done Loading...')
    

    epoch_cnt = args.epoch_cnt
                    
    D = X_train.shape[1]
    K = args.K
    
    Bx = args.Bx
    By = args.By
    
    print(f'Bx:{Bx}, By:{By}')
    
    
    learning_rate = 1e-2   #Adam preferred lr here

    pairwise_cost_pos_multiplier = Nclasses - 1

    neg_upper_threshold = 1.
    
    
    if args.activation == 'sigmoid':
        out_activation = tf.sigmoid
    elif args.activation == 'linear':
        out_activation = lambda x: x
    elif args.activation == 'relu':
        out_activation = tf.nn.relu    
    else:
        assert False,'unknown activation'
    


    l2_infty_reg = args.l2inf_reg


    print(f'l2_infty_reg: {l2_infty_reg}')



    l2_reg = args.l2_reg


    print(f'l2_reg: {l2_reg}')



    distance_type = Embedder.DISTANCE_TYPE_L2_SQ_UNNORMED

    intermediate_layer_sizes = args.intermediate_layers


    RND_SEED1 = np.random.RandomState(seed=6765)
    RND_SEED2 = np.random.RandomState(seed=5851)


    train_eval_batcher = Batcher(min(100,X_train.shape[0]), X_train, y_train,rnd_seed = RND_SEED1)
    test_eval_batcher = Batcher(min(100,X_test.shape[0]), X_test, y_test,rnd_seed = RND_SEED1)    
    n_eval_batches = 500


    log_period = 1
    eval_period = 10
    
    fit_cb = fitCB(
                   verbose_period = 1, 
                   log_period = log_period,
                   train_eval_batcher = train_eval_batcher,
                   test_eval_batcher = test_eval_batcher,
                   n_eval_batches = n_eval_batches,
                   eval_period = eval_period,
                   #do_knn_test = 5
                  )


    M = Embedder(D,K,Bx = None,learning_rate = learning_rate,
                         n_classes = Nclasses,
                         fit_Bx = Bx,
                         fit_By = By,
                         fit_cb = fit_cb,
                         fit_epochs = epoch_cnt,
                         pairwise_cost_pos_multiplier = pairwise_cost_pos_multiplier,
                         neg_upper_threshold = neg_upper_threshold,
                         out_activation = out_activation,
                         l2_reg = l2_reg,
                         l2_infty_reg = l2_infty_reg,
                         distance_type = distance_type,
                         intermediate_layer_sizes = intermediate_layer_sizes                         
                        )

    
    M.fit(X_train, y_train)


    if args.no_save is not True:
        res_weights_lst = M.get_weights()

        """
        This saves the lists of various statistics from the run. 
        
        The last computed loss value on train set is: fit_cb.fit_cb.train_eval_lst[-1][0]
        The last computed same class codnitioned loss on train set is: fit_cb.fit_cb.train_eval_lst[-1][1]        
        The last computed different class codnitioned loss on train set is: fit_cb.fit_cb.train_eval_lst[-1][2]
        
        For the test set, the values would be fit_cb.fit_cb.train_eval_lst[-1][0], fit_cb.fit_cb.train_eval_lst[-1][1] and fit_cb.fit_cb.train_eval_lst[-1][2], respectively. 

        
        See utils.py fitCB class for more details. 
        
        """

        res_weights_tup = tuple(a for t in res_weights_lst for a in t)
        
        save_tup = res_weights_tup + (fit_cb.loss_lst,
                                      fit_cb.val_lst,
                                      fit_cb.train_eval_lst,
                                      fit_cb.test_eval_lst
                                     )
        
        with open(save_file,'wb') as f:            
            np.savez(save_file,*save_tup)
    
    

    
    
    if args.show_picture is True:
    
        #zz_X,zz_Y = X_test, y_test
        zz_batcher = Batcher(2000, X_train, y_train,rnd_seed = RND_SEED2)
        zz_X,zz_Y,_ = zz_batcher.get_batch()


        
        transform_X = M.sess.run(M.t_partition_sigmoids_per_x, feed_dict = {M.t_x_ph: zz_X})    
            
        from sklearn.manifold import TSNE
        tsne = TSNE(n_components=2, perplexity = 30, metric = 'euclidean') 
        pca_transform_X = tsne.fit_transform(transform_X)
        
            
        from utils import jitter 
        j_pca_transform_X = jitter(pca_transform_X,1.)    
        
        import matplotlib.pylab as pl
        pl.figure()
        for i in range(Nclasses):        
            lidx = zz_Y == i
            pl.plot(j_pca_transform_X[lidx,0],j_pca_transform_X[lidx,1],'o',label = str(i))

        pl.legend()
        pl.show()





