import numpy as np
import matplotlib.pyplot as plt
import torch
import os

# Import your dataset class
from utils.dataloader_3d import OptimizedCFDDataset

# Create a simple dataset with a few indices
sample_indices = np.array([132], dtype=np.int32)
dataset = OptimizedCFDDataset(sample_indices)

# Sample a single element (get the first one)
sample_idx = 15
sample_data = dataset[sample_idx]

# Get trajectory and time information
traj_id, time_id = dataset.indices[sample_idx]
print(f"Plotting data from trajectory {traj_id}, time step {time_id}")

# Get the field names
field_names = dataset.fields  # ['Vx', 'Vy', 'Vz', 'density', 'pressure']

num_slices = 3
z_slices = np.linspace(0, sample_data[0][0].shape[0] - 1, num_slices, dtype=int)

# Create a figure with 2*num_slices rows (t, t+1 for each z-slice) and 5 columns (channels)
fig, axes = plt.subplots(2 * num_slices, len(field_names), figsize=(10, 4 * num_slices))
plt.subplots_adjust(wspace=0.05, hspace=0)
plt.suptitle(f"Trajectory {traj_id}, Time {time_id} → {time_id+1}", fontfamily='monospace',fontsize=18)

for s, z_slice in enumerate(z_slices):
    for i, field_name in enumerate(field_names):
        # Get current timestep data (input)
        current_data = sample_data[0][i].numpy()
        # Get next timestep data (target)
        next_data = sample_data[1][i].numpy()

        # Find vmin and vmax for consistent color scale in this column and z-slice
        vmin = min(current_data[z_slice, :, :].min(), next_data[z_slice, :, :].min())
        vmax = max(current_data[z_slice, :, :].max(), next_data[z_slice, :, :].max())

        # Plot current timestep (first row of this z-slice)
        ax_t = axes[2 * s, i]
        im1 = ax_t.imshow(current_data[z_slice, :, :], vmin=vmin, vmax=vmax)
        ax_t.set_xticks([])
        ax_t.set_yticks([])
        ax_t.set_title('')
        ax_t.set_xlabel('')  # No xlabel for t row
        if i == 0:
            ax_t.set_ylabel('t', fontfamily='monospace',fontsize=18)
        else:
            ax_t.set_ylabel('')
        if i == len(field_names) - 1:
            ax_t.set_ylabel(f'z={z_slice+1}', fontfamily='monospace',fontsize=18)
            ax_t.yaxis.set_label_position('right')
            ax_t.yaxis.tick_right()

        # Plot next timestep (second row of this z-slice)
        ax_tp1 = axes[2 * s + 1, i]
        im2 = ax_tp1.imshow(next_data[z_slice, :, :], vmin=vmin, vmax=vmax)
        ax_tp1.set_xticks([])
        ax_tp1.set_yticks([])
        ax_tp1.set_title('')
        # Only set xlabel for the last t+1 row
        if s == num_slices - 1:
            ax_tp1.set_xlabel(field_name, fontfamily='monospace',fontsize=18)
        else:
            ax_tp1.set_xlabel('')
        if i == 0:
            ax_tp1.set_ylabel('t+1', fontfamily='monospace',fontsize=18)
        else:
            ax_tp1.set_ylabel('')
        if i == len(field_names) - 1:
            ax_tp1.set_ylabel(f'z={z_slice+1}', fontfamily='monospace',fontsize=18)
            ax_tp1.yaxis.set_label_position('right')
            ax_tp1.yaxis.tick_right()


plt.tight_layout(rect=[0, 0, 1, 0.97])
plt.savefig('cfd_sample_10x5.png', dpi=300)
plt.show()
