import tensorflow as tf
gpu=2
gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_visible_devices(gpus[gpu], 'GPU')
device = gpus[gpu]

for device in tf.config.experimental.get_visible_devices('GPU'):
    tf.config.experimental.set_memory_growth(device, True)


# Stop tensorflow from doing its warning logging.
import logging
logging.getLogger('tensorflow').disabled = True

import warnings
warnings.filterwarnings('ignore')

import os
import tempfile
import shutil
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['TF_DETERMINISTIC_OPS'] = '1'

import numpy as np
#import tensorflow as tf
import tensorflow.keras.backend as K
import random as pyrand

K.set_image_data_format('channels_last')

from luf.luf_experiment import LUFExperiment
from luf.rs_experiment import string_to_architecture, RSExperiment
from tensorflow.keras.metrics import binary_accuracy


from datalib.splitters import Split
from datalib import CustomData, FashionMnist


def make_datalib(name, X_train, X_test, y_train, y_test, rat=None):
    if rat is None:
        rat=[len(X_train), len(X_test)]
    D = CustomData(
        name,
        np.concatenate((X_train, X_test)),
        np.concatenate((y_train, y_test)),
        Split(tr=rat[0], te=rat[1]))
    return D

X_train_hel=np.load("../data/X_train_heloc.npy")
y_train_hel=np.load("../data/y_train_heloc.npy")
X_test_hel=np.load("../data/X_test_heloc.npy")
y_test_hel=np.load("../data/y_test_heloc.npy")

#randomly generated points which will be left out from the training set for LOO-experiments
outliers=np.load("../data/heloc_outliers.npy")


HEL_Data=make_datalib("hel", X_train_hel, X_test_hel, y_train_hel, y_test_hel)
del X_train_hel, X_test_hel, y_train_hel, y_test_hel



LufExpHEL=LUFExperiment('100.32', HEL_Data, indices=outliers[0:100], num_trials=None, seed=42,
             multi=True, model_args=[], model_kwargs={},
            compile_kwargs={"loss": tf.keras.losses.CategoricalCrossentropy(from_logits=True),
        "metrics": "categorical_accuracy",
        "optimizer" : tf.keras.optimizers.Adam(1e-3)}, fit_args=[], fit_kwargs={"batch_size":128, "epochs":50}, 
            custom_objects={}, early_stop_cutoff=1.e-1)
info_HEL=LufExpHEL.run()
LufExpHEL.save()


RSExpHEL=RSExperiment('100.32', HEL_Data, num_trials=100, seed=42,
            multi=True, model_args=[], model_kwargs={},
            compile_kwargs={"loss": tf.keras.losses.CategoricalCrossentropy(from_logits=True),
        "metrics": "categorical_accuracy",
        "optimizer" : tf.keras.optimizers.Adam(1e-3)}, fit_args=[], fit_kwargs={"batch_size":128, "epochs":50}, 
            custom_objects={}, early_stop_cutoff=1.e-1)
info_HEL_RS = RSExpHEL.run()
RSExpHEL.save()


ctg=pd.read_csv('../data/ctg_data.csv')
ctg_data=datalib.CustomData("ctg2", ctg.values[:,0:21], (ctg.values[:,21]>1).astype('int'), 
                             processors=["normalize"], split=datalib.splitters.Split(tr=4, te=1, seed=0))

ctg_data_lip=datalib.CustomData("ctg_lip", ctg.values[:,0:21], (ctg.values[:,21]>1).astype('int'), 
                             processors=["normalize"], split=datalib.splitters.Split(tr=4, te=1, seed=0))

outliers_ctg=np.array([1128, 1315, 1237,   56, 1086,  111,  706,  825, 1082, 1103, 1123,
       1131, 1233, 1046,  285,  583, 1678,  747,  195,  367,  342,   47,
       1000,  783, 1603, 1097,  566,  784,  893,    6, 1557,   40, 1322,
        915,  737, 1115, 1484,  987, 1318,  832,  882, 1329,  736, 1389,
       1166,    4, 1059,  960, 1266,  598, 1496,  927,  108, 1012, 1422,
        563,  785, 1262,  120,  544, 1634, 1327,  614,  107, 1139,  312,
        145,  314, 1144, 1628,  936,  869,  313,  636, 1182,  972, 1360,
        557, 1453,  695, 1152, 1349,  415,  654, 1235,  738,  242, 1417,
        480, 1337,  839, 1522, 1369,  421,  527,  126,  996,  258, 1624,
        644, 80, 1371,  430, 1368, 1541, 1217, 1472,  192, 1259, 1175, 1229,
        262, 1403, 1500,  307,   65,  920, 1582, 1615,  501,  704,  605,
        206, 1151,  835, 1413,  319, 1509,  553, 1153,  371, 1028,  339,
       1677, 1083, 1091, 1170,   10, 1527,  169, 1332, 1592, 1321, 1533,
       1659, 1549, 1458, 1430,  115, 1396,  213, 1461,  907,  814,  624,
        587,  254,  483,  503,  119,  373,  789,  678,  116,  670,  188,
       1551, 1539,  487, 1249, 1416, 1380,  710, 1204,   70, 1616, 1675,
        510,  901, 1466,  602,  729,  836,  630, 1138,  202,   58,  475,
       1094,  402, 1124,  906, 1424, 1406, 1241,  554,  965, 1306, 1362,
       1277, 1098,   52, 1050,  751, 1465,  985,  666,  513,  858,   96,
       1635, 1343, 1232, 1642, 1425, 1005,  623, 1320,   32, 1398, 1525,
       1010,  979, 1376,  450,  214,  301,  809, 1221,   39,  845,  677,
       1426,  269,  440,  303,  444, 1071,  662, 1156,  759,  574,  536,
         85,  423,  392,  106,  322,  290, 1181, 1134, 1309,  773,  109,
        278,  465, 1469, 1176, 1605,  282,  518, 1187,  317,  607, 1324,
       1693, 1312,  436, 1054, 1072,  519,  971,  981,  672,  485,  235,
       1438, 1486,  300,  562, 1323,  687, 1451,  647, 1626, 1136, 1254,
        244,  764,  129, 1107, 1395, 1411,  272, 1015, 1478,  922,  180,
        327, 1226, 1294,  746, 1044, 1552, 1067,  428, 1596,  834, 1109,
        711, 1688, 1056, 1053,  349,  220, 1404,  112,  422,  273,  425,
         24,  216,  581,  734,  748, 1331,  431,  409, 1068,  464, 1594,
        249, 1287,  916,   21, 1529, 1555, 1273,  719,  693, 1297,  469,
        298,   55,    3,  281,  124,  151, 1073,  656,  663,  178, 1040,
       1648,  973,  865,   53,   99,  388, 1037, 1519,  847, 1378,  550,
        742,  284, 1142, 1530,  739, 1197,  457,  240,  653, 1495, 1148,
        556,  321,  975, 1561, 1045,  454, 1198, 1473, 1401,  223, 1553,
        627,   93,  776, 1112,  604,  928,  694,  320,  657, 1546,  909,
        952, 1467,  668,  448, 1629,  204, 1238, 1470,  645,  618,  978,
        838, 1035, 1118, 1065, 1511,  673,  564,  942,  157, 1651, 1288,
       1174,  740,   61,  153,  912, 1479,  659,  804,  995,  646,  344,
        210,  170,  239, 1620,  697, 1685, 1604,  350,  304,  944, 1330,
       1367, 1104,  769,   92,  853, 1108,  165,  603, 1434, 1158, 1335,
        854, 1600, 1574, 1556, 1077, 1437,  446,  130, 1690,  903,   88,
       1357,  684,  198, 1245,  546,  332, 1316, 1694,  143,  406, 1584,
         86, 1203,  514,  504, 1477, 1247, 1222,  552,  354, 1358, 1276,
        395, 1449,  381, 1679,  639,  807, 1459,  426, 1172, 1359, 1665,
       1445, 1283,   14,  615])

LufExp_CTG=LUFExperiment('100.32.16', ctg_data, indices=outliers_ctg, num_trials=None, seed=42,
            lip=False, multi=False, model_args=[], model_kwargs={},
            compile_kwargs={"loss": tf.keras.losses.BinaryCrossentropy(from_logits=True),
        "metrics": "accuracy",
        "optimizer" : tf.keras.optimizers.Adam(1e-3)}, fit_args=[], fit_kwargs={"batch_size":16, "epochs":100}, 
            custom_objects={}, early_stop_cutoff=1.e-1)
info_LufExp_CTG=LufExp_CTG.run()
LufExp_CTG.save()

RSExp_CTG=RSExperiment('100.32.16', ctg_data, indices=None, num_trials=500, seed=42,
            lip=False, multi=False, model_args=[], model_kwargs={},
            compile_kwargs={"loss": tf.keras.losses.BinaryCrossentropy(from_logits=True),
        "metrics": "accuracy",
        "optimizer" : tf.keras.optimizers.Adam(1e-3)}, fit_args=[], fit_kwargs={"batch_size":16, "epochs":100}, 
            custom_objects={}, early_stop_cutoff=1.e-1)
info_RSExp_CTG=RSExp_CTG.run()
RSExp_CTG.save()


X_train_gc=np.load("../data/german_X_train.npy")
y_train_gc=np.load("../data/german_y_train.npy")
X_test_gc=np.load("../data/german_X_test.npy")
y_test_gc=np.load("../data/german_y_test.npy")

outliers_gc=np.load("../data/outliers_gc.npy")

GC_Data=make_datalib("gc", X_train_gc, X_test_gc, y_train_gc, y_test_gc, rat=[4,1])
del X_train_gc, X_test_gc, y_train_gc, y_test_gc

LufExpGC=LUFExperiment('128.32.16', GC_Data, indices=outliers_gc[0:100], num_trials=None, seed=42,
            multi=True, model_args=[], model_kwargs={},
            compile_kwargs={"loss": tf.keras.losses.CategoricalCrossentropy(from_logits=True),
        "metrics": "categorical_accuracy",
        "optimizer" : tf.keras.optimizers.Adam(1e-3)}, fit_args=[], fit_kwargs={"batch_size":32, "epochs":100}, 
            custom_objects={}, early_stop_cutoff=1.e-1)
info_gc=LufExpGC.run()
LufExpGC.save()

RSExpGC=RSExperiment('128.32.16', GC_Data, num_trials=100, seed=42,
            multi=True,  model_args=[], model_kwargs={},
            compile_kwargs={"loss": tf.keras.losses.CategoricalCrossentropy(from_logits=True),
        "metrics": "categorical_accuracy",
        "optimizer" : tf.keras.optimizers.Adam(1e-3)}, fit_args=[], fit_kwargs={"batch_size":32, "epochs":100}, 
            custom_objects={}, early_stop_cutoff=1.e-1)
info_gc_RS = RSExpGC.run()
RSExpGC.save()

del outliers_gc

#seizure

X_train_sz=np.load("../data/seizure_X_train.npy")
y_train_sz=np.load("../data/seizure_y_train.npy")
X_test_sz=np.load("../data/seizure_X_test.npy")
y_test_sz=np.load("../data/seizure_y_test.npy")
outliers_sz=np.load("../data/seizure_outliers.npy")


SZ_Data=make_datalib("sz2", X_train_sz, X_test_sz, y_train_sz, y_test_sz)
del X_train_sz, X_test_sz, y_train_sz, y_test_sz


LufExpSZ=LUFExperiment('128.32.16', SZ_Data, indices=outliers_sz, num_trials=None, seed=42,
             multi=True, model_args=[], model_kwargs={},
            compile_kwargs={"loss": tf.keras.losses.CategoricalCrossentropy(from_logits=True),
        "metrics": "categorical_accuracy",
        "optimizer" : tf.keras.optimizers.Adam(1e-3)}, fit_args=[], fit_kwargs={"batch_size":32, "epochs": 100}, 
            custom_objects={}, early_stop_cutoff=1.e-1)
info_SZ=LufExpSZ.run()
LufExpSZ.save()

RSExpSZ=RSExperiment('128.32.16', SZ_Data, num_trials=100, seed=42,
             multi=True, model_args=[], model_kwargs={},
            compile_kwargs={"loss": tf.keras.losses.CategoricalCrossentropy(from_logits=True),
        "metrics": "categorical_accuracy",
        "optimizer" : tf.keras.optimizers.Adam(1e-3)}, fit_args=[], fit_kwargs={"batch_size":32, "epochs":100}, 
            custom_objects={}, early_stop_cutoff=1.e-1)
info_SZ_RS=RSExpSZ.run()
RSExpSZ.save()

#TAIWANESE

X_train_tai=np.load("../data/X_train_taiwanese.npy")
y_train_tai=np.load("../data/y_train_taiwanese.npy")
X_test_tai=np.load("../data/X_test_taiwanese.npy")
y_test_tai=np.load("../data/y_test_taiwanese.npy")
outliers_tai=np.load("../data/outliers_tai.npy")

tai_Data=make_datalib("tai2", X_train_tai, X_test_tai, y_train_tai, y_test_tai)
del X_train_tai, X_test_tai, y_train_tai, y_test_tai


LufExptai=LUFExperiment('32.16', tai_Data, indices=outliers_tai[0:100], num_trials=None, seed=42,
             multi=True, model_args=[], model_kwargs={},
            compile_kwargs={"loss": tf.keras.losses.CategoricalCrossentropy(from_logits=True),
        "metrics": "categorical_accuracy",
        "optimizer" : tf.keras.optimizers.Adam(1e-3)}, fit_args=[], fit_kwargs={"batch_size":512, "epochs":50}, 
            custom_objects={}, early_stop_cutoff=1.e-1)
info_tai=LufExptai.run()
LufExptai.save()

RSExptai=RSExperiment('32.16', tai_Data, num_trials=100, seed=42,
             multi=True, model_args=[], model_kwargs={},
            compile_kwargs={"loss": tf.keras.losses.CategoricalCrossentropy(from_logits=True),
        "metrics": "categorical_accuracy",
        "optimizer" : tf.keras.optimizers.Adam(1e-3)}, fit_args=[], fit_kwargs={"batch_size":512, "epochs":50}, 
            custom_objects={}, early_stop_cutoff=1.e-1)
info_tai_RS=RSExptai.run()
RSExptai.save()


# warafin
X_train_war=np.load("../data/X_train_warafin.npy")
y_train_war=np.load("../data/y_train_warafin.npy")
X_test_war=np.load("../data/X_test_warafin.npy")
y_test_war=np.load("../data/y_test_warafin.npy")
outliers_war=np.load("../data/warafin_outliers.npy")

war_Data=make_datalib("war2", X_train_war, X_test_war, y_train_war, y_test_war)
del X_train_war, X_test_war, y_train_war, y_test_war


LufExpwar=LUFExperiment('100', war_Data, indices=outliers_war, num_trials=None, seed=42,
             model_args=[], model_kwargs={},
            compile_kwargs={"loss": tf.keras.losses.CategoricalCrossentropy(from_logits=True),
        "metrics": "categorical_accuracy",
        "optimizer" : tf.keras.optimizers.Adam(1e-3)}, fit_args=[], fit_kwargs={"batch_size":128, "epochs":100}, 
            custom_objects={}, early_stop_cutoff=1.e-1)
info_war=LufExpwar.run()
LufExpwar.save()

RSExpwar=RSExperiment('100', war_Data, num_trials=100, seed=42,
             model_args=[], model_kwargs={},
            compile_kwargs={"loss": tf.keras.losses.CategoricalCrossentropy(from_logits=True),
        "metrics": "categorical_accuracy",
        "optimizer" : tf.keras.optimizers.Adam(1e-3)}, fit_args=[], fit_kwargs={"batch_size":128, "epochs":100}, 
            custom_objects={}, early_stop_cutoff=1.e-1)
info_war_RS=RSExpwar.run()
RSExpwar.save()

