import math

import matplotlib.pyplot as plt
import numpy as np
import torch
import torchdyn
from torchdyn.datasets import generate_moons
from sklearn import datasets

def eight_normal_sample(n, dim, scale=1, var=1):
    m = torch.distributions.multivariate_normal.MultivariateNormal(
        torch.zeros(dim), math.sqrt(var) * torch.eye(dim)
    )
    centers = [
        (1, 0),
        (-1, 0),
        (0, 1),
        (0, -1),
        (1.0 / np.sqrt(2), 1.0 / np.sqrt(2)),
        (1.0 / np.sqrt(2), -1.0 / np.sqrt(2)),
        (-1.0 / np.sqrt(2), 1.0 / np.sqrt(2)),
        (-1.0 / np.sqrt(2), -1.0 / np.sqrt(2)),
    ]
    centers = torch.tensor(centers) * scale
    noise = m.sample((n,))
    multi = torch.multinomial(torch.ones(8), n, replacement=True)
    data = []
    for i in range(n):
        data.append(centers[multi[i]] + noise[i])
    data = torch.stack(data)
    return data


def sample_moons(n):
    x0, _ = generate_moons(n, noise=0.2)
    return x0 * 3 - 1


def sample_8gaussians(n):
    return eight_normal_sample(n, 2, scale=5, var=0.1).float()


class torch_wrapper(torch.nn.Module):
    """Wraps model to torchdyn compatible format."""

    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, t, x, *args, **kwargs):
        return self.model(torch.cat([x, t.repeat(x.shape[0])[:, None]], 1))


def plot_trajectories(traj, method=None, step=None, save_dir=None):
    """Plot trajectories of some selected samples."""
    import os
    
    n = 2000
    plt.figure(figsize=(6, 6))
    plt.scatter(traj[0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c="black")
    plt.scatter(traj[:, :n, 0], traj[:, :n, 1], s=0.2, alpha=0.2, c="olive")
    plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c="blue")
    # plt.legend(["Prior sample z(S)", "Flow", "z(0)"])
    plt.xticks([])
    plt.yticks([])
    plt.axis("off")
    # if method is not None:
    #     plt.title(f"{method} at step {step}")
    # else:
    #     plt.title("Trajectories")
    
    # Determine save path
    if save_dir is not None:
        os.makedirs(save_dir, exist_ok=True)
        save_path = os.path.join(save_dir, f"trajectories_{method}_{step}.png")
    else:
        save_path = f"trajectories_{method}_{step}.png"
    
    plt.savefig(save_path, bbox_inches="tight", pad_inches=0)
    plt.show()


def plot_trajectories_grid(traj_list, method=None, steps=None, figsize=(25, 5), save_dir=None):
    """Plot multiple trajectories in a 1x5 grid with k=1,2,3,4,5 subtitles."""
    import os
    
    n = 2000
    fig, axes = plt.subplots(1, 5, figsize=figsize)
    
    for i, (traj, step) in enumerate(zip(traj_list, steps or range(1, 6))):
        ax = axes[i]
        
        # Plot trajectories
        ax.scatter(traj[0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c="black")
        ax.scatter(traj[:, :n, 0], traj[:, :n, 1], s=0.2, alpha=0.2, c="olive")
        ax.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c="blue")
        
        # Set subplot title
        ax.set_title(f"k = {i+1}")
        ax.set_xticks([])
        ax.set_yticks([])
        
        # Add legend only to the first subplot
        if i == 0:
            ax.legend(["Prior sample z(S)", "Flow", "z(0)"], 
                     bbox_to_anchor=(0, -0.1), loc='upper left', fontsize=8)
    
    # Set main title
    if method is not None:
        fig.suptitle(f"{method} - Training Progress", fontsize=16, y=1.02)
    else:
        fig.suptitle("Trajectories Comparison", fontsize=16, y=1.02)
    
    plt.tight_layout()
    
    # Determine save path
    if save_dir is not None:
        os.makedirs(save_dir, exist_ok=True)
        if method is not None:
            save_path = os.path.join(save_dir, f"trajectories_grid_{method}.png")
        else:
            save_path = os.path.join(save_dir, "trajectories_grid.png")
    else:
        if method is not None:
            save_path = f"trajectories_grid_{method}.png"
        else:
            save_path = "trajectories_grid.png"
    
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()


def sample_normal(n):
    """Sample from a standard 2D normal, scaled to roughly match other datasets."""
    return torch.randn(n, 2)


def sample_scurve(n):
    """Sample a 2D S-curve by projecting the 3D sklearn S-curve onto (x, z) axes and scaling."""
    X, _ = datasets.make_s_curve(n_samples=n, noise=0.1, random_state=None)
    xz = torch.tensor(X)[:, [0, 2]]
    scaling_factor = 7.0
    xz = (xz - xz.mean()) / xz.std() * scaling_factor
    return xz.float()
