import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="1"
import sys

import tensorflow as tf
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(e)


os.environ['TF_DETERMINISTIC_OPS'] = '1'

import warnings
warnings.filterwarnings('ignore')
import tensorflow.keras.backend as K
tf.random.set_seed(1234)

import numpy as np
import random as rn
rn.seed(12345)
np.random.seed(42)

import tensorflow_datasets as tfds

image, label = tfds.as_numpy(tfds.load(
    'colorectal_histology',
    split='train',
    batch_size=-1,
    as_supervised=True,
))
image = np.moveaxis(image, 3, 1)
image=tf.keras.applications.resnet.preprocess_input(image)
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test=train_test_split(
    image,
    label, shuffle=True, random_state=0)

y_train_cat=tf.keras.utils.to_categorical(y_train, num_classes=8)
y_test_cat=tf.keras.utils.to_categorical(y_test, num_classes=8)

#randomly generated points to leave out of training set for leave-one-out experiments
outliers = np.load('../data/colon_res/outliers_colon.npy')




from tensorflow.keras.utils import Progbar
import tensorflow.keras as keras
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, Activation, Dropout, Flatten, Conv2D, MaxPooling2D


def make_res(X_train, y_train, weights="imagenet", need_first_weights=False, 
             w_lastt=None, w_2_lastt=None, batch_size=512, seed=42,
            epochs=20, set_weights=False):
    shape=(X_train.shape[1], X_train.shape[2], X_train.shape[3])
    num_classes=y_train.shape[1]
    np.random.seed(seed)
    res=tf.keras.applications.ResNet50(
    include_top=False,
    weights=weights,
    input_tensor=tf.keras.layers.Input(shape),
    input_shape=shape,
    pooling='avg')
    if weights is not None:
        res.trainable=False
    x = Flatten()(res.output)
    x = Dense(1024)(x)
    x = Activation('relu')(x)
    x = Dense(num_classes)(x)
    f = Activation('softmax')(x)
    model = Model(res.input, f)
    if need_first_weights:
        w_last=model.layers[-2].get_weights()
        w_2_last=model.layers[-4].get_weights()
    if set_weights:
        model.layers[-2].set_weights(w_lastt)
        model.layers[-4].set_weights(w_2_lastt)
    model.compile(
      optimizer="adam", 
      loss='categorical_crossentropy', 
      metrics=['accuracy'])
    model.fit(
      X_train, 
      y_train,
      batch_size=batch_size, 
      epochs=epochs)
    if need_first_weights:
        return w_last, w_2_last, model
    else:
        return model


#BASELINE

pb = Progbar(100)

yp_train = []
yp_out = []
yp_test = []
N = 1
np.random.seed(42)
w_lastt, w_2_lastt, m_prime= make_res(X_train,y_train_cat, need_first_weights=True)
yp_train.append(m_prime.predict(X_train)) 
yp_out.append(m_prime.predict(X_train_out))
yp_test.append(m_prime.predict(X_test))
pb.add(1)
w=[w_lastt, w_2_lastt]
idx=[-2, -4]
for i in range(len([-2, -4])):
    np.save('../data/colon_res/models/colon_init_weights_w_{}.npy'.format(idx[i]), w[i][0], allow_pickle=False)
    np.save('../data/colon_res/models/colon_init_weights_b_{}.npy'.format(idx[i]), w[i][1], allow_pickle=False)

m_prime.save('../data/colon_res/models/colon_baseline.h5')
del m_prime
K.clear_session()

np.save('../data/colon_res/colon_yp_train_same_seed.npy', np.array(yp_train), allow_pickle=False)
np.save('../data/colon_res/colon_yp_test_same_seed.npy',np.array(yp_test), allow_pickle=False)


yp_train = []
yp_out = []
yp_test = []

pb = Progbar(len(outliers))
pb.add(0)
for i in range(200): 
    out=outliers[i]

    x_train_p = np.delete(X_train, [outliers[i]], axis=0)
    y_train_p = np.delete(y_train_cat, [outliers[i]], axis=0)


    np.random.seed(42)
    m_prime=  make_res(x_train_p,y_train_p)
    m_prime.save("../data/colon_res/colon_loo_model_{}.h5".format(i))

    yp_train.append(m_prime.predict(X_train)) 
    yp_out.append(m_prime.predict(X_train_out))
    yp_test.append(m_prime.predict(X_test))
    del m_prime 
    K.clear_session()
    pb.add(1)


    np.save('../data/colon_res/colon_loo_train_same_seed.npy', np.array(yp_train), allow_pickle=False)
    np.save('../data/colon_res/colon_loo_test_same_seed.npy',np.array(yp_test), allow_pickle=False)

#changing seed experiments
for seed in range(200):
    pb = Progbar(100)

    yp_train = []
    yp_out = []
    yp_test = []
    N = 1
    np.random.seed(seed)
    m_prime=make_res(X_train,y_train_cat, seed=seed)
    yp_train.append(m_prime.predict(X_train)) 
    yp_out.append(m_prime.predict(X_train_out))
    yp_test.append(m_prime.predict(X_test))
    pb.add(1)
    m_prime.save('../data/colon_res/models/colon_baseline_seed{}.h5'.format(seed))
    del m_prime
    K.clear_session()
    
    np.save('../data/colon_res/colon_yp_train_same_seed_seed{}.npy'.format(seed), np.array(yp_train), allow_pickle=False)
    np.save('../data/colon_res/colon_yp_out_same_seed_seed{}.npy'.format(seed),np.array(yp_out), allow_pickle=False)
    np.save('../data/colon_res/colon_yp_test_same_seed_seed{}.npy'.format(seed),np.array(yp_test), allow_pickle=False)

