import torch
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import os
import pickle

def visualize_tensor_3d(tensor: torch.Tensor, filename: str = 'tensor_plot.png', title: str = None, z_lim: float = None):
    """
    Visualizes a 2D or N-D PyTorch tensor as a 3D surface plot and saves it to a file.

    Args:
        tensor (torch.Tensor): The input tensor to visualize.
        filename (str, optional): The path to save the plot image. Defaults to 'tensor_plot.png'.
        title (str, optional): The title for the plot. Defaults to None.
        z_lim (float, optional): The maximum value for the Z-axis. If None, it's determined automatically.
    """
    if not isinstance(tensor, torch.Tensor):
        print("Error: Input must be a PyTorch tensor.")
        return

    # Ensure the directory exists
    output_dir = os.path.dirname(filename)
    if output_dir:
        os.makedirs(output_dir, exist_ok=True)

    # Detach tensor from graph, move to CPU, and convert to NumPy array
    tensor_np = tensor.detach().cpu().numpy()

    # If tensor is N-D (N > 2), slice the last two dimensions
    if tensor_np.ndim > 2:
        slicer = [0] * (tensor_np.ndim - 2) + [slice(None), slice(None)]
        tensor_2d = tensor_np[tuple(slicer)]
        if title is None:
            title = f"Slice [0,0,...,:,:] of a {tensor_np.ndim}D tensor"
    else:
        tensor_2d = tensor_np
        if title is None:
            title = f"2D Tensor"
    
    # Take the absolute value for magnitude visualization
    tensor_2d = np.abs(tensor_2d)
    
    # Create meshgrid for plotting
    H, W = tensor_2d.shape
    x = np.arange(W)
    y = np.arange(H)
    X, Y = np.meshgrid(x, y)

    # Plotting
    fig = plt.figure(figsize=(12, 8))
    ax = fig.add_subplot(111, projection='3d')
    surf = ax.plot_surface(X, Y, tensor_2d, cmap='viridis', edgecolor='none')
    
    ax.set_xlabel('Columns')
    ax.set_ylabel('Rows')
    ax.set_zlabel('Magnitude')
    ax.set_title(title)
    
    # Set the Z-axis limit if provided
    if z_lim is not None:
        ax.set_zlim(0, z_lim)
        
    fig.colorbar(surf, shrink=0.5, aspect=5)
    
    # Save the plot
    try:
        plt.savefig(filename)
        print(f"Tensor visualization saved to '{filename}'")
    except Exception as e:
        print(f"Error saving plot: {e}")
    finally:
        # Close the figure to free up memory
        plt.close(fig)

def save_tensor(tensor: torch.Tensor, filename: str = 'saved_tensor.pkl'):
    """
    Saves a PyTorch tensor to a file using pickle.

    Args:
        tensor (torch.Tensor): The input tensor to save.
        filename (str, optional): The path to save the tensor file. Defaults to 'saved_tensor.pkl'.
    """
    if not isinstance(tensor, torch.Tensor):
        print("Error: Input must be a PyTorch tensor.")
        return

    # Ensure the directory exists
    output_dir = os.path.dirname(filename)
    if output_dir:
        os.makedirs(output_dir, exist_ok=True)

    # Detach tensor from graph and move to CPU before saving
    tensor_to_save = tensor.detach().cpu()

    try:
        with open(filename, 'wb') as f:
            pickle.dump(tensor_to_save, f)
        print(f"Tensor successfully saved to '{filename}'")
    except Exception as e:
        print(f"Error saving tensor: {e}")


# --- PDB-friendly alias ---
# In your pdb session, you can simply call `viz(your_tensor)`
viz = visualize_tensor_3d 
save = save_tensor 