import os
os.environ['CUDA_VISIBLE_DEVICES']='1'

import tensorflow as tf
from tensorflow.keras.layers import *
from tensorflow.keras.models import *
from tqdm import tqdm

import h5py
import numpy as np
import matplotlib.pyplot as plt

def _get_segment_indices( ids_):
    head = tf.reverse(tf.add(tf.reduce_max(ids_, axis=0), 1),[0])
    max_columns = tf.concat([tf.reverse(tf.slice(head, [0], [tf.size(head)-1]),[0]),[1]],0)
    multipliers = tf.math.cumprod(max_columns, reverse=True)
    y, idx = tf.unique(tf.reduce_sum(tf.multiply(ids_, multipliers),axis=1))
    return idx

def _reduce_indices( ids_):
    ids_shape   = tf.shape(ids_)
    root_shape  = tf.gather(ids_shape, tf.range(0,tf.size(ids_shape)-1))
    last_column = tf.gather(ids_shape, [tf.size(ids_shape)-1])
    new_shape = tf.concat([last_column-1, root_shape], 0)
    reduced_ids = tf.reshape(tf.gather(tf.reshape(tf.transpose(ids_), [-1]), tf.range(0, tf.reduce_prod(new_shape))), new_shape)
    reduced_ids = tf.transpose(reduced_ids)
    return reduced_ids

def reduce_segment_ids(segment_ids, n):
    lengths = segment_ids
    for i in range(n):
        segment_indices = _get_segment_indices(lengths)
        reduced_ids = _reduce_indices(lengths)
        lengths = tf.math.segment_max(reduced_ids, segment_indices)
    return tf.cast(segment_indices, tf.int32)

def get_model():
    in_ = Input(shape=(138,))
    out_1 = Lambda(lambda t: reduce_segment_ids(tf.cast(t,tf.int32), 1), output_shape=(1,))(in_[:, 136:])
    out_2 = Lambda(lambda t: reduce_segment_ids(tf.cast(t,tf.int32), 2), output_shape=(1,))(in_[:, 136:])
    
    x = in_[:, :136]
    x = Dense(500, activation='relu')(x)
    x1 = Dense(250)(x)
    x2 = Dense(250)(x)
    x1 = Lambda(lambda t: tf.math.segment_max(t[0], t[1] ))([x1, out_1])
    x2 = Lambda(lambda t: tf.math.segment_mean(t[0], t[1] ))([x2, out_1])
    x  = Concatenate()([x1, x2])
    
    x1 = Dense(250)(x)
    x2 = Dense(250)(x)
    x1 = Lambda(lambda t: tf.math.segment_max(t[0], t[1] ))([x1, out_2])
    x2 = Lambda(lambda t: tf.math.segment_mean(t[0], t[1] ))([x2, out_2])
    x  = Concatenate()([x1, x2])

    x  = Dense(1, activation='sigmoid')(x)
    model = Model(in_, x)
    model.build((None, 138))
    return model

def get_set_of_set(A):
    # A is the adj matrix
    nodes =  int((A[:, 0] > -1).sum())
    features = np.zeros( (nodes, A.shape[0]), dtype=np.float32)
    N = []
    segment_idx = []
    
    for i in range(nodes):
        degree = A[i].sum() + 1
        features[i, :degree+1] = 1/np.sqrt(degree)
        neighs = [i] + list(np.where(A[i]>0)[0])
        neighs = np.array(neighs)
        N.append(neighs)
    X = np.zeros((sum([len(neighs) for neighs in N]),  features.shape[1]), dtype=np.float32)
    offset = 0
    for i,neighs in enumerate(N):
        X[offset:offset+len(neighs)] = features[neighs]
        segment_idx.append([[-1,i]]*len(neighs))
        offset += len(neighs)
    return X, np.concatenate(segment_idx)
    
class TH:
    def __init__(self, model):
        self.model = model
        self.opt = tf.keras.optimizers.Adam(learning_rate=1e-3)
        
    def train(self, x, y):
        loss = self._train(x,y)
        loss = loss.numpy()
        return loss
        
    @tf.function(experimental_relax_shapes=True)
    def _train(self, x, y):
        with tf.GradientTape() as tape:
            y = tf.reshape(y, [-1])
            y_hat = tf.reshape(self.model(x), [-1])
            loss = tf.keras.losses.binary_crossentropy(y, y_hat)
            loss = tf.reduce_mean(loss)
        tw = self.model.trainable_weights
        grads = tape.gradient(loss,tw)
        self.opt.apply_gradients(zip(grads,tw))
        return loss
    
data = h5py.File('imdb_binary.h5', 'r')
adj = np.array(data['adj'])

batch_size = 32
epochs = 100
decay = 0.5**(2./epochs)
n_batches = 900//batch_size

all_tests = []

for run in range(10):
    idx = data['run'][str(run)]
    for cv in range(10):
        test_idx = idx[cv*100:(cv+1)*100]
        train_idx = np.array(list(set(idx) - set(test_idx)))
        model = get_model()
        th = TH(model)
        pbar = tqdm(range(epochs))
        for epoch in pbar:
            np.random.shuffle(train_idx)
            for b in range(n_batches):
                batch_idx = train_idx[b*batch_size:(b+1)*batch_size]
                
                X_batch = []
                s_batch = []
                for i in batch_idx:
                    F, N = get_set_of_set(adj[i])
                    N[:, 0]=i
                    X_batch.append(F)
                    s_batch.append(N)
                X_batch = np.hstack((np.vstack(X_batch), np.vstack(s_batch))).astype(np.float32)
                y_batch = data['labels'][:][batch_idx]
                loss = th.train(X_batch, y_batch)
                pbar.set_description(f'L={loss:.3f}')
            th.opt.learning_rate = th.opt.learning_rate * decay
        X_batch = []
        s_batch = []
        for i in test_idx:
            F, N = get_set_of_set(adj[i])
            N[:, 0]=i
            X_batch.append(F)
            s_batch.append(N)
        X_batch = np.hstack((np.vstack(X_batch), np.vstack(s_batch))).astype(np.float32)
        y_batch = data['labels'][:][test_idx]
        accuracy = np.mean((model(X_batch, training=False).numpy().ravel() > 0.5 )==y_batch.ravel())
        all_tests.append(accuracy)