import os
import site
import time
import jax
import jax.numpy as jnp
from flax import linen as nn
import optax
import numpy as np
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
import matplotlib.animation as animation

# --- CONFIGURATION ---
DATA_PATH = "demo_data_sand.pt" 
OUTPUT_VIDEO = "giorom_stochastic_demo.mp4"
SPARSITY_FACTOR = 20
BATCH_SIZE = 4
LEARNING_RATE = 1e-3
EPOCHS = 50  
GRID_RES = 48

# ==========================================
# 1. GIOROM MODEL
# ==========================================
class GIOROM_online(nn.Module):
    output_dim: int
    grid_res: int = 64
    dtype: jnp.dtype = jnp.float32 

    @nn.compact
    def __call__(self, x_query_ref, x_source_ref, x_source_curr):
        x_source_ref = jnp.clip(x_source_ref, 0.0, 1.0)
        x_query_ref = jnp.clip(x_query_ref, 0.0, 1.0)

        # 1. Encode
        u_source = x_source_curr - x_source_ref
        f_trans = nn.Dense(3, dtype=self.dtype)(u_source)
        f_trans = nn.gelu(f_trans)
        
        density_tag = jnp.ones((f_trans.shape[0], 1), dtype=self.dtype)
        content = jnp.concatenate([f_trans, density_tag, u_source], axis=-1)
        n_channels = content.shape[-1]
        
        # 2. Splat (Branchless)
        grid_shape = (self.grid_res, self.grid_res, self.grid_res, n_channels)
        flat_grid = jnp.zeros((self.grid_res**3, n_channels), dtype=self.dtype)
        
        coords = x_source_ref * (self.grid_res - 1)
        coords = jnp.clip(coords, 0.0, self.grid_res - 1.001)
        base_idx = jnp.floor(coords).astype(jnp.int32)
        d = coords - base_idx
        
        offsets = jnp.array([[0,0,0], [0,0,1], [0,1,0], [0,1,1],
                             [1,0,0], [1,0,1], [1,1,0], [1,1,1]])
        
        for i in range(8):
            off = offsets[i]
            o_x, o_y, o_z = [off[k].astype(self.dtype) for k in range(3)]
            w_x = d[:, 0] * o_x + (1.0 - d[:, 0]) * (1.0 - o_x)
            w_y = d[:, 1] * o_y + (1.0 - d[:, 1]) * (1.0 - o_y)
            w_z = d[:, 2] * o_z + (1.0 - d[:, 2]) * (1.0 - o_z)
            weight = (w_x * w_y * w_z)[:, None]
            
            idx_x = base_idx[:, 0] + off[0]
            idx_y = base_idx[:, 1] + off[1]
            idx_z = base_idx[:, 2] + off[2]
            flat_indices = idx_x * self.grid_res**2 + idx_y * self.grid_res + idx_z
            flat_grid = flat_grid.at[flat_indices].add(content * weight)

        grid_raw = flat_grid.reshape(grid_shape)

        # 3. Stochastic Decode
        q_coords = x_query_ref * (self.grid_res - 1)
        rng = self.make_rng('stochastic')
        sigma = 0.8
        K_paths = 8
        noise = jax.random.normal(rng, (K_paths, q_coords.shape[0], 3), dtype=self.dtype) * sigma

        def look_up(noise_vector):
            pts = q_coords + noise_vector
            grid_transposed = jnp.moveaxis(grid_raw, -1, 0)
            def sample_channel(vol):
                return jax.scipy.ndimage.map_coordinates(vol, pts.T, order=1, mode='nearest')
            return jax.vmap(sample_channel)(grid_transposed).T
            
        samples = jax.vmap(look_up)(noise)
        raw_sampled = jnp.mean(samples, axis=0)

        grid_f = raw_sampled[..., :-4]
        grid_d = raw_sampled[..., -4:-3]
        grid_u = raw_sampled[..., -3:]
        denom = jnp.maximum(grid_d, 1e-5)
        mask_dec = (grid_d > 1e-5).astype(self.dtype)
        f_norm = (grid_f / denom) * mask_dec
        u_norm = (grid_u / denom) * mask_dec
        
        def get_pe(x, L=3):
            out = []
            for i in range(L):
                freq = 2.0**i * jnp.pi
                out.append(jnp.sin(freq * x))
                out.append(jnp.cos(freq * x))
            return jnp.concatenate(out, axis=-1)

        pe = get_pe(x_query_ref).astype(self.dtype)
        decoder_input = jnp.concatenate([f_norm, u_norm, pe], axis=-1)
        
        x = nn.Dense(64, dtype=self.dtype)(decoder_input)
        x = nn.gelu(x)
        x = nn.Dense(64, dtype=self.dtype)(x)
        x = nn.gelu(x)
        residual_drift = nn.Dense(self.output_dim, dtype=self.dtype)(x)
        pred = jnp.clip(u_norm + residual_drift, 0.001, 0.999)
        
        return pred

# ==========================================
# 2. UTILITIES (VIDEO ADDED HERE)
# ==========================================
def load_data(path):
    print(f"Loading {path}...")
    data_raw = torch.load(path, map_location='cpu', weights_only=False)['position']
    if isinstance(data_raw, list):
        tensor_list = []
        for item in data_raw:
            if isinstance(item, np.ndarray): tensor_list.append(torch.from_numpy(item))
            else: tensor_list.append(item)
        sizes = [t.shape[0] for t in tensor_list]
        min_size = min(sizes)
        if max(sizes) > min_size: tensor_list = [t[:min_size] for t in tensor_list]
        data_raw = torch.stack(tensor_list)
    if data_raw.shape[1] > 5000: data_raw = data_raw.permute(0, 2, 1, 3)
    raw_np = data_raw.numpy()
    dmin, dmax = raw_np.min(), raw_np.max()
    norm = (raw_np - dmin) / (dmax - dmin)
    return norm, dmin, dmax

def chamfer_distance(pred, gt, samples=2048):
    N = pred.shape[0]
    if samples > 0 and N > samples:
        idx = np.random.choice(N, samples, replace=False)
        pred, gt = pred[idx], gt[idx]
    x, y = pred[:, None, :], gt[None, :, :]
    dist_sq = jnp.sum((x - y) ** 2, axis=-1)
    return jnp.mean(jnp.min(dist_sq, axis=1)) + jnp.mean(jnp.min(dist_sq, axis=0))

def generate_2d_video_stoc(params, video_traj, dmin, dmax, model, sparsity_factor, grid_res):
    frames_to_render = min(len(video_traj), 200) 
    print(f"\n--- Generating 2D Video (Stochastic) ---")
    
    video_key = jax.random.PRNGKey(100)

    @jax.jit
    def inference_step(x_q_ref, x_s_ref, x_s_curr, rng_key): 
        # FIX: Unpack tuple (pred, metric) -> return only pred
        pred = model.apply(params, x_q_ref, x_s_ref, x_s_curr,
                           rngs={'stochastic': rng_key})
        return pred

    x_d_ref = jnp.array(video_traj[0])
    x_s_ref = x_d_ref[::int(sparsity_factor)]
    
    fig = plt.figure(figsize=(18, 7), facecolor='#0f0f0f')
    gs = fig.add_gridspec(2, 3, height_ratios=[6, 1], hspace=0.1)
    
    titles = [f"Sparse Input (1/{sparsity_factor})", "Ground Truth", "GIOROM (Stochastic)"]
    colors = ['#00FFFF', '#00FF00', '#FF00FF']
    axes, scatters = [], []
    
    for i in range(3):
        ax = fig.add_subplot(gs[0, i])
        ax.set_facecolor('black')
        ax.set_aspect('equal')
        ax.set_xticks([]); ax.set_yticks([])
        ax.set_xlim(0, 1); ax.set_ylim(0, 1)
        ax.set_title(titles[i], color=colors[i], fontsize=14, fontweight='bold')
        axes.append(ax)
        sc = ax.scatter([], [], s=1.0, alpha=0.6, c=colors[i], edgecolors='none')
        scatters.append(sc)

    text_ax = fig.add_subplot(gs[1, :])
    text_ax.set_facecolor('#0f0f0f'); text_ax.axis('off')
    hud_text = text_ax.text(0.5, 0.5, "Init...", ha='center', va='center', color='white', fontfamily='monospace', fontsize=12)

    def denorm(x): return x * (dmax - dmin) + dmin
    key_container = [video_key]

    def update(i):
        x_d_curr = jnp.array(video_traj[i])
        x_s_curr = x_d_curr[::int(sparsity_factor)]
        
        current_key, next_key = jax.random.split(key_container[0])
        key_container[0] = next_key
        
        pred = inference_step(x_d_ref, x_s_ref, x_s_curr, current_key)
        pred.block_until_ready()
        
        mse = jnp.mean((denorm(x_d_curr) - denorm(pred))**2)
        
        # Visualize only the first 2 dimensions (X, Y)
        s_np = np.array(x_s_curr, dtype=np.float32)
        d_np = np.array(x_d_curr, dtype=np.float32)
        p_np = np.array(pred, dtype=np.float32)
        
        scatters[0].set_offsets(s_np[:, :2])
        scatters[1].set_offsets(d_np[:, :2])
        scatters[2].set_offsets(p_np[:, :2])
        
        hud_text.set_text(f"Frame: {i:03d} | MSE: {mse:.2e}")
        return scatters + [hud_text]

    print(f"Rendering to {OUTPUT_VIDEO}...")
    ani = animation.FuncAnimation(fig, update, frames=tqdm(range(frames_to_render)), interval=1, blit=True)
    ani.save(OUTPUT_VIDEO, fps=30, dpi=100, writer='ffmpeg')
    plt.close(fig)
    print("Video Done.")

# ==========================================
# 3. MAIN EXECUTION
# ==========================================
def main():
    print(f"JAX Device: {jax.devices()[0]}")
    data, dmin, dmax = load_data(DATA_PATH)
    
    train_traj = data[0:1] 
    eval_traj = data[1:2] if data.shape[0] > 1 else data[0:1]

    model = GIOROM_online(3, grid_res=GRID_RES)
    key = jax.random.PRNGKey(0)
    
    x_d_ref = jnp.array(train_traj[0, 0])
    x_s_ref = x_d_ref[::int(SPARSITY_FACTOR)]
    
    key_params, key_init = jax.random.split(key)
    params = model.init({'params': key_params, 'stochastic': key_init}, x_d_ref, x_s_ref, x_s_ref)
    
    tx = optax.adam(LEARNING_RATE)
    opt_state = tx.init(params)

    @jax.jit
    def train_step(state, params, x_q, x_s, x_d_t, x_s_t, k):
        def loss_fn(p):
            # Model returns (pred, metric), we just need pred for loss
            pred = model.apply(p, x_q, x_s, x_s_t, rngs={'stochastic': k})
            return jnp.mean((pred - x_d_t)**2)
        loss, grads = jax.value_and_grad(loss_fn)(params)
        updates, new_state = tx.update(grads, state, params)
        return optax.apply_updates(params, updates), new_state, loss

    print(f"Training on 1 Trajectory ({train_traj.shape[1]} frames)...")
    n_frames = train_traj.shape[1]
    
    for epoch in range(EPOCHS):
        indices = np.random.permutation(n_frames)
        avg_loss = 0
        steps = 0
        for i in range(0, n_frames, BATCH_SIZE):
            batch_idx = indices[i:i+BATCH_SIZE]
            if len(batch_idx) < BATCH_SIZE: continue
            key, subkey = jax.random.split(key)
            x_d_t = jnp.array(train_traj[0, batch_idx]) 
            x_s_t = x_d_t[:, ::int(SPARSITY_FACTOR)]
            idx = np.random.choice(batch_idx)
            x_curr = jnp.array(train_traj[0, idx])
            x_s_curr = x_curr[::int(SPARSITY_FACTOR)]
            params, opt_state, loss = train_step(opt_state, params, x_d_ref, x_s_ref, x_curr, x_s_curr, subkey)
            avg_loss += loss.item()
            steps += 1
        if epoch % 10 == 0:
            print(f"Epoch {epoch}: Loss {avg_loss/steps:.6f}")

    print("\n--- Evaluating... ---")
    @jax.jit
    def eval_step(x_q, x_s, x_s_t, k):
        return model.apply(params, x_q, x_s, x_s_t, rngs={'stochastic': k})

    eval_frames = eval_traj.shape[1]
    ref_frame = jnp.array(eval_traj[0, 0])
    sparse_ref = ref_frame[::int(SPARSITY_FACTOR)]
    
    l2_errors, chamfers, times, metrics = [], [], [], []
    key_eval = jax.random.PRNGKey(42)
    preds = []
    for i in tqdm(range(eval_frames)):
        gt_frame = jnp.array(eval_traj[0, i])
        sparse_curr = gt_frame[::int(SPARSITY_FACTOR)]
        key_eval, subkey = jax.random.split(key_eval)
        
        t0 = time.perf_counter()
        pred = eval_step(ref_frame, sparse_ref, sparse_curr, subkey)
        pred.block_until_ready()
        t1 = time.perf_counter()
        
        
        pred_phys = pred * (dmax - dmin) + dmin
        if i == 0:
            preds.append(pred_phys)
        else:
            pred_phys = 0.2 * pred_phys + 0.8 * preds[-1]
            preds.append(pred_phys)
        gt_phys = gt_frame * (dmax - dmin) + dmin
        
        err = jnp.linalg.norm(pred_phys - gt_phys)
        norm = jnp.linalg.norm(gt_phys)
        l2_errors.append(err / (norm + 1e-8))
        chamfers.append(chamfer_distance(pred_phys, gt_phys).item())
        times.append((t1 - t0) * 1000)

    print("\n============================================")
    print("        GIOROM DEMO RESULTS")
    print("============================================")
    print(f"Relative L2:       {np.mean(l2_errors):.2%} ± {np.std(l2_errors):.2%}")
    print(f"Chamfer Dist:      {np.mean(chamfers):.4e}")
    # Using the metric name from your text paragraph
    print(f"Inference Time:    {np.mean(times):.2f} ms")
    print(f"VRAM Usage:        {jax.devices()[0].memory_stats()['bytes_in_use'] / 1024**2:.1f} MB")
    print("============================================")

    # --- GENERATE VIDEO ---
    # We pass the full numpy array of the evaluation trajectory
    video_data = np.array(eval_traj[0]) # Shape: (Time, N, 3)
    generate_2d_video_stoc(params, video_data, dmin, dmax, model, SPARSITY_FACTOR, GRID_RES)

if __name__ == "__main__":
    main()