import numpy as np
import tensorflow as tf
import pandas as pd
import datalib

def get_data(data):
    if data == 'Adult':
        X_train = np.load(
            '/longterm/XXXX/ensemble_fairness/data/adult_X_train_norm.npy',
            allow_pickle=False).astype('float32')
        y_train = np.load(
            '/longterm/XXXX/ensemble_fairness/data/adult_y_train_norm.npy',
            allow_pickle=False).astype('float32')
        X_test = np.load(
            '/longterm/XXXX/ensemble_fairness/data/adult_X_test_norm.npy',
            allow_pickle=False).astype('float32')
        y_test = np.load(
            '/longterm/XXXX/ensemble_fairness/data/adult_y_test_norm.npy',
            allow_pickle=False).astype('float32')

        outliers = np.load(
            '/longterm/XXXX/ensemble_fairness/data/adult_outliers.npy',
            allow_pickle=False)

        X_train_out = X_train[outliers]
        y_train_out = y_train[outliers]

    elif data == 'GermanCredit':
        X_train = np.load(
            '/longterm/XXXX/ensemble_fairness/data/german_X_train.npy',
            allow_pickle=False).astype('float32')
        y_train = np.load(
            '/longterm/XXXX/ensemble_fairness/data/german_y_train.npy',
            allow_pickle=False).astype('float32')
        X_test = np.load(
            '/longterm/XXXX/ensemble_fairness/data/german_X_test.npy',
            allow_pickle=False).astype('float32')
        y_test = np.load(
            '/longterm/XXXX/ensemble_fairness/data/german_y_test.npy',
            allow_pickle=False).astype('float32')

        outliers = np.load(
            '/longterm/XXXX/ensemble_fairness/data/new_outliers_gc.npy',
            allow_pickle=False)

        X_train_out = X_train[outliers]
        y_train_out = y_train[outliers]

    elif data == 'fmnist':
        (X_train,
         y_train), (X_test,
                    y_test) = tf.keras.datasets.fashion_mnist.load_data()
        outliers = [
            473, 662, 745, 1414, 1490, 1826, 1878, 2182, 3307, 4287, 4869,
            5608, 6215, 6546, 6561, 7122, 7384, 7812, 7944, 8278, 9313, 9359,
            9573, 10360, 10667, 10810, 12476, 15823, 16012, 16865, 17680,
            18104, 18397, 18481, 18916, 19049, 21019, 21050, 21540, 21587,
            22791, 23214, 24282, 24456, 24491, 24609, 25766, 25868, 26497,
            26605, 27566, 27997, 29312, 29526, 29564, 29715, 29759, 30153,
            30207, 30820, 31144, 32355, 32496, 33160, 34510, 35477, 36316,
            36623, 37286, 39351, 40365, 41151, 41257, 42842, 43755, 43777,
            43805, 44076, 44261, 44959, 45565, 46501, 46725, 46760, 47195,
            47201, 48055, 49015, 49065, 49623, 50102, 50178, 51612, 52401,
            52749, 55062, 55280, 57585, 59175, 59876
        ]

        X_train_out = X_train[outliers]
        y_train_out = y_train[outliers]

        X_train = X_train.astype('float32') / 255.
        X_test = X_test.astype('float32') / 255.

    elif data == 'mnist':
        (X_train, y_train), (X_test,
                             y_test) = tf.keras.datasets.mnist.load_data()
        X_train_out = None
        y_train_out = None

        X_train = X_train.astype('float32') / 255.
        X_test = X_test.astype('float32') / 255.

    elif data == 'Seizure':
        X_train = np.load(
            '/longterm/XXXX/ensemble_fairness/data/seizure_X_train.npy',
            allow_pickle=False).astype('float32')
        y_train = np.load(
            '/longterm/XXXX/ensemble_fairness/data/seizure_y_train.npy',
            allow_pickle=False).astype('float32')
        X_test = np.load(
            '/longterm/XXXX/ensemble_fairness/data/seizure_X_test.npy',
            allow_pickle=False).astype('float32')
        y_test = np.load(
            '/longterm/XXXX/ensemble_fairness/data/seizure_y_test.npy',
            allow_pickle=False).astype('float32')

        outliers = np.load(
            '/longterm/XXXX/ensemble_fairness/data/seizure_outliers.npy',
            allow_pickle=False)
        outliers = np.sort(outliers)

        X_train_out = X_train[outliers]
        y_train_out = y_train[outliers]
    elif data == 'TaiwaneseCredit':
        X_train = np.load(
            '/longterm/XXXX/ensemble_fairness/data/X_train_taiwanese.npy',
            allow_pickle=False).astype('float32')
        y_train = np.load(
            '/longterm/XXXX/ensemble_fairness/data/y_train_taiwanese.npy',
            allow_pickle=False).astype('float32')
        X_test = np.load(
            '/longterm/XXXX/ensemble_fairness/data/X_test_taiwanese.npy',
            allow_pickle=False).astype('float32')
        y_test = np.load(
            '/longterm/XXXX/ensemble_fairness/data/y_test_taiwanese.npy',
            allow_pickle=False).astype('float32')

        X_train_out = None
        y_train_out = None
    elif data == 'Warafin':
        X_train = np.load(
            '/longterm/XXXX/ensemble_fairness/data/X_train_warafin.npy',
            allow_pickle=False).astype('float32')
        y_train = np.load(
            '/longterm/XXXX/ensemble_fairness/data/y_train_warafin.npy',
            allow_pickle=False).astype('float32')
        X_test = np.load(
            '/longterm/XXXX/ensemble_fairness/data/X_test_warafin.npy',
            allow_pickle=False).astype('float32')
        y_test = np.load(
            '/longterm/XXXX/ensemble_fairness/data/y_test_warafin.npy',
            allow_pickle=False).astype('float32')

        X_train_out = None
        y_train_out = None

    elif data == 'HELOC':
        X_train = np.load(
            '/longterm/XXXX/ensemble_fairness/data/X_train_heloc.npy',
            allow_pickle=False).astype('float32')
        y_train = np.load(
            '/longterm/XXXX/ensemble_fairness/data/y_train_heloc.npy',
            allow_pickle=False).astype('float32')
        X_test = np.load(
            '/longterm/XXXX/ensemble_fairness/data/X_test_heloc.npy',
            allow_pickle=False).astype('float32')
        y_test = np.load(
            '/longterm/XXXX/ensemble_fairness/data/y_test_heloc.npy',
            allow_pickle=False).astype('float32')

        X_train_out = None
        y_train_out = None
    
    elif data == 'Breast':

        outliers=np.array([418,  75, 176,  30, 357, 347, 154, 153, 415, 157, 408,  70, 195,
       329,  39, 141, 268,  72,   9, 275, 281, 255, 184,  55, 140, 193,
       145, 403, 413, 262, 126, 249, 299,  93, 421,  77, 341, 314, 317,
       422, 218, 148,  78,  73,  33, 116, 294,  76, 237, 373,  15, 229,
       350,   0,  19, 272, 280,  56, 336,  82, 239, 368,  79,  90, 361,
        25, 173, 180, 175,  42, 209, 414, 104, 165, 417, 274,  22,  46,
       376, 113,  94, 375,  57, 124,  24,  17,  66, 132, 222,  31,  84,
       291, 211, 298, 412, 244, 386, 351,   5,  45])


        breast_data= datalib.BreastCancer(
            split=datalib.splitters.Split(tr=3,te=1, seed=0 ))
    

        X_train = breast_data.x_tr.astype('float32')
        y_train = breast_data.y_tr.astype('float32')
        X_test = breast_data.x_te.astype('float32')
        y_test = breast_data.y_te.astype('float32')

        X_train_out = None
        y_train_out = None
    
    elif data == 'Pima':

        outliers=np.array([533, 544,  41, 148, 111, 293, 435, 407, 311,  23, 242, 185, 521,
       496, 566, 301, 375, 200, 599, 306, 132, 559,   0, 135, 195, 386,
       563, 201, 520, 518, 207,  13, 120, 218,  59, 292, 160, 463, 602,
       402, 605, 257, 236, 284,  47,  85,  82, 307, 488, 173, 382, 500,
       422, 611, 273, 477,  65, 179, 172,  34,  49, 147, 392, 576,  66,
       415,  90, 214, 410, 557, 558, 472, 531,  62, 165, 383, 510, 608,
       421,  17, 162,  60, 119, 426,  88, 512, 462, 334,  29, 277, 346,
       107, 528, 223, 574, 425,  92, 295, 484, 260])


        pima=pd.read_csv('/longterm/XXXX/repos/data-unsorted/pima-indians-diabetes/pima-indians-diabetes.csv')

        pima_data=datalib.CustomData("pima", pima.values[:,0:8], pima.values[:,8], 
                             processors=["normalize"], split=datalib.splitters.Split(tr=4, te=1, seed=0))
    

        X_train = pima_data.x_tr.astype('float32')
        y_train = pima_data.y_tr.astype('float32')
        X_test = pima_data.x_te.astype('float32')
        y_test = pima_data.y_te.astype('float32')

        X_train_out = None
        y_train_out = None

    elif data == 'CTG':

        outliers=np.array([1128, 1315, 1237,   56, 1086,  111,  706,  825, 1082, 1103, 1123,
       1131, 1233, 1046,  285,  583, 1678,  747,  195,  367,  342,   47,
       1000,  783, 1603, 1097,  566,  784,  893,    6, 1557,   40, 1322,
        915,  737, 1115, 1484,  987, 1318,  832,  882, 1329,  736, 1389,
       1166,    4, 1059,  960, 1266,  598, 1496,  927,  108, 1012, 1422,
        563,  785, 1262,  120,  544, 1634, 1327,  614,  107, 1139,  312,
        145,  314, 1144, 1628,  936,  869,  313,  636, 1182,  972, 1360,
        557, 1453,  695, 1152, 1349,  415,  654, 1235,  738,  242, 1417,
        480, 1337,  839, 1522, 1369,  421,  527,  126,  996,  258, 1624,
        644])


        ctg=pd.read_csv('/longterm/XXXX/repos/ensemble_fairness/data/ctg_data.csv')
        ctg_data=datalib.CustomData("ctg", ctg.values[:,0:21], (ctg.values[:,21]>1).astype('int'), 
                             processors=["normalize"], split=datalib.splitters.Split(tr=4, te=1, seed=0))
        X_train = ctg_data.x_tr.astype('float32')
        y_train = ctg_data.y_tr.astype('float32')
        X_test = ctg_data.x_te.astype('float32')
        y_test = ctg_data.y_te.astype('float32')

        X_train_out = None
        y_train_out = None

    elif data == 'Thyroid':

        outliers=np.array([ 599, 1201,  628, 1642, 1263,  931,   23,  844,  964,  764, 1483,
       1172,  344,  413,  494,  298,  529, 1651, 1190, 1648,  548,  371,
       1340,  736,  254,  829,  479,  297, 1193,  602,  940,  352,  173,
       1530, 1078, 1635, 1617, 1673,  834, 1433, 1490,  861, 1506,  700,
       1091, 1456, 1244, 1418,  898, 1657,  188,  818,  943,  979,  481,
       1421, 1027, 1709,  570, 1255,  462,  724,  115,  990, 1261, 1594,
        522,  557,  727, 1308, 1220,  582,  135,  495, 1094, 1564,  210,
        621,  759,   49, 1205, 1161,  846, 1198, 1265,  916, 1406,  398,
       1672,  998, 1357, 1696,  966,   78, 1385, 1461,  527,  438,  651,
       1455])


        thy=pd.read_csv('/longterm/XXXX/repos/ensemble_fairness/data/ann-all.csv')
        thy_undersampled= np.load(
            "/longterm/XXXX/repos/ensemble_fairness/data/thyroid_undersampling_idx.npy")
        thy_data=datalib.CustomData(
            "thy", thy.values[:,0:21][thy_undersampled],
            (thy.values[:,21][thy_undersampled]<3).astype('int'), 
                             processors=["normalize"], 
                             split=datalib.splitters.Split(tr=4, te=1, seed=0))

        X_train = thy_data.x_tr.astype('float32')
        y_train = thy_data.y_tr.astype('float32')
        X_test = thy_data.x_te.astype('float32')
        y_test = thy_data.y_te.astype('float32')

        X_train_out = None
        y_train_out = None
    else:
        raise ValueError(f"{data} is not found")

    return (X_train, y_train), (X_test, y_test), (X_train_out, y_train_out)


def remove_None(L):
    m = []
    for l in L:
        if l is not None:
            m.append(l)
    return m


def search_z_adv(vae,
                 classifier,
                 x,
                 latent_representations,
                 y_sparse,
                 epsilon=0.141,
                 steps=100,
                 num_samples=100,
                 direction='random',
                 p=2,
                 batch_size=128,
                 transform=None,
                 return_probits=True,
                 detemintristic=True, 
                 **kwargs):
    def get_random_delta(d, z_0, num_samples, epsilon, seed=None):
        if d == 'random':
            # create an intial direction
            delta = tf.random.normal((num_samples, ) + z_0.shape, seed=seed)
            delta /= tf.norm(tf.keras.backend.batch_flatten(delta),
                             ord=p)  # unit vector with length = 1
            random_norms = tf.random.uniform((num_samples, 1),
                                             minval=0,
                                             maxval=epsilon,
                                             seed=seed)
            delta *= random_norms
        else:
            raise NotImplementedError(
                f"The {d} direction has not been impletmented yet.")

        return tf.ones([num_samples, 1]) * z + delta

    if detemintristic:
        seed = 2021
    else:
        seed = None

    reconst_x = vae.decode(latent_representations, batch_size=batch_size)
    reconst_x_pred = np.argmax(
        classifier.predict(reconst_x, batch_size=batch_size), -1)



    pb = tf.keras.utils.Progbar(target=latent_representations.shape[0], stateful_metrics=['ave_iter', 'success_rate'])
    counterfactuals = []
    for i, z in enumerate(latent_representations):

        if isinstance(z, np.ndarray):
            z = tf.constant(z)

        # If the prediction is already different, dont bother searching
        if reconst_x_pred[i] != y_sparse[i]:
            counterfactuals.append(reconst_x[i][None, :])
            j = 0

        else:
            reconst = reconst_x[:1].copy()
            for j in range(steps):
                batch_z_candidates = get_random_delta(direction,
                                                      z,
                                                      num_samples,
                                                      epsilon,
                                                      seed=seed)
                reconst_x_from_z_tilde = vae.decode(batch_z_candidates,
                                                    batch_size=batch_size)

                if transform is not None:
                    # VAe/AE is trained on data sclae [0, 1]. We need to map the counterfactuals back into the
                    # the actual data range.
                    reconst_x_from_z_tilde = transform(reconst_x_from_z_tilde)

                new_pred = np.argmax(
                    classifier.predict(reconst_x_from_z_tilde,
                                       batch_size=batch_size), -1)
                valid_idx = new_pred != y_sparse[i]
                num_valid = np.sum(valid_idx)
                if num_valid > 0:
                    valid_reconst = reconst_x_from_z_tilde[valid_idx]
                    # use the closest decode(z)
                    diff = valid_reconst - x[i]
                    flatten_diff = tf.keras.backend.batch_flatten(diff)
                    flatten_diff_norm = tf.norm(flatten_diff, axis=-1)
                    closest_id = tf.argmin(flatten_diff_norm)
                    reconst = valid_reconst[closest_id][None, :]
                    break

            counterfactuals.append(reconst)
        
        success = (np.sum(counterfactuals[-1]) > 0) * 1.0
        pb.add(1, [('avg_iter', j), ('success_rate', success)])

    counterfactuals = np.vstack(counterfactuals)
    counterfactuals_pred_logits = classifier.predict(counterfactuals, batch_size=batch_size)
    is_adv = y_sparse != np.argmax(counterfactuals_pred_logits, axis=-1)
    adv_x = counterfactuals[is_adv]
    y_pred_adv = counterfactuals_pred_logits[is_adv]
    if not return_probits:
        y_pred_adv = np.argmax(y_pred_adv, axis=-1)

    return adv_x, y_pred_adv, is_adv
