"""
Module for plotting functions used in the project.
"""
import matplotlib.pyplot as plt
import numpy as np


def plot_task_outputs(X, y, sigma_sq, task_name, feature_idx=0, save_path=None):
    """Plot task outputs and uncertainties as a function of one input feature.
    
    Args:
        X (np.ndarray): Input features
        y (np.ndarray): Output values
        sigma_sq (np.ndarray): Noise variances
        task_name (str): Name of the task for plot title
        feature_idx (int): Index of the feature to plot against
        save_path (str, optional): If provided, save the plot to this path
    """
    plt.figure(figsize=(10, 6))
    
    # Sort by the feature we're plotting against
    sort_idx = np.argsort(X[:, feature_idx])
    X_sorted = X[sort_idx, feature_idx]
    y_sorted = y[sort_idx]
    sigma_sorted = np.sqrt(sigma_sq[sort_idx])
    
    # Plot the data points
    plt.scatter(X_sorted, y_sorted, c='b', s=10, label=f'{task_name} output')
    
    # Plot the uncertainty bands
    plt.fill_between(
        X_sorted,
        y_sorted - 2 * sigma_sorted,
        y_sorted + 2 * sigma_sorted,
        color='r',
        alpha=0.2,
        label='95% confidence interval'
    )
    
    plt.xlabel(f'Input feature {feature_idx}')
    plt.ylabel(f'{task_name} output')
    plt.title(f'{task_name} vs Input Feature {feature_idx}')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path)
        plt.close()
    else:
        plt.show()


def plot_task_correlation(y1, y2, task1_name='Task 1', task2_name='Task 2', save_path=None):
    """Plot correlation between two tasks.
    
    Args:
        y1 (np.ndarray): Output values for first task
        y2 (np.ndarray): Output values for second task
        task1_name (str): Name of first task
        task2_name (str): Name of second task
        save_path (str, optional): If provided, save the plot to this path
    """
    plt.figure(figsize=(8, 8))
    
    # Calculate correlation coefficient
    corr = np.corrcoef(y1, y2)[0, 1]
    
    plt.scatter(y1, y2, c='b', s=10, alpha=0.5)
    plt.xlabel(task1_name)
    plt.ylabel(task2_name)
    plt.title(f'Task Correlation (ρ = {corr:.2f})')
    plt.grid(True, alpha=0.3)
    
    if save_path:
        plt.savefig(save_path)
        plt.close()
    else:
        plt.show() 