from colorama import Fore, Style

RED=Fore.RED
GREEN=Fore.GREEN
BLUE=Fore.BLUE
YELLOW=Fore.YELLOW
RESET=Style.RESET_ALL

import json
import numpy as np
import os
import seaborn as sns
import tensorflow as tf
from tqdm import tqdm

# custom lib
from gplib import TrainHelper
from models import get_model

os.environ['CUDA_VISIBLE_DEVICES'] = '1'

vocab_size = 20000  # Only consider the top 20k words
maxlen = 200  # Only consider the first 200 words of each movie review
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.imdb.load_data(num_words=vocab_size)
X_train = np.array([ (np.array(x[:maxlen//2] + x[-maxlen//2:]) if len(x)>maxlen \
         else np.concatenate((x, [0]*(maxlen-len(x))))) for x in X_train]).astype(np.float32)
X_test  = np.array([ (np.array(x[:maxlen//2] + x[-maxlen//2:]) if len(x)>maxlen \
         else np.concatenate((x, [0]*(maxlen-len(x))))) for x in X_test]).astype(np.float32)
y_train = (2*y_train - 1).astype(np.float32)
y_test  = (2*y_test  - 1).astype(np.float32)

for ippp in [2,4,8,16,32,64,128,256,512]:
    args = {
        "inv_iterations": 5,
        "lr": 1e-4,
        "inducing_points": ippp,
        "batch_size": 32,
        "epochs": 50,
        "id_gpu":0,
        "runs":5,
        "seed":12345,
    }
    # set seed for reproducibility
    np.random.seed(args['seed'])

    all_rmse = []
    for run in range(args['runs']):
        os.makedirs(f'results/run_{run:d}', exist_ok=True)
        print(f'Starting run {run:d}')

        # same preprossing as  @article{Chen_Zheng,
        # title={Stochastic Gradient Descent in 
        # Correlated Settings: A Study on Gaussian Processes}, 
        # author={Chen, Hao and Zheng, Lili}, pages={12} }
        model, regressor = get_model(X_train.shape[1:], args['inducing_points'])
        th = TrainHelper(model, regressor, lr=args['lr'], inv_iterations=args['inv_iterations'])

        batch_size = args['batch_size']
        epochs = args['epochs']
        decay = 0.5**(2./epochs)

        # TRAINING
        pbar = tqdm(range(epochs))
        idx = np.arange(len(X_train))
        n_batches = len(idx) // batch_size
        losses = []
        for epoch in pbar:
            np.random.shuffle(idx)
            for b in range(n_batches):
                batch_idx = idx[b*batch_size:(b+1)*batch_size]
                X_batch, y_batch   = X_train[batch_idx], y_train[batch_idx]
                loss = th.train_step(X_batch, y_batch)
                losses.append(loss)
                pbar.set_description('{:.2f}'.format(np.mean(losses[-50:])))
            th.opt.lr = th.opt.lr * decay

        # TESTING
        idx = np.arange(len(X_test))
        n_batches = int(np.ceil(len(idx) / batch_size))
        y_pred = []
        K_pred = []
        for b in range(n_batches):
            batch_idx = idx[b*batch_size:min((b+1)*batch_size, len(X_test))]
            a,b = th.predict(X_test[batch_idx])
            y_pred.append(a.ravel())
            K_pred.append(b.ravel())
        y_pred = np.concatenate(y_pred)
        K_pred = np.concatenate(K_pred)

        accuracy = 100*( (y_pred>0.5).ravel() == (y_test>0).ravel()).mean()
        np.save(f'results/run_{run:d}/accuracy_{ippp:d}.npy', accuracy)
        np.save(f'results/run_{run:d}/K_pred_{ippp:d}.npy', K_pred)

        print(f'{run:d} {ippp:d} - Test accuracy {accuracy:.2f} - std: {K_pred.mean():.2f}')
        print('-'*50)