from matplotlib import pyplot as plt
import torch
import numpy as np


def sample_points(batch_size, dim, s_range=(0, 2), x_range=(-3, 3)):
    """Sample x and s uniformly in the specified range."""
    x = torch.rand(batch_size, dim) * (x_range[1] - x_range[0]) + x_range[0]
    y = torch.rand(batch_size, 1) * (s_range[1] - s_range[0]) + s_range[0]
    return torch.cat([x, y], dim=1)


import matplotlib.pyplot as plt
import numpy as np
import torch
from torchcfm.conditional_flow_matching import *
from torchcfm.models.models import *
from torchcfm.utils import *


import matplotlib.pyplot as plt
import numpy as np
import torch

def plot_variance_heatmaps(vector_field_outputs, wandb, t_eval, saved_models, x_range=(-1, 7), y_range=(-1, 7), grid_size=50):
    """
    Function to compute variance of vector fields across models and plot variance heatmaps.
    Plots a 2x5 grid of heatmaps for each time step in `t_eval`.
    """
    
    # Function to compute variance of vector fields across models
    def compute_variance_across_models(saved_models, t_eval, grid_points):
        variances = {t.item(): [] for t in t_eval}
        model = MLP(dim=2, time_varying=True)
        
        for t_fixed in t_eval:
            all_outputs = []
            
            # Collect all vector field outputs for this time step and for each model
            for model_state_dict in saved_models:
                model.load_state_dict(model_state_dict)
                model.eval()
                with torch.no_grad():
                    grid_input = torch.cat([grid_points, t_fixed.expand(grid_points.size(0), 1)], dim=-1)
                    output = torch.mean(model(grid_input), dim=1)                    
                    all_outputs.append(output)

            all_outputs = np.array(all_outputs)  # Shape: (num_models, num_grid_points, dim)
            variances_at_t = np.var(all_outputs, axis=0)  # Variance across models for each grid point
            variances[t_fixed.item()] = variances_at_t  

        return variances

    # Define grid points for variance calculation
    x_vals = torch.linspace(x_range[0], x_range[1], grid_size)
    y_vals = torch.linspace(y_range[0], y_range[1], grid_size)

    # Create a meshgrid using torch instead of np.meshgrid
    X, Y = torch.meshgrid(x_vals, y_vals)

    # Flatten the grid for easy iteration (this is equivalent to np.ravel)
    grid_points = torch.stack([X.flatten(), Y.flatten()], dim=-1)
    # Compute variances across saved models
    variances = compute_variance_across_models(saved_models, t_eval, grid_points)

    # Plot variance heatmaps for each time step in a 2x5 grid of subplots
    fig, axes = plt.subplots(2, 5, figsize=(20, 6), constrained_layout=True)

    # Flatten axes for easy indexing
    axes = axes.flatten()

    for idx, (t_fixed, var_at_t) in enumerate(variances.items()):
        if idx >= len(axes):  # If there are more time steps than subplots, stop plotting
            break

        # Plot the variance for the current time step in the corresponding subplot
        ax = axes[idx]
        heatmap = ax.imshow(var_at_t.reshape(len(x_vals), len(y_vals)), origin='lower', cmap='viridis', extent=[x_range[0], x_range[1], y_range[0], y_range[1]])
        ax.set_title(f"t={t_fixed:.2f}")
        ax.set_xlabel("x")
        ax.set_ylabel("y")

        # Optionally, add color bar for each subplot
        fig.colorbar(heatmap, ax=ax)

    # Log the figure as a Wandb image
    wandb.log({"variance_heatmap_all": wandb.Image(fig)})

    # Close the figure after logging
    plt.close(fig)




