import os
import sys
sys.path.append('.')
path = os.path.dirname(sys.argv[0])

import torch
import numpy as np
import matplotlib.pyplot as plt
import csv

from hvbll.basic import Parameters, init_log, log
from hvbll.toy_functions import SimpleFnDataset, ToyFn_Sin_Noise_Lin, ToyFn_Sin_Noise_Sin, ToyFn_Lin_Noise_Lin, ToyFn_Lin_Noise_Sin
from hvbll.hvbll import HVBLL


def plot_1d(dataset: SimpleFnDataset, net: HVBLL, y_min=None, y_max=None, path='./', name='figure', return_data=False):
    """
    Create plot for 1D function.
    
    Args:
        dataset: Dataset object with training data
        net: Neural network model
        y_min: Minimum y value for plot
        y_max: Maximum y value for plot
        path: Directory to save the figure
        name: Base name for the figure file
        return_data: If True, returns the data used for plotting
        
    Returns:
        If return_data is True, returns a tuple of (x_points, y_pred, y_std, y_lower, y_upper)
    """

    # Get training data
    X_train = dataset.X.cpu().numpy()
    Y_train = dataset.Y.cpu().numpy()
    
    # Create evaluation points
    x_min, x_max = X_train.min(), X_train.max()
    margin = 0.1 * (x_max - x_min)
    x_min -= margin
    x_max += margin
    
    n_point = 101
    x_points = np.linspace(x_min, x_max, n_point, endpoint=True)[..., None]
    x_tensor = torch.FloatTensor(x_points)
    
    # Ground truth
    y_true = dataset.func_mean(x_points)[:,0]
    a_uncertainty_true = dataset.func_noise(x_points)[:,0]
    
    if torch.cuda.is_available():
        x_tensor = x_tensor.cuda()
    
    # Compute predictions
    with torch.no_grad():
        
        net.eval()
        out = net.forward(x_tensor)
        
        dist_y = out.predictive
        y_pred = dist_y.mean.cpu().detach().numpy().squeeze()

        # Aleatoric uncertainty (1 std)
        if net.name == 'HVBLL':
            
            # Total uncertainty (1 std)
            y_std_pred  = torch.sqrt(dist_y.covariance.squeeze()).cpu().detach().numpy()
            
            hvblr = net.layers['out_layer']
            a_uncertainty_pred = hvblr.get_aleatoric_uncertainty(x_tensor).cpu().detach().cpu().numpy()
            a_uncertainty_pred = np.sqrt(a_uncertainty_pred)[:,0]
            
        elif net.name == 'VBLL':
            
            # Total uncertainty (1 std)
            y_std_pred  = torch.sqrt(dist_y.covariance.squeeze()).cpu().detach().numpy()
            
            vblr = net.layers['out_layer']
            a_uncertainty_pred = vblr.noise_std_numpy[0] * np.ones(n_point)

        else:
            raise ValueError('Unknown network name: %s'%(net.name))

        net.train()
        
    x_points = x_points[:,0]
    
    data = {
        'x_points': x_points,
        'y_true': y_true,
        'a_uncertainty_true': a_uncertainty_true,
        'y_pred': y_pred,
        'y_std_pred': y_std_pred,
        'a_uncertainty_pred': a_uncertainty_pred,
    }

    # Create the plot
    plt.figure(figsize=(10, 6))
    
    # Plot training data
    plt.scatter(X_train, Y_train, s=20, alpha=0.5, label='Training data')
    
    # Plot ground truth
    plt.plot(x_points, y_true, 'b-', label='Ground truth mean')
    
    # Plot the predicted mean
    plt.plot(x_points, y_pred, 'r-', label='Predicted mean')
    
    # Plot confidence intervals
    plt.fill_between(x_points, y_true - 2 * a_uncertainty_true, y_true + 2 * a_uncertainty_true, color='b', alpha=0.2, label='2σ confidence (ground truth)')
    plt.fill_between(x_points, y_pred - 2 * a_uncertainty_pred, y_pred + 2 * a_uncertainty_pred, color='r', alpha=0.4, label='2σ confidence (aleatoric)')
    plt.fill_between(x_points, y_pred - 2 * y_std_pred, y_pred + 2 * y_std_pred, color='g', alpha=0.4, label='2σ confidence (total)')
    
    # Set axis limits
    if y_min is not None and y_max is not None:
        plt.ylim(y_min, y_max)
    
    plt.legend()
    plt.title(f'1D Function Prediction - {name}')
    plt.xlabel('x')
    plt.ylabel('y')
    
    # Save the figure
    plt.savefig(os.path.join(path, f'{net.name}-{name}.png'), dpi=150, bbox_inches='tight')
    plt.close()
    
    
    if return_data:
        return data
    else:
        return None

def save_to_csv(filename, fig_data: dict):
    """
    Save plot data to CSV file for later adjustment
    Args:
        filename: CSV filename
    """
    with open(filename, 'w', newline='') as csvfile:
        writer = csv.writer(csvfile)
        
        # Write header
        header = fig_data.keys()
        writer.writerow(header)
        
        # Determine number of rows to write
        data_length = len(next(iter(fig_data.values())))
        
        # Write data
        for i in range(data_length):
            row = []
            for key in header:
                if i < len(fig_data[key]):
                    row.append(fig_data[key][i])
                else:
                    row.append(None)
            writer.writerow(row)

def prepare_dataset(FUNCTION_NAME: str, N_SAMPLE: int, GPU_ID: int, net_name: str, only_get_training_data: bool=False):

    if not only_get_training_data:

        pM = Parameters(os.path.join(path, 'parameters-%s.json'%(FUNCTION_NAME)), 'HVBLL')
        
        fname_log  = os.path.join(path, '%s-%s.log'%(net_name, FUNCTION_NAME))
        fname_loss = os.path.join(path, 'loss-%s-%s.dat'%(net_name, FUNCTION_NAME))
        
        init_log(path, fname=fname_log)

        log(' ', prefix='> ', fname=fname_log)
        log('Result folder = '+path, prefix='> ', fname=fname_log)

    #* Data Setup
    if FUNCTION_NAME == "lin-lin":
        dataset = ToyFn_Lin_Noise_Lin(num_samples=N_SAMPLE, seed=0, noise_level=0.1, gpu_id=GPU_ID)
        test_set = ToyFn_Lin_Noise_Lin(num_samples=20, seed=0, noise_level=0.1, gpu_id=GPU_ID)
        plot_y_min = -0.2
        plot_y_max = 1.2
        
    elif FUNCTION_NAME == "lin-sin":
        dataset = ToyFn_Lin_Noise_Sin(num_samples=N_SAMPLE, seed=1, noise_level=0.2, gpu_id=GPU_ID)
        test_set = ToyFn_Lin_Noise_Sin(num_samples=20, seed=1, noise_level=0.2, gpu_id=GPU_ID)
        plot_y_min = -0.2
        plot_y_max = 1.2
    
    elif FUNCTION_NAME == "sin-lin":
        dataset = ToyFn_Sin_Noise_Lin(num_samples=N_SAMPLE, seed=0, noise_level=0.2, gpu_id=GPU_ID)
        test_set = ToyFn_Sin_Noise_Lin(num_samples=20, seed=0, noise_level=0.2, gpu_id=GPU_ID)
        plot_y_min = -0.8
        plot_y_max = 0.8
        
    elif FUNCTION_NAME == "sin-sin":
        dataset = ToyFn_Sin_Noise_Sin(num_samples=N_SAMPLE, seed=1, noise_level=0.1, gpu_id=GPU_ID)
        test_set = ToyFn_Sin_Noise_Sin(num_samples=20, seed=1, noise_level=0.1, gpu_id=GPU_ID)
        plot_y_min = -0.8
        plot_y_max = 0.8

    else:
        raise ValueError('Unknown function name: %s'%(FUNCTION_NAME))
    
    if only_get_training_data:
        
        # Get training data
        X_train = dataset.X.cpu().numpy()
        Y_train = dataset.Y.cpu().numpy()
        
        return X_train, Y_train

    else:
        return pM, fname_log, fname_loss, dataset, test_set, plot_y_min, plot_y_max
    