from colorama import Fore, Style

RED=Fore.RED
GREEN=Fore.GREEN
BLUE=Fore.BLUE
YELLOW=Fore.YELLOW
RESET=Style.RESET_ALL

import json
import numpy as np
import os
import seaborn as sns
import tensorflow as tf
from tqdm import tqdm
import cv2

# custom lib
from gplib import TrainHelper, probit_

args = {
    "inv_iterations": 5,
    "lr": 1e-4,
    "inducing_points": 16,
    "batch_size": 32,
    "epochs": 10,
    "id_gpu":1,
    "runs":10,
    "seed":12345,
}

from models import get_model

if __name__ == '__main__':
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args['id_gpu']) if args['id_gpu'] >= 0 else ''
    
    # set seed for reproducibility
    np.random.seed(args['seed'])
    
    all_accuracies = []
    for run in range(1,args['runs']):
        print(f'Starting run {run:d}')
        for label in range(10):
            (X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()
            y_train = y_train.ravel()
            y_test  = y_test.ravel()
#             X_train = np.array([cv2.cvtColor(x, cv2.COLOR_RGB2GRAY) for x in X_train])
#             X_test  = np.array([cv2.cvtColor(x, cv2.COLOR_RGB2GRAY) for x in X_test])
            X_train = X_train.astype(np.float32)/255.
            X_test  = X_test.astype(np.float32)/255.
            
            mean = X_train.mean(keepdims=True)
            std  = X_train.std(keepdims=True)
            print(mean, std)
            
            X_train = (X_train - mean)/std
            X_test = (X_test - mean)/std
            
            train_idx = np.logical_or((y_train==label), (y_train!=label))
            test_idx  = np.logical_or((y_test==label), (y_test!=label))
        
            X_train = X_train[train_idx]
            X_test  = X_test[test_idx]
        
            y_train = y_train[train_idx].astype(np.float32)
            y_test  = y_test[test_idx].astype(np.float32)

            aux_train = y_train.copy()
            aux_train[y_train==label] = 1.
            aux_train[y_train!=label] = -1.
        
            aux_test = y_test.copy()
            aux_test[y_test==label] = 1.
            aux_test[y_test!=label] = -1.

            y_train = aux_train
            y_test = aux_test
            # same preprossing as  @article{Chen_Zheng,
            # title={Stochastic Gradient Descent in 
            # Correlated Settings: A Study on Gaussian Processes}, 
            # author={Chen, Hao and Zheng, Lili}, pages={12} }
            model, regressor = get_model(X_train.shape[1:], args['inducing_points'])
            th = TrainHelper(model, regressor, lr=args['lr'], inv_iterations=args['inv_iterations'])
            
            batch_size = args['batch_size']
            epochs = args['epochs']
            decay = 0.5**(2./epochs)

            # TRAINING
            pbar = tqdm(range(epochs))
            idx_pos = np.where(y_train == 1)[0]
            idx_neg = np.where(y_train == -1)[0]
            idx = np.arange(len(X_train))
            n_batches = len(X_train) // batch_size
            losses = []
            for epoch in pbar:
                np.random.shuffle(idx)
                for b in range(n_batches):
                    li = b*batch_size
                    ri = min((b+1)*batch_size, len(X_train))
                    #b_pos = np.random.choice(len(idx_pos), batch_size//2, replace=False)
                    #b_neg = np.random.choice(len(idx_neg), batch_size-batch_size//2, replace=False)
                    #b_pos = idx_pos[b_pos]
                    #b_neg = idx_neg[b_neg]
                    #batch_idx = np.concatenate((b_pos,b_neg))
                    batch_idx = idx[li:ri]
                    X_batch, y_batch   = X_train[batch_idx], y_train[batch_idx]
                    loss = th.train_step(X_batch, y_batch)
                    losses.append(loss)
                    pbar.set_description('{:.2f}'.format(np.mean(losses[-50:])))
                th.opt.lr = th.opt.lr * decay
                
            # TESTING
            idx = np.arange(len(X_test))
            n_batches = int(np.ceil(len(idx) / batch_size))
            y_preds, K_preds = [],[]
            for b in range(n_batches):
                batch_idx = idx[b*batch_size:min((b+1)*batch_size, len(X_test))]
                y_pred_shift, K_pred_shift, y_pred, K_pred = th.predict(X_test[batch_idx])
                y_preds.append(y_pred_shift)
                K_preds.append(K_pred_shift)
            y_preds = np.concatenate(y_preds)
            K_preds = np.concatenate(K_preds)

            os.makedirs(f'predictions/{run:d}', exist_ok=True)
            np.save(f'predictions/{run:d}/{label:d}_mean.npy', y_preds)
            np.save(f'predictions/{run:d}/{label:d}_var.npy', K_preds)


        (_, _), (_, y_test) = tf.keras.datasets.mnist.load_data()
        P = [np.load(f'predictions/{run:d}/{i:d}_mean.npy') for i in range(10)]
        P = np.array(P).T
        FP = np.argmax(P, axis=1)
        accuracy = np.mean(FP == y_test.ravel()) * 100
        print(f'Test accuracy {accuracy:.2f}')
        all_accuracies.append(accuracy)
    print(np.mean(all_accuracies), np.std(all_accuracies))
            # y_preds = 2*(probit_(y_preds) > 0.5)-1
            # accuracy = 100*(y_preds.ravel() == y_test.ravel()).mean()
            # print(f'Test accuracy {accuracy:.2f}')
            # print('-'*50)
