import numpy as np
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import os


# --- 3. Add Noise Function ---
def add_gaussian_noise(
    images: torch.Tensor,
    mean: float = 0.4,
    std_dev: float = 0.7
) -> torch.Tensor:
    """
    Adds Gaussian noise N(mean, std_dev^2) to images and clips to [0, 1].

    Args:
        images (torch.Tensor): Image tensor (expected values approx [0, 1]).
        mean (float): Mean of the Gaussian noise.
        std_dev (float): Standard deviation of the Gaussian noise.

    Returns:
        torch.Tensor: Images with added noise, clipped to [0, 1].
    """
    print(f"--- Adding Gaussian Noise (Mean={mean}, StdDev={std_dev}) ---")
    if images.shape[0] == 0:
        print("Warning: Cannot add noise to empty tensor.")
        return images

    noise = torch.normal(mean, std_dev, images.shape)
    noisy_images = images + noise
    noisy_images = torch.clamp(noisy_images, 0.0, 1.0) # Clip to valid pixel range [0,1]
    print(f"Noise added to {images.shape[0]} images.")
    print("-" * 30)
    return noisy_images


def visualize_samples(
    images: torch.Tensor,
    save_dir: str = None,
    file_name: str = None,
    num_images: int = 40,
    title: str = "Sample Images",
):
    """
    Visualizes and optionally saves a grid of sample images from a tensor.
    Handles both grayscale (1-channel) and RGB (3-channel) images.

    Args:
        images: A PyTorch tensor of shape (N, C, H, W) containing images to visualize.
            C can be 1 (grayscale) or 3 (RGB).
        save_dir: Path to the directory where the image grid will be saved. If None, image is not saved.
        file_name: Name of the file for saving the visualized image grid. Required if save_dir is provided.
        num_images: Maximum number of images to visualize. Defaults to 40.
        title: Title displayed above the visualization grid. Defaults to "Sample Images".

    Raises:
        ValueError: If images tensor is empty or has incorrect dimensions.
        FileNotFoundError: If the specified save directory does not exist and saving is requested.
    """
    print(f"--- Visualizing Samples: {title} ---")
    
    # Check if images tensor is valid
    if images.numel() == 0:
        print("Cannot visualize samples: Images tensor is empty.")
        print("-" * 30)
        return
    
    if len(images.shape) != 4:
        raise ValueError(f"Expected 4D tensor with shape (batch, channels, height, width), got shape {images.shape}")
    
    num_images = min(num_images, images.shape[0])
    if num_images == 0:
        print("No images to visualize.")
        print("-" * 30)
        return

    # Check the number of channels to determine the image type
    num_channels = images.shape[1]
    if num_channels not in [1, 3]:
        print(f"Warning: Unexpected number of channels: {num_channels}. Expected 1 (grayscale) or 3 (RGB).")

    ncols = 8
    nrows = 5

    fig, axes = plt.subplots(nrows, ncols, figsize=(ncols * 1.5, nrows * 1.5))
    axes = axes.flatten()  # Flatten to easily iterate

    for i in range(num_images):
        img = images[i]
        
        # Handle different channel configurations
        if num_channels == 1:
            # For grayscale, remove channel dimension and use gray colormap
            img = img.squeeze()  # (1, H, W) -> (H, W)
            axes[i].imshow(img.cpu().numpy(), cmap='gray')
        else:
            # For RGB, permute dimensions from (C, H, W) to (H, W, C) for matplotlib
            img = img.permute(1, 2, 0)  # (3, H, W) -> (H, W, 3)
            # Ensure values are in valid range for imshow
            img_np = img.cpu().numpy().clip(0, 1)
            axes[i].imshow(img_np)
            
        axes[i].axis('off')

    # Hide unused subplots
    for j in range(num_images, len(axes)):
        axes[j].axis('off')

    plt.suptitle(title, fontsize=14)
    plt.tight_layout()  # Adjust layout to prevent title overlap
    
    # Save the figure if save_dir and file_name are provided
    if save_dir and file_name:
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
            print(f"Created directory: {save_dir}")
        plt.savefig(os.path.join(save_dir, file_name), dpi=300, bbox_inches='tight')
        print(f"Plot saved to {os.path.join(save_dir, file_name)}")
    
    # plt.show()
    plt.close()
    print("-" * 30)


def visualize_reconstructed_images(
        images1: np.ndarray,
        images2: np.ndarray,
        save_dir: str,
        file_name: str,
        num_to_show: int = 6,
        titles: tuple = ("Original Images", "Reconstructed Images"),
        figsize: tuple = (16, 6),  # Adjusted for horizontal layout
        cmap: str = 'gray'
):
    """
    Displays pairs of original and reconstructed images side-by-side for comparison.
    The function generates a visualization of the provided image arrays, where the 
    top row contains original images and the bottom row contains corresponding reconstructed 
    images. Supports both grayscale (1-channel) and RGB (3-channel) images with channels first format.

    Args:
        images1 (np.ndarray): Array of original images in format (C, H, W).
        images2 (np.ndarray): Array of reconstructed images, must have the same length and format as `images1`.
        save_dir (str): Directory to save the generated visualization.
        file_name (str): Filename for the saved visualization image (including extension).
        num_to_show (int, optional): Number of images to display. Defaults to 4.
        titles (tuple, optional): Titles for the top and bottom rows of images. Defaults to 
            ("Original Images", "Reconstructed Images").
        figsize (tuple, optional): Size of the figure in inches (width, height). Defaults to (16, 6).
        cmap (str, optional): Color map for displaying grayscale images. Defaults to 'gray'.
            Only used for grayscale images.

    Raises:
        ValueError: If `images1` and `images2` have different lengths.
        ValueError: If `num_to_show` is not a positive integer.
    """
    # Validate inputs
    if len(images1) != len(images2):
        raise ValueError(f"Image arrays must have the same length. Got {len(images1)} and {len(images2)}.")

    if num_to_show <= 0:
        raise ValueError("Number of images to show must be positive.")

    # Limit the number of images to show to available images
    num_to_show = min(num_to_show, len(images1))

    # Create a figure with subplots (num_to_show columns, 2 rows - original row and reconstructed row)
    fig, axes = plt.subplots(2, num_to_show, figsize=figsize)

    # Handle case where only one image pair is shown
    if num_to_show == 1:
        axes = axes.reshape(2, 1)

    # Determine if we're working with RGB images by checking the first image's channel count
    # Assuming images are in format (C, H, W)
    sample_img = images1[0]
    num_channels = sample_img.shape[0]
    is_rgb = num_channels == 3
    
    # Define a helper function to prepare images based on their channel count
    def prepare_image(img):
        if is_rgb:
            # For RGB images, transpose from (3, H, W) to (H, W, 3)
            return img.transpose(1, 2, 0)
        else:
            # For grayscale images, squeeze to remove channel dimension
            return img.squeeze(0)
    
    # Display images in two rows
    for i in range(num_to_show):
        # Original image (top row)
        img1_display = prepare_image(images1[i])
        axes[0, i].imshow(img1_display, cmap=None if is_rgb else cmap)
        axes[0, i].set_title(f"{titles[0]} {i + 1}")
        axes[0, i].axis('off')
        
        # Reconstructed image (bottom row)
        img2_display = prepare_image(images2[i])
        axes[1, i].imshow(img2_display, cmap=None if is_rgb else cmap)
        axes[1, i].set_title(f"{titles[1]} {i + 1}")
        axes[1, i].axis('off')

    # Adjust layout to minimize whitespace
    plt.subplots_adjust(wspace=0.1, hspace=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, file_name), dpi=300, bbox_inches='tight')
    # plt.show()
    plt.close()

    # print(f"Displayed {num_to_show} image pairs.")


def plot_errors(err_src, err_trg, title, save_dir, file_name):
    """
    Plots and saves reconstruction error curves for source and target domains.

    Args:
        err_src (list): List of error values for the source domain.
        err_trg (list): List of error values for the target domain.
        title (str): Title of the plot.
        save_dir (str): Directory where the plot will be saved.
        file_name (str): Name of the file for saving the plot.
    """
    epochs = range(1, len(err_src)+1)
    plt.figure(figsize=(10, 5))
    plt.plot(epochs, err_src, label='Source')
    plt.plot(epochs, err_trg, label='Target')
    plt.title(title)
    plt.xlabel('Epoch')
    plt.ylabel('MSE Error')
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, file_name), dpi=300, bbox_inches='tight')
    # plt.show()
    plt.close()

def is_spd(P):
    is_symmetric = torch.allclose(P, P.T, atol=1e-8)

    eigvals_P = torch.linalg.eigvals(P).real
    is_positive_definite = torch.all(eigvals_P > 0)

    return is_symmetric and is_positive_definite


def mapto_SPD_cone(cov, mu, beta, jitter=1e-7):

    beta = torch.tensor(beta).view(1,1).to(mu.device)

    P_11 = cov + (beta * (torch.outer(mu, mu)))

    if mu.dim() == 1:
        mu = mu.unsqueeze(-1)

    top = torch.cat((P_11, mu), dim=1)

    bottom = torch.cat((mu.T, beta), dim=1)
    P = torch.cat((top, bottom), dim=0)

    P = (P + P.T) / 2.0
    P = P + torch.eye(P.size(0)).to(mu.device) * jitter

    assert is_spd(P.clone().detach()), "The P matrix is not SPD!"

    return P