'''
Test Heteroscedastic Variational Bayesian Last Layers.

'''
import os
import sys
sys.path.append('.')
path0 = os.path.dirname(sys.argv[0])

import time
import numpy as np
import torch
from torch.utils.data import DataLoader

from hvbll.basic import log, get_lr, loss_history
from hvbll.hvbll import HVBLL

from utils import plot_1d, prepare_dataset, save_to_csv


GPU_ID = 0
N_SAMPLE = 200

FUNCTION_NAME = "xsin-lin"


if __name__ == '__main__':
    
    #* ===========================================================
    #* Preparation
    if True:

        pM, fname_log, fname_loss, dataset, test_set, plot_y_min, plot_y_max \
            = prepare_dataset(FUNCTION_NAME, N_SAMPLE, GPU_ID, net_name='HVBLL')

        dataloader = DataLoader(dataset, batch_size=pM('batch_size'), shuffle=True, drop_last=False)
        
        path = os.path.join(path0, 'results-N%d'%(N_SAMPLE))
        os.makedirs(path, exist_ok=True)
        
        fname_loss = os.path.join(path, f"loss-HVBLL-{pM('wishart_scale')}.dat")

        #* Model Setup
        net = HVBLL(dim_input=pM('dim_input'), dim_output=pM('dim_output'),
                    dim_latent=pM('dim_latent'), dim_hidden=pM('dim_hidden'),
                    dim_hidden_noise=pM('dim_hidden_noise'),
                    n_hidden_layers=pM('n_layer_latent'),
                    n_noise_layers=pM('n_layer_noise'),
                    reg_weight_latent=pM('reg_weight_latent'),
                    reg_weight_noise=pM('reg_weight_noise'),
                    covariance_type=pM('covariance_type'),
                    prior_scale=pM('prior_scale'),
                    wishart_scale=pM('wishart_scale'), dof=pM('dof'))
        
        if torch.cuda.is_available():
            net.cuda(GPU_ID)
        else:
            GPU_ID = -1
        
        #* Optimizer
        optimizer = torch.optim.AdamW(net.parameters(), lr=pM('lr'), betas=(pM('b1'), pM('b2')))
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, gamma=0.8, step_size=pM('step_size'))
        
        
    #* ==================================================================================
    #* Model training
    if True:
        
        log('Model training: GPU ID = %d'%(GPU_ID), fname=fname_log)

        t0 = time.perf_counter()
        info = {}
        info['test_nll'] = 0.0
        
        #* Model training
        for epoch in range(pM('n_epochs')):
                        
            info['epoch']    = epoch
            info['loss'] = []
            info['nll'] = []
            info['kl_term'] = []
            info['wishart_term'] = []
            
            prefix='> '
            
            if  epoch%pM('n_fig')==0 or epoch==pM('n_epochs')-1:
                save_prediction = True
            else:
                save_prediction = False
        
            for batch_idx, (xs, ys) in enumerate(dataloader):
                
                optimizer.zero_grad()

                out = net.forward(xs)

                result = out.train_loss_fn(ys)
                loss = result['neg_total_elbo']
                loss.backward()
                
                torch.nn.utils.clip_grad_norm_(net.parameters(), pM('clip_grad'))
                
                optimizer.step()
                
                info['loss'].append(loss.item())
                info['nll'].append(result['nll'].item())
                info['kl_term'].append(result['kl_term'].item())
                info['wishart_term'].append(result['wishart_term'].item())

                #* Save field and plot
                if save_prediction:
                    
                    save_prediction = False
                    prefix='= '
                    
                    # torch.save(net.state_dict(), os.path.join(path,'parameters-%d.pth'%(epoch)))

                    # Plot the figure
                    plot_name = f"N{N_SAMPLE}-sigma0-{pM('wishart_scale')}"
                    
                    fig_data = plot_1d(dataset, net, 
                            y_min=plot_y_min, y_max=plot_y_max,
                            path=path, name=plot_name, 
                            return_data=True)
                    
                    # Save data to CSV
                    if fig_data:
                        csv_filename = os.path.join(path, f"{net.name}-{plot_name}.csv")
                        save_to_csv(csv_filename, fig_data)
                        log(f"Saved figure data to {csv_filename}", fname=fname_log)
                    
                    #* Test the model (the predictive negative log likelihood (NLL) of test data)
                    with torch.no_grad():
                        
                        net.eval()
                        
                        out = net.forward(test_set.X)
                        
                        test_nll = out.val_loss_fn(test_set.Y)
                        
                        info['test_nll'] = test_nll.item()

            t1 = time.perf_counter()
            
            learning_rate = get_lr(optimizer)

            info['loss'] = np.mean(info['loss'])
            info['nll'] = np.mean(info['nll'])
            info['kl_term'] = np.mean(info['kl_term'])
            info['wishart_term'] = np.mean(info['wishart_term'])
            
            info['time'] = t1-t0
            info['lr'] = learning_rate

            if epoch%pM('n_show')==0:
                
                #* Print information
                text = "[Epoch %5d] | [TRAIN] -ELBO:%10.3E NLL:%10.3E KL_W:%10.3E KL_N:%10.3E | [TEST] NLL:%10.3E | t= %.1fm | lr= %.2E | @%2d"%(
                    epoch, info['loss'], info['nll'], info['kl_term'], info['wishart_term'],
                    info['test_nll'], (t1-t0)/60.0, info['lr'], GPU_ID)
                
                log(text, prefix=prefix, fname=fname_log)
                
                loss_history(epoch, info, fname_loss)
                
            #* --------------------------------------------------
            #* Monitor training process

            if learning_rate > pM('min_lr'):
                lr_scheduler.step()  
                
