from colorama import Fore, Style

RED=Fore.RED
GREEN=Fore.GREEN
BLUE=Fore.BLUE
YELLOW=Fore.YELLOW
RESET=Style.RESET_ALL

import json
import numpy as np
import os
import seaborn as sns
import tensorflow as tf
from tqdm import tqdm
from time import time
# custom lib
from benchmark_datasets import BenchmarkDataset
from config import get_config
from gplib import TrainHelper
from models import get_model

def loglikelihood(y, K):
    y = y.ravel()
    K_inv = np.linalg.inv(K)
    _, logdet = np.linalg.slogdet(K)
    t1 = -0.5 * (y[None] @ K_inv @ y[:,None]) / len(y)
    t2 = -0.5 * logdet / len(y)
    t3 = -0.5 * np.log(np.pi*2)
    return t1+t2+t3

if __name__ == '__main__':
    args = get_config()
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args['id_gpu']) if args['id_gpu'] >= 0 else ''
    
    data_config = json.load(open('data_config.json'))
    dataset_name = args['dataset']
    # set seed for reproducibility
    np.random.seed(data_config[dataset_name]['seed'])
    
    all_rmse, all_vars, all_ll, all_tts = [], [], [], []
    for run in range(args['runs']):
        print(f'Starting run {run:d}')
        dataset = BenchmarkDataset(dataset_name, data_config)
        (X_train, y_train), (X_valid, y_valid), (X_test, y_test) = dataset.get_data()

        # same preprossing as  @article{Chen_Zheng,
        # title={Stochastic Gradient Descent in 
        # Correlated Settings: A Study on Gaussian Processes}, 
        # author={Chen, Hao and Zheng, Lili}, pages={12} }
        MX, SX = X_train.mean(axis=0, keepdims=True), X_train.std(axis=0, keepdims=True)
        MY, SY = y_train.mean(axis=0, keepdims=True), y_train.std(axis=0, keepdims=True)

        X_train = (X_train - MX) / SX
        y_train = (y_train - MY) / SY

        X_test  = (X_test - MX) / SX
        y_test  = (y_test - MY) / SY
        
        model, regressor = get_model(X_train.shape[-1], args['inducing_points'])
        th = TrainHelper(model, regressor, lr=args['lr'], inv_iterations=args['inv_iterations'])
        
        batch_size = args['batch_size']
        epochs = args['epochs']
        decay = 0.5**(2./epochs)

        # TRAINING
        pbar = tqdm(range(epochs))
        idx = np.arange(len(X_train))
        n_batches = len(idx) // batch_size
        losses = []
        tin = time()
        for epoch in pbar:
            np.random.shuffle(idx)
            for b in range(n_batches):
                batch_idx = idx[b*batch_size:(b+1)*batch_size]
                X_batch, y_batch   = X_train[batch_idx], y_train[batch_idx]
                loss = th.train_step(X_batch, y_batch)
                losses.append(loss)
                pbar.set_description('{:.2f}'.format(np.mean(losses[-50:])))
            th.opt.lr = th.opt.lr * decay
        tout = time()
        all_tts.append((tout-tin)/epochs)
        # TESTING
        idx = np.arange(len(X_test))
        n_batches = int(np.ceil(len(idx) / batch_size))
        y_pred = []
        K_pred = []
        ll = []
        for b in range(n_batches):
            batch_idx = idx[b*batch_size:min((b+1)*batch_size, len(X_test))]
            yp, Kp = th.predict(X_test[batch_idx])
            ll_ = np.sum(loglikelihood(yp, Kp))
            if not np.isnan(ll_):
                ll.append(ll_)
            y_pred.append(yp.ravel())
            K_pred.append(np.diag(Kp).ravel())
        y_pred = np.concatenate(y_pred)
        K_pred = np.concatenate(K_pred)
        rmse = np.sqrt( ((y_pred.ravel() - y_test.ravel())**2).mean())
        all_rmse.append(rmse)
        all_vars.append(K_pred.mean())
        all_ll.append(np.mean(ll))
        print(all_ll)
        print('-'*50)
    print(all_rmse)
    np.save(f"ablation/{dataset_name}_{args['inducing_points']}.npy", all_vars)
    print(f"{GREEN}RMSE {np.mean(all_rmse):.3f}±{np.std(all_rmse):.3f} - LL={np.mean(all_ll):.3f}±{np.std(all_ll):.3f} with {args['inducing_points']} inducing points.{RESET}")
    print(f"{GREEN}Time per epoch {np.mean(all_tts):.3f}±{np.std(all_tts):.3f}")
    print(f"{YELLOW}SOTA was RMSE {data_config[dataset_name]['sota'][0]:.3f}±{data_config[dataset_name]['sota'][1]:.3f}.{RESET}")
        
