from colorama import Fore, Style

RED=Fore.RED
GREEN=Fore.GREEN
BLUE=Fore.BLUE
YELLOW=Fore.YELLOW
RESET=Style.RESET_ALL

import json
import numpy as np
import os
import seaborn as sns
import tensorflow as tf
from tqdm import tqdm
import h5py

# custom lib
from gplib import TrainHelper

args = {
    "inv_iterations": 10,
    "lr": 1e-4,
    "inducing_points":64,
    "batch_size": 16,
    "epochs": 100,
    "id_gpu":4,
    "runs":5,
    "seed":12345,
}

from models import get_model

def get_set_of_set(A):
    # A is the adj matrix
    nodes =  int((A[:, 0] > -1).sum())
    features = np.zeros( (nodes, A.shape[0]), dtype=np.float32)
    N = []
    segment_idx = []
    
    for i in range(nodes):
        degree = A[i].sum() + 1
        features[i, :degree+1] = 1./np.sqrt(degree)
        neighs = [i] + list(np.where(A[i]>0)[0])
        neighs = np.array(neighs)
        N.append(neighs)
    X = np.zeros((sum([len(neighs) for neighs in N]),  features.shape[1]), dtype=np.float32)
    offset = 0
    for i,neighs in enumerate(N):
        X[offset:offset+len(neighs)] = features[neighs]
        segment_idx.append([[-1,i]]*len(neighs))
        offset += len(neighs)
    return X, np.concatenate(segment_idx)

def get_set_of_set_int(A):
    # A is the adj matrix
    nodes =  int((A[:, 0] > -1).sum())
    features = np.zeros( (nodes, 1), dtype=np.float32)
    N = []
    segment_idx = []
    
    for i in range(nodes):
        degree = A[i].sum() + 1
        features[i] = degree
        neighs = [i] + list(np.where(A[i]>0)[0])
        neighs = np.array(neighs)
        N.append(neighs)
    X = np.zeros((sum([len(neighs) for neighs in N]),  features.shape[1]), dtype=np.float32)
    offset = 0
    for i,neighs in enumerate(N):
        X[offset:offset+len(neighs)] = features[neighs]
        segment_idx.append([[-1,i]]*len(neighs))
        offset += len(neighs)
    return X, np.concatenate(segment_idx)

if __name__ == '__main__':
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args['id_gpu']) if args['id_gpu'] >= 0 else ''
    
    # set seed for reproducibility
    np.random.seed(args['seed'])
    
    data = h5py.File('imdb_binary.h5', 'r')
    adj = np.array(data['adj'])


    batch_size = args['batch_size']
    epochs = args['epochs']
    decay = 0.5**(2./epochs)
    n_batches = 900//batch_size

    
    
    all_accuracies = []
    for run in range(args['runs']):
        print(f'Starting run {run:d}')
        idx = data['run'][str(run)]
        for cv in range(10):
            cv=0
            test_idx = idx[cv*100:(cv+1)*100]
            train_idx = np.array(list(set(idx) - set(test_idx)))

            model, regressor = get_model(1+2, args['inducing_points'])
            th = TrainHelper(model, regressor, lr=args['lr'], inv_iterations=args['inv_iterations'])

            batch_size = args['batch_size']
            epochs = args['epochs']
            decay = 0.5**(2./epochs)

            # TRAINING
            pbar = tqdm(range(epochs))
            loss = np.inf
            for epoch in pbar:
                np.random.shuffle(train_idx)
                for b in range(n_batches):
                    batch_idx = train_idx[b*batch_size:(b+1)*batch_size]

                    X_batch = []
                    s_batch = []
                    for e,i in enumerate(batch_idx):
                        F, N = get_set_of_set_int(adj[i])
                        N[:, 0]=i
                        X_batch.append(F.copy())
                        s_batch.append(N.copy())
                    X_batch = np.hstack((np.vstack(X_batch), np.vstack(s_batch))).astype(np.float32)
                    y_batch = data['labels'][:][batch_idx]
                    y_batch = np.array(2.*y_batch-1.).astype(np.float32)
                    if len(set(y_batch)) == 2:
                        loss = th.train_step(X_batch, y_batch)
                    pbar.set_description(f'L={loss:.3f}')
                th.opt.learning_rate = th.opt.learning_rate * decay

            # TESTING
            n_test_batches = int(np.ceil(len(test_idx)/batch_size))
            all_preds = []
            for b in range(n_test_batches):
                batch_idx = test_idx[b*batch_size:min((b+1)*batch_size, len(test_idx))]

                X_batch = []
                s_batch = []
                for e,i in enumerate(batch_idx):
                    F, N = get_set_of_set_int(adj[i])
                    N[:, 0]=e
                    X_batch.append(F)
                    s_batch.append(N)
                X_batch = np.hstack((np.vstack(X_batch), np.vstack(s_batch))).astype(np.float32)
                y_pred, K_pred = th.predict(X_batch)
                all_preds.append(y_pred)
            accuracy = np.mean((np.concatenate(all_preds) > 0.5) == data['labels'][:][test_idx])*100
            all_accuracies.append(accuracy)
            print(f'Test accuracy {accuracy:.2f}')
