# -*- coding: utf-8 -*-
import os
import torch
from torchvision.utils import save_image
import random
import trimesh
import argparse
import warp as wp
import numpy as np
from pathlib import Path
from omegaconf import DictConfig, OmegaConf
from tqdm.autonotebook import trange

# --- Module Imports ---
from modules.nclaw.utils import denormalize_points_helper_func
from modules.nclaw.sim import (
    MPMModelBuilder,
    MPMForwardSim,
    MPMStateInitializerGaussian,
    MPMStaticsInitializerGaussian,
    MPMInitDataGaussian,
)
from modules.nclaw.material import (
    InvariantFullMetaElasticity,
    InvariantFullMetaPlasticity
)
from modules.d3gs.scene.gaussian_model import GaussianModel
from modules.tune.dataset.neuma_dataset import VideoDataset
from modules.tune.utils import (
    save_video_mediapy,
    diff_rasterization
)

# --- Global Paths ---
ASSETS_PATH = Path(__file__).parent / "assets"
RESULT_PATH = Path("results")

def parse_args():
    """Parses command-line arguments."""
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", "-c", type=str, required=True, help="Path to the config file.")
    parser.add_argument("--eval_steps", "-s", type=int, default=600, help="Number of simulation steps.")
    parser.add_argument("--skip_frames", "-f", type=int, default=1, help="Frame skip rate for video packing.")
    parser.add_argument("--init_frame", type=int, default=None, help="Specific camera frame to load for initialization.")
    parser.add_argument("--load_lora", "-l", type=str, default=None, help="Path to load LoRA weights.")
    parser.add_argument("--remove_images", "-ri", action="store_true", help="Remove image frames after creating the video.")
    parser.add_argument("--video_name", "-vn", type=str, required=True, help="Filename for the saved video.")
    parser.add_argument("--sim_dt", "-dt", type=float, default=None, help="Override simulation time step (dt).")
    parser.add_argument("--debug_views", "-dv", nargs='+', default=[], help="Specific views to render.")
    parser.add_argument("--save_particles", "-sp", type=str, default=None, help="Folder to save simulated particle positions as .ply files.")
    parser.add_argument("--change_base_model", "-cbm", type=str, default=None, help="Path to an alternative base model for rendering.")
    parser.add_argument("--dataset_path", type=str, default=None, help="Override the video dataset path.")
    parser.add_argument("--transform_file", type=str, default=None, help="Override the camera transforms filename (e.g., 'eval_dynamic.json').")
    parser.add_argument("--alpha", type=float, default=None, help="Override the alpha value for the LoRA adapter.")
    return parser.parse_args()


def evaluate(cfg: DictConfig):
    """Main evaluation function to run the physics simulation and render the output."""
    
    # --- 1. Initialization ---
    # Set seeds for reproducibility and initialize Warp and PyTorch devices.
    random.seed(cfg.seed)
    np.random.seed(cfg.seed)
    torch.manual_seed(cfg.seed)

    wp.init()
    wp_device = wp.get_device(f'cuda:{cfg.gpu}')
    torch_device = torch.device(f'cuda:{cfg.gpu}')
    torch.backends.cudnn.benchmark = True
    
    # --- 2. Setup Paths ---
    # Define directories for loading data and saving results.
    exp_root = Path(cfg.root) / cfg.name
    tune_root = exp_root / 'finetune'
    data_root = ASSETS_PATH / cfg.sim_data_name
    image_root = RESULT_PATH / cfg.name / f'images_{cfg.video_name}'
    video_root = RESULT_PATH / cfg.name / 'videos'
    image_root.mkdir(exist_ok=True, parents=True)
    video_root.mkdir(exist_ok=True, parents=True)

    if cfg.save_particles:
        state_root = RESULT_PATH / cfg.name / f'states_{cfg.save_particles}'
        state_root.mkdir(parents=True, exist_ok=True)

    # --- 3. Load Data ---
    # Load the video dataset for camera information and initial 3D Gaussian Splatting model.
    if cfg.dataset_path:
        cfg.video_data.data.path = cfg.dataset_path
    if cfg.transform_file:
        cfg.video_data.data.transformsfile = cfg.transform_file
        
    cfg.video_data.device = f"cuda:{cfg.gpu}"
    cfg.video_data.data.init_frame = cfg.init_frame
    cfg.video_data.data.used_views = cfg.debug_views
    
    dataset = VideoDataset(cfg.video_data)
    first_step_idx = dataset.steps[0]
    
    gaussians = GaussianModel(cfg.gaussian.sh_degree)
    gaussians.load_ply(str(data_root / "kernels.ply"), requires_grad=False)

    # --- 4. Initialize Models and Weights ---
    # Load the base material models (elasticity, plasticity) and optional LoRA adapters.
    elasticity = InvariantFullMetaElasticity(cfg.constitution.elasticity).to(torch_device).eval()
    plasticity = InvariantFullMetaPlasticity(cfg.constitution.plasticity).to(torch_device).eval()

    ckpt_path = cfg.get("change_base_model") or cfg.pretrained_ckpt
    pretrained = torch.load(ckpt_path, map_location=torch_device)
    elasticity.load_state_dict(pretrained['elasticity'])
    plasticity.load_state_dict(pretrained['plasticity'])
    print(f'Loaded base model weights from {ckpt_path}')

    if cfg.get("load_lora"):
        if cfg.alpha is not None:
            cfg.constitution.lora.alpha = cfg.alpha
        lora_path = tune_root / cfg.load_lora
        elasticity.init_lora_layers(r=cfg.constitution.lora.r, lora_alpha=cfg.constitution.lora.alpha)
        plasticity.init_lora_layers(r=cfg.constitution.lora.r, lora_alpha=cfg.constitution.lora.alpha)
        lora_weights = torch.load(lora_path, map_location=torch_device)
        elasticity.load_state_dict(lora_weights['elasticity'], strict=False)
        plasticity.load_state_dict(lora_weights['plasticity'], strict=False)
        print(f'Loaded LoRA weights from {lora_path}')

    # --- 5. Setup Physics Simulation ---
    # Configure the MPM simulator with material properties and initial particle states.
    if cfg.sim_dt:
        cfg.sim.dt = cfg.sim_dt
    
    model = MPMModelBuilder().parse_cfg(cfg.sim).finalize(wp_device, requires_grad=False)
    sim = MPMForwardSim(model)
    
    cfg.particle_data.span = [0, cfg.eval_steps]
    cfg.particle_data.shape.name = f"{cfg.sim_data_name}/particles"
    init_data = MPMInitDataGaussian.get_from_gaussians(cfg.particle_data, gaussians)

    state_initializer = MPMStateInitializerGaussian(model)
    state_initializer.add_group(init_data)
    state, _ = state_initializer.finalize()

    statics_initializer = MPMStaticsInitializerGaussian(model)
    statics_initializer.add_group(init_data)
    statics = statics_initializer.finalize()

    # Set initial velocity and finalize simulation state
    init_v = np.array([0.0, -5.0, 0.0], dtype=np.float32)
    dataset.set_init_x_and_v(init_x=init_data.pos, init_v=init_v)
    x, v, C, F, _ = dataset.get_init_material_data()
    state.from_torch(x=x, v=v, C=C, F=F)

    # --- 6. Run Simulation and Rendering Loop ---
    background = torch.tensor([1, 1, 1] if cfg.video_data.data.get("white_background") else [0, 0, 0], dtype=torch.float32, device=torch_device)
    
    # Render and save the initial frame (t=0)
    for view in cfg.debug_views:
        render_results = diff_rasterization(
            gaussians.get_xyz, None, gaussians,
            dataset.getCameras(view, first_step_idx), background,
            scaling_modifier=cfg.gaussian.get('scaling_modifier', 1.0)
        )
        save_image(render_results['render'], image_root / f"{view}_{first_step_idx:03d}.png")

    # Main simulation loop
    for step in trange(1, cfg.eval_steps + 1):
        # Forward pass through material models and simulator
        stress = elasticity(F)
        state.from_torch(stress=stress)
        x, v, C, F = sim(statics, state)
        F = plasticity(F)
        state.from_torch(F=F)
        
        # Denormalize particle positions for rendering
        means3D = denormalize_points_helper_func(x, init_data.size, init_data.center)
        
        # Render and save the current frame
        for view in cfg.debug_views:
            render_results = diff_rasterization(
                means3D, F, gaussians,
                dataset.getCameras(view, first_step_idx), background,
                scaling_modifier=cfg.gaussian.get('scaling_modifier', 1.0)
            )
            save_image(render_results['render'], image_root / f"{view}_{first_step_idx + step:03d}.png")

        # Save particle data if requested
        if cfg.save_particles:
            point_cloud = trimesh.PointCloud(vertices=x.clone().detach().cpu().numpy())
            point_cloud.export(state_root / f'{first_step_idx + step:03d}.ply')

    # --- 7. Pack Video and Cleanup ---
    # Convert the rendered image sequence into a video file.
    for view in cfg.debug_views:
        save_video_mediapy(
            image_root, f"{view}_*.png",
            video_root / f"{cfg.video_name}_{view}.mp4",
            skip_frames=cfg.skip_frames, fps=30
        )

    if cfg.remove_images:
        os.system(f"rm -rf {image_root}")

if __name__ == "__main__":
    args = parse_args()
    cfg = OmegaConf.load(args.config)
    # Merge command-line arguments into the config
    cfg.merge_with(vars(args))

    with torch.no_grad():
        evaluate(cfg)