import numpy as np
import tensorflow as tf


from utils import fitCB, Batcher
from embedder import Embedder
import tensorflow.keras.datasets as kdds

def keras_image_format_to_std(data):
    """ converts 0-255 int range data to -1 to 1 float data """
    data = (data.astype(np.float32) - 127.5) / 127.5
    if len(data.shape) == 3:
        data = np.expand_dims(data, axis=3)
    return data

    
    

def randomize_labels(target, rnd_frac, Nclasses):
    
    rng = np.random.default_rng()
    
    n = target.shape[0]
    change_idxs = np.where(rng.random(size = n) < rnd_frac)[0]
    
    new_target = target.copy()
    
    new_target[change_idxs] = rng.integers(0,Nclasses,size = change_idxs.shape[0])
    
    return new_target
    
    
    

import argparse

if __name__ == '__main__':


    parser=argparse.ArgumentParser(description='experiment runner')
    parser.add_argument('K', type = int)
    parser.add_argument('activation', choices=['sigmoid','linear','relu'])
    parser.add_argument('run_id', type = str)
    parser.add_argument('--epoch_cnt', type = int, nargs='?', const=180, default=180) 
    parser.add_argument('--no_save', action="store_true") 
    parser.add_argument('-intermediate_layers', nargs='+', type = int)
    parser.add_argument('--l2inf_reg', type = float, default= float(0.)) 
    parser.add_argument('--l2_reg', type = float, default= float(0.)) 
    parser.add_argument('--show_picture', action="store_true") 
    
    args=parser.parse_args()

    if args.intermediate_layers is None:
        args.intermediate_layers = []

    intermediate_layers_str = 'x' 
    for s in args.intermediate_layers:
        intermediate_layers_str = intermediate_layers_str + str(s) + 'x'


    save_file = f'mnist_run_{args.K}_{args.activation}_{args.epoch_cnt}_{args.run_id}_{intermediate_layers_str}.npz'


    print(save_file)    
        
    
    (X_train, y_train), (X_test, y_test) = kdds.mnist.load_data()
    ntmp = np.prod(X_train.shape[1:])
    X_train = keras_image_format_to_std(X_train).reshape(-1,ntmp)
    X_test = keras_image_format_to_std(X_test).reshape(-1,ntmp)

    Nclasses = 10
    
    
    y_train_orig = y_train
    y_test_orig = y_test
    #rnd_frac = .2
    #y_train = randomize_labels(y_train,rnd_frac,Nclasses)
    #y_test = randomize_labels(y_test,rnd_frac,Nclasses)

    
    print('Done Loading...')
    
    epoch_cnt = args.epoch_cnt
                    
    D = X_train.shape[1]
    K = args.K
    
    Bx = 250
    By = 250
    
    
    learning_rate = 1e-4   #Adam preferred lr here #original paper rate
    #learning_rate = 1e-3   #Adam preferred lr here

    pairwise_cost_pos_multiplier = Nclasses - 1
    
    neg_upper_threshold = 1.


    if args.activation == 'sigmoid':
        out_activation = tf.sigmoid
    elif args.activation == 'linear':
        out_activation = lambda x: x
    elif args.activation == 'relu':
        out_activation = tf.nn.relu    
    else:
        assert False,'unknown activation'
    

    l2_infty_reg = args.l2inf_reg


    print(f'l2_infty_reg: {l2_infty_reg}')



    l2_reg = args.l2_reg


    print(f'l2_reg: {l2_reg}')
    
    
    distance_type = Embedder.DISTANCE_TYPE_L2_SQ_UNNORMED

    intermediate_layer_sizes = args.intermediate_layers


    RND_SEED1 = np.random.RandomState(seed=6765)
    RND_SEED2 = np.random.RandomState(seed=5851)

    train_eval_batcher = Batcher(500, X_train, y_train,rnd_seed = RND_SEED1)
    test_eval_batcher = Batcher(500, X_test, y_test,rnd_seed = RND_SEED1)    
    n_eval_batches = 100
        
    
    fit_cb = fitCB(
                   verbose_period = 1, 
                   log_period = 1,
                   batch_cnt_to_cb = 100,
                   train_eval_batcher = train_eval_batcher,
                   test_eval_batcher = test_eval_batcher,
                   n_eval_batches = n_eval_batches,
                   eval_period = 1
                  )


    M = Embedder(D,K,Bx = None,learning_rate = learning_rate,
                         n_classes = Nclasses,
                         fit_Bx = Bx,
                         fit_By = By,
                         fit_cb = fit_cb,
                         fit_epochs = epoch_cnt,
                         pairwise_cost_pos_multiplier = pairwise_cost_pos_multiplier,
                         neg_upper_threshold = neg_upper_threshold,
                         out_activation = out_activation,
                         l2_reg = l2_reg,
                         l2_infty_reg = l2_infty_reg,
                         distance_type = distance_type,
                         intermediate_layer_sizes = intermediate_layer_sizes
                        )


    
    M.fit(X_train, y_train)


    if args.no_save is not True:
        res_weights_lst = M.get_weights()

        """
        This saves the lists of various statistics from the run. 
        
        The last computed loss value on train set is: fit_cb.fit_cb.train_eval_lst[-1][0]
        The last computed same class codnitioned loss on train set is: fit_cb.fit_cb.train_eval_lst[-1][1]        
        The last computed different class codnitioned loss on train set is: fit_cb.fit_cb.train_eval_lst[-1][2]
        
        For the test set, the values would be fit_cb.fit_cb.train_eval_lst[-1][0], fit_cb.fit_cb.train_eval_lst[-1][1] and fit_cb.fit_cb.train_eval_lst[-1][2], respectively. 

        
        See utils.py fitCB class for more details. 
        
        """

        res_weights_tup = tuple(a for t in res_weights_lst for a in t)
        
        save_tup = res_weights_tup + (fit_cb.loss_lst,
                                      fit_cb.val_lst,
                                      fit_cb.train_eval_lst,
                                      fit_cb.test_eval_lst
                                     )
        
        with open(save_file,'wb') as f:            
            np.savez(save_file,*save_tup)
    
    


    if args.show_picture is True:    
        zz_batcher = Batcher(2000, X_test, y_test,rnd_seed = RND_SEED2)
        #zz_batcher = Batcher(2000, X_train, y_train,rnd_seed = RND_SEED2)
        zz_X,zz_Y,_ = zz_batcher.get_batch()

        
        transform_X = M.sess.run(M.t_partition_sigmoids_per_x, feed_dict = {M.t_x_ph: zz_X})    
            
        from sklearn.manifold import TSNE
        tsne = TSNE(n_components=2, perplexity = 30, metric = 'euclidean') 
        pca_transform_X = tsne.fit_transform(transform_X)
        
            
        from utils import jitter 
        j_pca_transform_X = jitter(pca_transform_X,1.)    
        
        import matplotlib.pylab as pl
        pl.figure()
        for i in range(Nclasses):        
            lidx = zz_Y == i
            pl.plot(j_pca_transform_X[lidx,0],j_pca_transform_X[lidx,1],'o',label = str(i))

        pl.legend()
        pl.show()





