import matplotlib.pyplot as plt
import numpy as np
import torch

from src.datasets.bridge_1d import OneDimensionalBridgeDataset


def visualize_dataset_and_method_1d(
    ds,
    method,
    num_steps=1000,
    num_samples=1000,
    y_lims=(-3, 3),
    title="1D Generative Model Visualization",
    n_boxes=100,
):
    plt.figure(figsize=(18, 6))

    plt.subplot(1, 2, 1)

    plt.hist(ds.x0, bins=100, density=True, alpha=0.5, color="red")
    plt.title("Dataset Histogram")
    plt.xlabel("Value")
    plt.ylabel("Density")
    plt.grid()

    x0 = ds.x0

    selected_idx = np.random.randint(0, x0.shape[0])

    y = None
    yi = None
    if isinstance(ds, OneDimensionalBridgeDataset):
        y = ds.y
        yi = y[selected_idx].clone()

    canvas = torch.zeros((num_steps, n_boxes))
    space = torch.linspace(*y_lims, n_boxes)

    ts = torch.linspace(0.001, 0.999, num_steps)

    x_traj = [x0[selected_idx].clone()]
    dt = 0.998 / num_steps

    for idx, t in enumerate(ts):
        xt, _ = method.marginal_sample(x0, t, y=y)
        x_traj.append(
            x_traj[-1]
            + method.f(x_traj[-1], t, y=yi) * dt
            + method.g(t) * torch.randn_like(x_traj[-1]) * dt**0.5
        )

        indices = torch.bucketize(xt, space, right=False)
        indices -= 1

        valid = (indices >= 0) & (indices < len(space) - 1)
        indices = indices[valid]

        boxes = torch.bincount(indices, minlength=len(space)).float()
        boxes /= len(xt)

        canvas[idx, :] = boxes

    plt.subplot(1, 2, 2)
    plt.imshow(
        canvas.T, aspect="auto", extent=[0, 1, *y_lims], origin="lower", cmap="viridis"
    )
    plt.plot(
        ts.numpy(),
        x_traj[1:],
        color="red",
        linewidth=1.5,
        label="Sampled Trajectory",
        alpha=0.5,
    )
    plt.title(title)
    plt.xlabel("Time")
    plt.ylabel("x")
    plt.ylim(y_lims)
    plt.grid()
    plt.show()


def visualize_results_1d(
    ds, pred_x0, traj, y_lims=(-3, 3), n_boxes=100, title="1D Generative Model Results"
):
    plt.figure(figsize=(18, 6))
    plt.subplot(1, 2, 1)

    plt.hist(
        pred_x0.detach().cpu().numpy(),
        bins=100,
        density=True,
        alpha=0.5,
        color="blue",
        label="Sampled Distribution",
    )
    plt.hist(
        ds.x0,
        bins=100,
        density=True,
        alpha=0.5,
        color="red",
        label="Original Distribution",
    )
    plt.title("Sampled vs Original Distribution")
    plt.xlabel("Value")
    plt.ylabel("Density")
    plt.legend()
    plt.grid()

    plt.subplot(1, 2, 2)
    num_steps = len(traj)

    canvas = torch.zeros((num_steps, n_boxes))
    space = torch.linspace(*y_lims, n_boxes)

    ts = torch.linspace(0, 1, num_steps)

    for idx, t in enumerate(ts):
        xt = traj[idx, :, 0].detach().cpu()

        indices = torch.bucketize(xt, space, right=False)
        indices -= 1

        valid = (indices >= 0) & (indices < len(space) - 1)
        indices = indices[valid]

        boxes = torch.bincount(indices, minlength=len(space)).float()
        boxes /= len(xt)

        canvas[idx, :] = boxes

    plt.imshow(
        canvas.T, aspect="auto", extent=[0, 1, *y_lims], origin="lower", cmap="viridis"
    )
    plt.plot(
        ts.numpy(),
        traj[:, 0, 0].detach().cpu().numpy(),
        color="red",
        linewidth=1.5,
        label="Sampled Trajectory",
        alpha=0.5,
    )
    plt.title(title)
    plt.xlabel("Time")
    plt.ylabel("x")
    plt.ylim(y_lims)
    plt.grid()
    plt.show()
