#!/usr/bin/env python3
"""
Main WhisperSplat script for embedding a hidden 2D image into a pretrained 3D Gaussian Splatting model.
"""

import os
import sys
import math
import argparse
import torch
import numpy as np
from PIL import Image
from tqdm import tqdm

sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from scene             import Scene, GaussianModel
from gaussian_renderer import render
from arguments         import ModelParams, PipelineParams, OptimizationParams
from utils.loss_utils  import l1_loss, ssim
from utils.image_utils import psnr

# wandb is optional...
try:
    import wandb
    WANDB_AVAILABLE = True
except ImportError:
    WANDB_AVAILABLE = False

parser = argparse.ArgumentParser(
    description="Embed a hidden 2D image into a pretrained 3DGS model using trainable noise."
)
lp = ModelParams(parser);  op = OptimizationParams(parser);  pp = PipelineParams(parser)
parser.add_argument("-v","--view_index",   type=int,   required=True, help="Fixed view index")
parser.add_argument("-t","--target_image", type=str,   required=True, help="Path to target image")
parser.add_argument("--out_noise",        type=str,   required=True, help="Prefix for noise tensors")
parser.add_argument("--out_model",        type=str,   required=True, help="Path for watermarked model")
parser.add_argument("--out_render",       type=str,   required=True, help="Path for best-view PNG")
parser.add_argument("--num_iters",        type=int,   default=500,   help="Number of steps")
parser.add_argument("--lr",               type=float, default=1e-2,  help="Learning rate")
parser.add_argument("--save_renders_dir", type=str,   required=True, help="Dir for per-iter renders")
parser.add_argument("--use_wandb",        action="store_true",      help="Enable wandb logging")
parser.add_argument("--wandb_project",    type=str,   default="3DGS_Steganography", help="wandb project")
parser.add_argument("--wandb_run_name",   type=str,   default=None,   help="wandb run name")
parser.add_argument("--device",          type=str,   default="cuda", help="Torch device")
args = parser.parse_args()

# wandb is optional...
if args.use_wandb:
    if not WANDB_AVAILABLE:
        raise ImportError("pip install wandb to log")
    wandb.init(project=args.wandb_project, name=args.wandb_run_name, config=vars(args))

# extract configurations
dataset_params = lp.extract(args)
opt_params     = op.extract(args)
pipe_params    = pp.extract(args)

# Load model, scene
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
print(f"Loading pretrained 3DGS from '{dataset_params.model_path}'")
gaussians = GaussianModel(dataset_params.sh_degree, opt_params.optimizer_type)
scene     = Scene(dataset_params, gaussians, load_iteration=-1, shuffle=False)

# preserve & freeze original SH weights
orig_param_dc   = gaussians._features_dc
orig_param_rest = gaussians._features_rest
with torch.no_grad():
    orig_dc   = orig_param_dc.detach().clone().to(device)
    orig_rest = orig_param_rest.detach().clone().to(device)
for v in gaussians.__dict__.values():
    if isinstance(v, torch.nn.Parameter):
        v.requires_grad = False

# initialize noise key 
print("Creating noise tensors")
noise_dc   = torch.zeros_like(orig_dc,   requires_grad=True, device=device)
noise_rest = torch.zeros_like(orig_rest, requires_grad=True, device=device)

# target image
print(f"Loading target image '{args.target_image}'")
pil = Image.open(args.target_image).convert("RGB")
cams = scene.getTrainCameras();  sel_cam = cams[args.view_index]
H, W = sel_cam.original_image.shape[1:]
pil = pil.resize((W,H), Image.LANCZOS)
np_img = np.array(pil, dtype=np.float32)/255.0
target_tensor = torch.from_numpy(np_img.transpose(2,0,1))[None].to(device)
if args.use_wandb:
    wandb.log({"target": wandb.Image(np_img, caption="Hidden Target")})

# clean render
print("Computing clean render")
with torch.no_grad():
    gaussians._features_dc   = orig_param_dc
    gaussians._features_rest = orig_param_rest
    clean_out = render(sel_cam, gaussians, pipe_params, torch.ones(3,device=device))
    clean_render = clean_out["render"][None]  


lambda1, lambda2 = 0.5, 1.0   # L1 and SSIM loss weights, can change these 
os.makedirs(args.save_renders_dir, exist_ok=True)
print(f"Saving renders to '{args.save_renders_dir}'")

# adam
optimizer = torch.optim.Adam([noise_dc, noise_rest], lr=args.lr)        # can be changed as well

# optimization loop
print(f"Optimizing for {args.num_iters} iters @ lr={args.lr}")    
best_loss   = float('inf')
best_render = None

for t in tqdm(range(args.num_iters), desc="Iter", ncols=80):
    optimizer.zero_grad()

    # override SH weights
    gaussians._features_dc   = orig_dc   + noise_dc
    gaussians._features_rest = orig_rest + noise_rest

    # render
    out  = render(sel_cam, gaussians, pipe_params, torch.ones(3,device=device))
    pert = out["render"][None]

    # losses
    L_l1   = l1_loss(pert,    target_tensor)
    L_ssim = 1.0 - ssim(pert, target_tensor).mean() # We want SSIM to be high, so we minimize (1-SSIM)

    # GPP cosine decay
    alphs      = 0.5 * (1.0 + math.cos(math.pi * t / args.num_iters)) # Cosine decay from 1.0 to 0.0
    L_pert = -((pert - clean_render).pow(2).mean()) # Negative MSE to encourage perturbation

    # total loss with everything combined
    L_total = lambda1 * L_l1 + lambda2 * L_ssim + alpha * L_pert
    L_total.backward()
    optimizer.step()

    # restore original SH weights
    gaussians._features_dc   = orig_param_dc
    gaussians._features_rest = orig_param_rest

    # save results
    with torch.no_grad():
        psnr_val = psnr(pert, target_tensor).mean().item()
        ssim_val = ssim(pert, target_tensor).mean().item()
    img_np = (pert.clamp(0,1)[0].detach().cpu().numpy().transpose(1,2,0)*255).astype(np.uint8)
    Image.fromarray(img_np).save(os.path.join(args.save_renders_dir, f"iter_{t:04d}.png"))

    # wandb log is optional again...
    if args.use_wandb:
        wandb.log({
            "step":       t,
            "L1_loss":    L_l1.item(),
            "L_ssim":     L_ssim.item(),
            "L_pert":     L_pert.item(),
            "alpha":          alpha,
            "total_loss": L_total.item(),
            "PSNR":       psnr_val,
            "SSIM":       ssim_val
        })

    # best
    if L_l1.item() < best_loss:
        best_loss   = L_l1.item()
        best_render = img_np.copy()

print(f"Best L1 = {best_loss:.4e}")

# Save noise + model with noise present for future use
print("Saving noise & model with key")
torch.save(noise_dc.detach().cpu(),   args.out_noise.replace(".pth","_dc.pth"))
torch.save(noise_rest.detach().cpu(), args.out_noise.replace(".pth","_rest.pth"))
if args.use_wandb:
    wandb.save(args.out_noise.replace(".pth","_dc.pth"))
    wandb.save(args.out_noise.replace(".pth","_rest.pth"))

with torch.no_grad():
    orig_param_dc.data.copy_(orig_dc + noise_dc)
    orig_param_rest.data.copy_(orig_rest + noise_rest)

class _DummyOpt:
    def state_dict(self): return {}
if not hasattr(gaussians,"optimizer") or gaussians.optimizer is None:
    gaussians.optimizer = _DummyOpt()

model_state = gaussians.capture()
torch.save(model_state, args.out_model)
if args.use_wandb:
    wandb.save(args.out_model)

# Save best view
print(f"Saving best render to '{args.out_render}'")
Image.fromarray(best_render).save(args.out_render)
if args.use_wandb:
    wandb.log({"watermarked_view": wandb.Image(best_render)})

if args.use_wandb:
    wandb.finish()

