import matplotlib.pyplot as plt
import numpy as np

def visualize_patches(x, x_mean, means, save_path='patch_visualization'):
    """
    Visualize original and normalized patches for the first sample.
    
    Args:
        x: Tensor of shape (B, patch_num, patch_size) - original patches
        x_mean: Tensor of shape (B, patch_num, patch_size) - normalized patches
        save_path: Base path for saving PDF files (without extension)
    """
    # Convert to numpy arrays for the first sample
    x_np = x[0][0].cpu().detach().numpy()  # (patch_num, patch_size)
    x_mean_np = x_mean[0][0].cpu().detach().numpy()  # (patch_num, patch_size)
    means = means[0][0].cpu().detach().numpy()
    
    # Flatten the patches to reconstruct the original sequence
    original_seq = x_np.flatten()
    normalized_seq = x_mean_np.flatten()
    means_seq = means.flatten()
    
    # Create patch boundaries for vertical lines
    patch_boundaries = np.arange(0, len(original_seq)+0.5)
    patch_boundaries = patch_boundaries[::x_np.shape[1]][:-1]
    
    # Create time axis
    time_axis = np.arange(len(original_seq))
    
    # Figure 1: Both plots in one figure
    plt.figure(figsize=(12, 9))
    
    plt.subplot(3, 1, 1)
    plt.plot(time_axis, original_seq, label='Original')
    for boundary in patch_boundaries:
        plt.axvline(x=boundary, color='r', linestyle='--', alpha=0.5)
    plt.title('Original Sequence with Patch Boundaries')
    plt.xlabel('Time Steps')
    plt.ylabel('Value')
    plt.legend()
    
    plt.subplot(3, 1, 2)
    plt.plot(time_axis, normalized_seq, label='Normalized (x - mean)')
    for boundary in patch_boundaries:
        plt.axvline(x=boundary, color='r', linestyle='--', alpha=0.5)
    plt.title('Normalized Sequence with Patch Boundaries')
    plt.xlabel('Time Steps')
    plt.ylabel('Value')
    plt.legend()

    plt.subplot(3, 1, 3)
    plt.plot(time_axis, means_seq, label='Means')
    for boundary in patch_boundaries:
        plt.axvline(x=boundary, color='r', linestyle='--', alpha=0.5)
    plt.title('Means of Patches')
    plt.xlabel('Time Steps')
    plt.ylabel('Value')
    plt.legend()
    
    plt.tight_layout()
    plt.savefig(f'{save_path}_combined.pdf')
    plt.close()
    
    # Figure 2: Separate plots in two figures
    plt.figure(figsize=(12, 4))
    plt.plot(time_axis, original_seq, label='Original')
    for boundary in patch_boundaries:
        plt.axvline(x=boundary, color='r', linestyle='--', alpha=0.5)
    plt.title('Original Sequence with Patch Boundaries')
    plt.xlabel('Time Steps')
    plt.ylabel('Value')
    plt.legend()
    plt.tight_layout()
    plt.savefig(f'{save_path}_original.pdf')
    plt.close()
    
    plt.figure(figsize=(12, 4))
    plt.plot(time_axis, normalized_seq, label='Normalized (x - mean)')
    for boundary in patch_boundaries:
        plt.axvline(x=boundary, color='r', linestyle='--', alpha=0.5)
    plt.title('Normalized Sequence with Patch Boundaries')
    plt.xlabel('Time Steps')
    plt.ylabel('Value')
    plt.legend()
    plt.tight_layout()
    plt.savefig(f'{save_path}_normalized.pdf')
    plt.close()


