# Fetched from 'uploaded:cdiffusion/datasets/get_dipeptide_data.py'
import argparse
import pathlib
import matplotlib.pyplot as plt
import torch
import numpy as np
from manifolds.MD import Manifold_MD
from utils import set_seed_everywhere
from sde_lib import SDE_Brownian_manifolds
from sampling import SDE_sampler_manifolds_OLLA, SDE_sampler_manifolds_ULLA_EM
from runners.MD_runner import MD_prior_potential

def gen_data(manifold, repeat_num=1, kappa=100., N=100000, T=1000, sigma=1.):
    """
    Generates data by running a Langevin dynamics simulation.
    """
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # Note: This requires a 'dipeptide_ref.npy' file to exist.
    xref = torch.tensor(np.load(f'./data/dipeptide/dipeptide_ref_phi_psiwin.npy')).float()
    prior = MD_prior_potential(xref, kappa=kappa)

    sde = SDE_Brownian_manifolds(sigma_min=sigma, sigma_max=sigma, tau_min = 1.0, tau_max=1.0, N=N, T=T)
    sde.func_b = lambda x: - prior.gradV(x)

    init = xref.unsqueeze(0).repeat(repeat_num, 1, 1).to(device)

    # manifold.epsilon = 0.1
    # sde_kwargs = {
    #     'alpha': 25.0,
    # }
    # _, _, other_dict = SDE_sampler_manifolds_OLLA(sde, manifold, init.reshape(repeat_num, -1),
    #                                               reverse=False, keep_quiet=False, **sde_kwargs)
    
    manifold.epsilon = 0.3
    sde_kwargs = {
        'alpha': 50.0,
        'gamma' : 5.0
    }
    _, _, other_dict = SDE_sampler_manifolds_ULLA_EM(sde, manifold, init.reshape(repeat_num, -1),
                                                  reverse=False, keep_quiet=False, **sde_kwargs)

    return other_dict["x_hist_all"].detach().cpu().reshape(repeat_num, -1, 10, 3).numpy()


def generate_dipeptide_prior_data(kappa=50., psi_windows=[]):
    """
    Generates and saves a dataset based on a prior potential, ensuring all
    final samples satisfy the inequality constraints.
    """
    print(f"\nGenerating prior dataset with kappa={kappa}...")
    set_seed_everywhere(777)
    manifold = Manifold_MD(psi_windows=psi_windows)
    data_dir = pathlib.Path("./data/dipeptide")

    # 1. Generate the full simulation trajectory
    # The output `data_full` will have shape [num_steps, 1, 10, 3]
    data_full = gen_data(manifold, repeat_num=1, kappa=kappa, N=100000, T=50., sigma=1.)
    
    # --- MODIFICATION: Filter trajectory before subsampling ---
    # 2. Check all generated frames against the inequality constraint
    data_tensor = torch.from_numpy(data_full).float().squeeze(0)
    
    # Reshape to [num_frames, 30] to pass to the manifold's g function
    data_flat = data_tensor.reshape(-1, 30)
    num_frames = data_flat.shape[0]
    
    # Get constraint violation values; g(x) <= 0 for valid samples
    g_values = manifold.g(data_flat)
    
    # Create a mask for valid frames
    valid_mask = (g_values <= 0).squeeze()
    
    # 3. Select only the frames that satisfy the constraint
    data_filtered = data_tensor[valid_mask]
    
    num_valid = data_filtered.shape[0]
    print(f"Filtering complete: Kept {num_valid} / {num_frames} frames that satisfy the psi constraints.")
    
    # 4. Subsample from the *filtered* data
    # Note: We take a regular slice from the filtered data, e.g., every 10th sample after a burn-in.
    burn_in = 1000
    if num_valid > burn_in:
        x_prior = data_filtered[burn_in::10].reshape(-1, 10, 3).numpy()
    else:
        print("Warning: Not enough valid samples after burn-in. Using all valid samples.")
        x_prior = data_filtered.reshape(-1, 10, 3).numpy()

    print(f"Final prior dataset contains {x_prior.shape[0]} samples.")
    # --- END OF MODIFICATION ---
    
    output_path = data_dir / f'dipeptide_prior_{int(kappa)}.npy'
    np.save(output_path, x_prior)
    print(f"Saved prior dataset to: {output_path}")

    return



@torch.no_grad()
def refine_dataset_md(sdf, samples, tol=1e-5, max_iter_n=1000, step_size=1e-1, rk4=True, keep_quiet=False):
    """
    Refines a dataset of samples to satisfy an equality constraint g(x)=0
    using gradient-based projection.
    """
    print("refine data set...")

    if isinstance(samples, torch.Tensor):
        # Use double precision for better convergence.
        samples = samples.clone().double()
        device = samples.device
    else:
        samples = torch.tensor(samples, dtype=torch.float64)
        device = torch.device("cpu")

    # If the input tensor is empty, return empty tensors immediately to avoid a RuntimeError.
    if samples.shape[0] == 0:
        if not keep_quiet:
            print("Warning: Received empty tensor in refine_dataset_md. Returning empty tensors.")
        return samples.cpu(), samples.cpu()

    active_idx = torch.arange(0, samples.shape[0], dtype=torch.int64).to(device)

    iter_n = 0
    while iter_n < max_iter_n:
        if len(active_idx) == 0:
            print("break: no active samples remain.")
            break

        # Calculate the constraint violation (residual).
        xi_vals = sdf.constrain_fn(samples[active_idx, :])
        error = torch.abs(xi_vals).squeeze(dim=1)

        # Identify samples that do not meet the tolerance.
        bad_idx_mask = (error >= tol)
        if not torch.any(bad_idx_mask):
            print("break: all active samples converged.")
            break
        else:
            if iter_n % 50 == 0 and not keep_quiet:
                print(f'iter {iter_n}: max_err={torch.max(error):.3e}, {bad_idx_mask.sum()} bad states, tol={tol:.3e}')

        # Only update the samples that violate the constraint.
        active_idx = active_idx[bad_idx_mask]

        if rk4:
            k1 = sdf.constrain_fn(samples[active_idx,:])[:,:,None] * sdf.constrain_grad_fn(samples[active_idx, :])
            tmp = samples[active_idx,:] + k1 * step_size * 0.5
            k2 = sdf.constrain_fn(tmp)[:,:,None] * sdf.constrain_grad_fn(tmp)
            tmp = samples[active_idx,:] + k2 * step_size * 0.5
            k3 = sdf.constrain_fn(tmp)[:,:,None] * sdf.constrain_grad_fn(tmp)
            tmp = samples[active_idx,:] + k3 * step_size
            k4 = sdf.constrain_fn(tmp)[:,:,None] * sdf.constrain_grad_fn(tmp)
            samples[active_idx, :] = samples[active_idx, :] - (k1 + 2*k2 + 2*k3 + k4) * step_size / 6.0
        else:
            samples[active_idx, :] = samples[active_idx, :] - sdf.constrain_fn(samples[active_idx,:])[:,:,None] * sdf.constrain_grad_fn(samples[active_idx, :]) * step_size
        iter_n += 1

    final_xi_vals = sdf.constrain_fn(samples)
    max_error = torch.max(torch.abs(final_xi_vals).squeeze())

    print(f'total steps={iter_n}, final error: {max_error: .3e}.')
    if max_error > tol * 1.1:
        print(f'Warning: tolerance ({tol: .3e}) not reached!')

    final_errors = torch.abs(final_xi_vals).squeeze(dim=-1)
    mask = (final_errors < tol)
    samples_abridged = samples[mask, ...]
    print(f"Refinement complete: {samples_abridged.shape[0]} / {samples.shape[0]} samples met the tolerance.")

    return samples.detach().cpu(), samples_abridged.detach().cpu()


def plot_phi_psi_hist(phi, psi, savefig=""):
    """Generates and saves a histogram of phi and psi angles."""
    fig = plt.figure(figsize=(10, 5))
    print(f"Plotting histogram for {phi.shape[0]} samples.")

    if phi.shape[0] == 0:
        print(f"Warning: No data to plot for '{savefig}'. Skipping histogram generation.")
        plt.close(fig)
        return

    bins = int(phi.shape[0] / 100) if phi.shape[0] >= 100 else 10

    ax = plt.subplot(1, 2, 1)
    ax.hist(phi.reshape(-1).numpy()/ torch.pi * 180, bins=bins, alpha=1.0, density=True)
    ax.set_title("phi")

    ax = plt.subplot(1, 2, 2)
    ax.hist(psi.reshape(-1).numpy()/ torch.pi * 180, bins=bins, alpha=1.0, density=True)
    ax.set_title("psi")

    plt.savefig(f'./datasets/figs/dipeptide/dipeptide_Hist_{savefig}.png', dpi=300)
    plt.close(fig)


if __name__ == "__main__":
    # --- CLI argument parsing ---
    parser = argparse.ArgumentParser(description="Build alanine-dipeptide dataset with hard phi (projection) and psi (filtering) constraints.")
    parser.add_argument("--seed", type=int, default=777, help="Random seed")
    parser.add_argument(
        "--psi_windows",
        type=str,
        default="130,170",
        help="Semicolon-separated degree windows for psi, e.g. '-30,-10;140,160' (lo,hi per window)",
    )
    parser.add_argument("--phi_tol", type=float, default=1e-5, help="Tolerance for equality phi projection")
    parser.add_argument("--phi_step", type=float, default=1e-3, help="Step size for phi projection")
    args = parser.parse_args()

    # Parse the psi_windows string into a list of tuples.
    def _parse_windows(s):
        wins = []
        for chunk in s.split(";"):
            chunk = chunk.strip()
            if not chunk:
                continue
            lo, hi = chunk.split(",")
            wins.append((float(lo), float(hi)))
        return wins
    psi_windows = _parse_windows(args.psi_windows)

    # --- Setup paths and directories ---
    data_dir = pathlib.Path("./data/dipeptide")
    figs_dir = pathlib.Path("./datasets/figs/dipeptide")
    data_dir.mkdir(parents=True, exist_ok=True)
    figs_dir.mkdir(parents=True, exist_ok=True)

    # --- Run data generation pipeline ---
    set_seed_everywhere(args.seed)

    # 0) Load raw data.
    x_raw = torch.tensor(np.load(data_dir / "dipeptide.npy")).float()
    manifold = Manifold_MD(psi_windows=psi_windows)

    # Plot histogram of the raw data.
    phi_raw, psi_raw = manifold.angle_phi(x_raw), manifold.angle_psi(x_raw)
    plot_phi_psi_hist(phi_raw, psi_raw, savefig="ori")

    # Generate the prior dataset with kappa=50.
    generate_dipeptide_prior_data(kappa=50., psi_windows=psi_windows)

    # 1) Apply hard equality constraint on phi via projection.
    eq_manifold = Manifold_MD(psi_windows=[])
    _, x_phi = refine_dataset_md(
        eq_manifold, x_raw,
        tol=args.phi_tol, max_iter_n=10000,
        step_size=args.phi_step, rk4=True, keep_quiet=False
    )
    np.save(data_dir / "dipeptide_refined_phi.npy", x_phi.numpy())
    np.save(data_dir / "dipeptide_ref_phi.npy", x_phi[0].numpy() if x_phi.shape[0] > 0 else np.array([]))

    # Plot histogram after phi refinement.
    phi_refined, psi_refined = manifold.angle_phi(x_phi), manifold.angle_psi(x_phi)
    plot_phi_psi_hist(phi_refined, psi_refined, savefig="refined_phi")

    # 2) Filter the data based on psi angle inequality constraints.
    if x_phi.shape[0] > 0:
        print(f"\nFiltering {x_phi.shape[0]} samples based on psi windows: {psi_windows}")
        
        # --- FIX: Squeeze the tensor to remove the singleton dimension ---
        # The original shape is [N, 1], we need [N] for proper boolean indexing.
        psi_angles_deg = (manifold.angle_psi(x_phi) / torch.pi * 180.0).squeeze()
        # --- End of Fix ---

        # Create a boolean mask, initially all False.
        final_mask = torch.zeros_like(psi_angles_deg, dtype=torch.bool)

        # For each window, find samples that fall within it and set their mask value to True.
        for (low, high) in psi_windows:
            window_mask = (psi_angles_deg >= low) & (psi_angles_deg <= high)
            final_mask = final_mask | window_mask

        x_phi_psi = x_phi[final_mask]
        print(f"Kept {x_phi_psi.shape[0]} samples after psi filtering.")
    else:
        print("\nSkipping psi filtering because the phi-refined dataset is empty.")
        x_phi_psi = x_phi # Keep it as an empty tensor

    # 3) Save the final results.
    if x_phi_psi.shape[0] == 0:
        print("\n[Warning] No samples satisfied the specified phi and psi constraints.")
        print("Saving an empty file for the final dataset.")
        np.save(data_dir / "dipeptide_refined_phi_psiwin.npy", x_phi_psi.numpy())
    else:
        np.save(data_dir / "dipeptide_refined_phi_psiwin.npy", x_phi_psi.numpy())
        np.save(data_dir / "dipeptide_ref_phi_psiwin.npy", x_phi_psi[0].numpy())

        # Plot histogram of the final dataset.
        phi_final, psi_final = manifold.angle_phi(x_phi_psi), manifold.angle_psi(x_phi_psi)
        plot_phi_psi_hist(phi_final, psi_final, savefig="refined_phi_psiwin")

        # 3) Pick psi-well centers FROM THE CONSTRAINED SET
        print("\nFinding and saving well centers...")
        psi_deg = (manifold.angle_psi(x_phi_psi).reshape(-1) / torch.pi * 180.0)

        # Find the sample closest to 150 degrees.
        idx = torch.argmin((psi_deg - (150.0)).abs())
        x_center = x_phi_psi[idx].clone()
        np.save(data_dir / "dipeptide_center.npy", x_center.numpy())
        print(f"  - Saved center_1 (psi ≈ {psi_deg[idx]:.2f} deg)")

        print("\n[Done]")
        print(f" Saved:\n  - {data_dir/'dipeptide_refined_phi_psiwin.npy'}  (constrained dataset)")
        print(f"  - {data_dir/'dipeptide_ref_phi_psiwin.npy'}      (x_ref for this dataset)")
        print(f" Figures in: {figs_dir}")

