from colorama import Fore, Style

RED=Fore.RED
GREEN=Fore.GREEN
BLUE=Fore.BLUE
YELLOW=Fore.YELLOW
RESET=Style.RESET_ALL

import json
import numpy as np
import os
import seaborn as sns
import tensorflow as tf
from tqdm import tqdm

# custom lib
from gplib import TrainHelper

args = {
    "inv_iterations": 5,
    "lr": 1e-4,
    "inducing_points": 512,
    "batch_size": 32,
    "epochs": 10,
    "id_gpu":3,
    "runs":5,
    "seed":12345,
}

from models import get_model

if __name__ == '__main__':
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args['id_gpu']) if args['id_gpu'] >= 0 else ''
    
    # set seed for reproducibility
    np.random.seed(args['seed'])
    
    vocab_size = 20000  # Only consider the top 20k words
    maxlen = 200  # Only consider the first 200 words of each movie review
    (X_train, y_train), (X_test, y_test) = tf.keras.datasets.imdb.load_data(num_words=vocab_size)
    X_train = np.array([ (np.array(x[:maxlen//2] + x[-maxlen//2:]) if len(x)>maxlen \
                 else np.concatenate((x, [0]*(maxlen-len(x))))) for x in X_train]).astype(np.float32)
    X_test  = np.array([ (np.array(x[:maxlen//2] + x[-maxlen//2:]) if len(x)>maxlen \
                 else np.concatenate((x, [0]*(maxlen-len(x))))) for x in X_test]).astype(np.float32)
    y_train = (2*y_train - 1).astype(np.float32)
    y_test  = (2*y_test  - 1).astype(np.float32)
    
    all_rmse = []
    for run in range(args['runs']):
        print(f'Starting run {run:d}')

        # same preprossing as  @article{Chen_Zheng,
        # title={Stochastic Gradient Descent in 
        # Correlated Settings: A Study on Gaussian Processes}, 
        # author={Chen, Hao and Zheng, Lili}, pages={12} }
        model, regressor = get_model(X_train.shape[1:], args['inducing_points'])
        th = TrainHelper(model, regressor, lr=args['lr'], inv_iterations=args['inv_iterations'])

        batch_size = args['batch_size']
        epochs = args['epochs']
        decay = 0.5**(2./epochs)

        # TRAINING
        pbar = tqdm(range(epochs))
        idx = np.arange(len(X_train))
        n_batches = len(idx) // batch_size
        losses = []
        for epoch in pbar:
            np.random.shuffle(idx)
            for b in range(n_batches):
                batch_idx = idx[b*batch_size:(b+1)*batch_size]
                X_batch, y_batch   = X_train[batch_idx], y_train[batch_idx]
                loss = th.train_step(X_batch, y_batch)
                losses.append(loss)
                pbar.set_description('{:.2f}'.format(np.mean(losses[-50:])))
            th.opt.lr = th.opt.lr * decay

        # TESTING
        idx = np.arange(len(X_test))
        n_batches = int(np.ceil(len(idx) / batch_size))
        y_pred = []
        K_pred = []
        for b in range(n_batches):
            batch_idx = idx[b*batch_size:min((b+1)*batch_size, len(X_test))]
            a,b = th.predict(X_test[batch_idx])
            y_pred.append(a.ravel())
            K_pred.append(b.ravel())
        y_pred = np.concatenate(y_pred)
        K_pred = np.concatenate(K_pred)
    
        accuracy = 100*( (y_pred>0.5).ravel() == (y_test>0).ravel()).mean()
        
        print(f'Test accuracy {accuracy:.2f}')
        print('-'*50)
