import torch
import matplotlib.pyplot as plt



def draw_grid(fuc, title, save_path, vmin=None, vmax=None, save_data=None):
    x_grid = torch.arange(-3, 3, 1e-2).to('cuda')
    t_grid = torch.arange(0, 1, 1e-2).to('cuda')
    density_grid = torch.zeros([len(t_grid), len(x_grid)]).to('cuda')
    for i, t in enumerate(t_grid):
        density_grid[i] = fuc(x_grid, t.unsqueeze(-1).repeat(len(x_grid)).unsqueeze(-1))
    if save_data is not None:
        torch.save(density_grid, save_data)
    plt.figure()
    x_grid, t_grid, density_grid = x_grid.detach().cpu().numpy(), t_grid.detach().cpu().numpy(), density_grid.detach().cpu().numpy()
    plt.xlim(t_grid[0], t_grid[-1])
    plt.ylim(x_grid[0], x_grid[-1])
    plt.imshow(density_grid.T, extent=[t_grid[0], t_grid[-1], x_grid[0], x_grid[-1]], aspect='auto', origin='lower',
               cmap='viridis', vmax=vmax, vmin=vmin)
    # plt.title(title, fontsize=15)
    plt.xticks(fontsize=15)
    plt.xlabel(r'$t$', fontsize=15)
    plt.ylabel(r"$x_t$", fontsize=15)
    plt.yticks(fontsize=15)
    plt.rcParams['ytick.labelsize'] = 12
    plt.colorbar()
    plt.tight_layout()
    # plt.savefig(save_path)
    plt.show()
    plt.close()

def visualize_batch(samples, num_samples=20, rows=4, cols=5):
        from matplotlib import pyplot as plt
        fig, axes = plt.subplots(rows, cols, figsize=(cols * 3, rows * 3))
        axes = axes.flatten()
        for i in range(num_samples):
            if i >= len(samples): pass
            else:
                ax = axes[i]
                y = samples[i]
                y = y.squeeze().numpy()  # (nx, nt)

                ax.imshow(y)

                ax.set_xlabel("Position (y)")
                ax.set_ylabel("u(y, t)")

        plt.tight_layout()
        plt.show()

def calculate_metrics(ref, gen):
    stats = {
        'mean': [ref.mean(0), gen.mean(0)],
        'std': [ref.std(0), gen.std(0)],
    }
    for s, (ref_s, gen_s) in stats.items():
        mse = (ref_s - gen_s).pow(2).mean()
        print(f'{s} MSE: {mse:.6f}')