# -*- coding: UTF-8 -*-
import sys
import torch
from torch import nn
from torch.optim import RAdam
from torchvision.utils import save_image
from torch.nn.utils import clip_grad_norm_
from collections import deque
from tensorboardX import SummaryWriter
import random
import argparse
import warp as wp
import numpy as np
from pathlib import Path
from typing import Optional
from natsort import natsorted
from omegaconf import DictConfig, OmegaConf
from tqdm.autonotebook import tqdm, trange

# --- Module Imports ---
from modules.nclaw.utils import mkdir, denormalize_points_helper_func
from modules.nclaw.sim import (
    MPMModelBuilder,
    MPMStaticsInitializerGaussian,
    MPMInitDataGaussian,
    MPMCacheDiffSim
)
from modules.nclaw.material import (
    InvariantFullMetaElasticity,
    InvariantFullMetaPlasticity
)
from modules.d3gs.utils.loss_utils import l1_loss, l2_loss
from modules.d3gs.scene.gaussian_model import GaussianModel
from modules.tune.dataset.neuma_dataset import VideoDataset
from modules.tune.scheduler import fetch_scheduler
from modules.tune.utils import Logger, Timer, get_warp_device, diff_rasterization
from modules.tune.constraint.match import init_extractor
from modules.tune.constraint.preset import *
from modules.tune.lightglue.utils import rbd

# --- Global Constants ---
ASSETS_PATH = Path(__file__).parent / "assets"
EPS = 6e-7
PIXEL_LOSSES = {"l1": l1_loss, "l2": l2_loss}

def parse_args():
    """Parses command-line arguments."""
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", "-c", type=str, required=True, help="Path to the config file.")
    return parser.parse_args()

def process_depth(depth_map, lower_percentile=2.0, upper_percentile=98.0):
    """Normalizes a depth map to [0, 1] for robust loss calculation."""
    nonzero_mask = (depth_map != 0)
    if not torch.any(nonzero_mask):
        return depth_map

    nonzero_values = depth_map[nonzero_mask]
    
    # Use quantiles to find robust min/max, ignoring outliers
    lower_val = torch.quantile(nonzero_values, lower_percentile / 100.0)
    upper_val = torch.quantile(nonzero_values, upper_percentile / 100.0)

    # Clip, normalize, and fill back into a zero-background tensor
    clipped_values = torch.clamp(nonzero_values, min=lower_val, max=upper_val)
    
    if upper_val - lower_val < 1e-6:
        normalized_values = torch.full_like(clipped_values, 0.5)
    else:
        normalized_values = (clipped_values - lower_val) / (upper_val - lower_val)
        
    result = torch.zeros_like(depth_map)
    result[nonzero_mask] = normalized_values
    return result

def optimize_init_velocity(
    cfg: DictConfig,
    gaussians: GaussianModel,
    dataset: VideoDataset,
    dataset_mask: VideoDataset,
    elasticity: InvariantFullMetaElasticity,
    plasticity: InvariantFullMetaPlasticity,
    background: torch.Tensor,
    tune_root: Path,
):
    """Optimizes the initial velocity of particles by comparing rendered images to ground truth."""
    init_ckpt_path = tune_root / 'init.pt'
    if init_ckpt_path.exists():
        print("\nLoading existing initial velocity from checkpoint.")
        init_x_and_v = torch.load(init_ckpt_path, map_location="cpu")
        dataset.set_init_x_and_v(init_x=init_x_and_v['init_x'], init_v=init_x_and_v['init_v'])
        return

    print("\nOptimizing initial velocity...")
    torch.cuda.empty_cache()
    wp_device = get_warp_device(background.device)

    # Setup simulation parameters
    nframes = cfg.velocity.num_frames
    substeps = cfg.velocity.substeps
    nsteps = nframes * substeps
    
    model = MPMModelBuilder().parse_cfg(cfg.sim).finalize(wp_device, requires_grad=True)
    sim = MPMCacheDiffSim(model, nsteps, None)
    
    cfg.particle_data.shape.name = f"{cfg.sim_data_name}/particles"
    init_data = MPMInitDataGaussian.get_from_gaussians(cfg.particle_data, gaussians)

    statics_initializer = MPMStaticsInitializerGaussian(model)
    statics_initializer.add_group(init_data)
    statics = statics_initializer.finalize()
    
    # Setup optimizer for initial velocity
    dataset.set_init_x_and_v(init_x=init_data.pos)
    dataset.init_velocity_optimizer(RAdam, lr=cfg.velocity.lr)
    dataset.init_velocity_scheduler(cfg.velocity.scheduler, init_lr=cfg.velocity.lr)

    pixel_loss = PIXEL_LOSSES.get(cfg.velocity.get("pixel_loss", "l2"), l2_loss)

    for epoch in trange(1, cfg.velocity.num_epochs + 1):
        x, init_v, C, F, _ = dataset.get_init_material_data()
        dataset.getVelocityOptimizer.zero_grad()

        loss_rgb = 0.0
        loss_mask = 0.0
        
        # Run simulation
        for it in range(nsteps):
            stress = elasticity(F)
            x, v, C, F = sim(statics, it, x, v, C, F, stress)
            F = plasticity(F)

            # Compute loss at specified frames
            if (it + 1) % substeps == 0:
                cur_step_idx = (it + 1) // substeps
                cur_frame = dataset.steps[cur_step_idx]
                means3D = denormalize_points_helper_func(x, init_data.size, init_data.center)
                
                for view in sorted(cfg.velocity.views):
                    render_results = diff_rasterization(
                        means3D, F, gaussians, dataset.getCameras(view, cur_frame), background
                    )
                    render = render_results["render"]
                    render_mask = render_results["mask"]
                    gt_image = dataset.getCameras(view, cur_frame).original_image.to(x.device)
                    gt_mask = dataset_mask.getCameras(view, cur_frame).original_image.to(x.device)
                    
                    loss_rgb += pixel_loss(render, gt_image)
                    loss_mask += pixel_loss(render_mask, gt_mask)
        
        # Regularization loss
        loss_reg = cfg.velocity.lambda_reg * (init_v[:, 0].abs().mean() + init_v[:, 2].abs().mean())
        total_loss = loss_rgb + loss_mask + loss_reg
        
        total_loss.backward()
        dataset.getVelocityOptimizer.step()
        dataset.getVelocityScheduler.step()

        tqdm.write(f"Epoch {epoch}/{cfg.velocity.num_epochs} | Loss RGB: {loss_rgb.item():.4e} | Loss Mask: {loss_mask.item():.4e}")

    dataset.freeze_velocity()
    dataset.export_init_x_and_v(init_ckpt_path)
    print(f'\nInitial velocity optimized and saved to {init_ckpt_path}.')


def finetune_constitutive(
    cfg: DictConfig,
    gaussians: GaussianModel,
    dataset: VideoDataset,
    dataset_mask: VideoDataset,
    dataset_depth: Optional[VideoDataset],
    elasticity: InvariantFullMetaElasticity,
    plasticity: InvariantFullMetaPlasticity,
    extractor,
    matcher,
    background: torch.Tensor,
    tune_root: Path,
):
    """Fine-tunes the material models (elasticity and plasticity) using LoRA."""
    print("\nFine-tuning the constitutive models...")
    writer = SummaryWriter(tune_root.parent)
    timer = Timer()

    # Setup simulation
    nsteps = cfg.constitution.num_frames * cfg.constitution.substeps
    wp_device = get_warp_device(background.device)
    model = MPMModelBuilder().parse_cfg(cfg.sim).finalize(wp_device, requires_grad=True)
    sim = MPMCacheDiffSim(model, nsteps, None)
    
    cfg.particle_data.shape.name = f"{cfg.sim_data_name}/particles"
    init_data = MPMInitDataGaussian.get_from_gaussians(cfg.particle_data, gaussians)

    statics_initializer = MPMStaticsInitializerGaussian(model)
    statics_initializer.add_group(init_data)
    statics = statics_initializer.finalize()

    # Setup LoRA and optimizers
    elasticity.init_lora_layers(r=cfg.constitution.lora.r, lora_alpha=cfg.constitution.lora.alpha)
    plasticity.init_lora_layers(r=cfg.constitution.lora.r, lora_alpha=cfg.constitution.lora.alpha)
    elasticity.freeze_all_except_lora()
    plasticity.freeze_all_except_lora()

    e_opt = RAdam(filter(lambda p: p.requires_grad, elasticity.parameters()), lr=cfg.constitution.elasticity_lr)
    p_opt = RAdam(filter(lambda p: p.requires_grad, plasticity.parameters()), lr=cfg.constitution.plasticity_lr)
    e_sch = fetch_scheduler(cfg.constitution.elasticity_scheduler).get_scheduler(e_opt, cfg.constitution.elasticity_lr)
    p_sch = fetch_scheduler(cfg.constitution.plasticity_scheduler).get_scheduler(p_opt, cfg.constitution.plasticity_lr)
    
    pixel_loss = PIXEL_LOSSES.get(cfg.constitution.get("pixel_loss", "l2"), l2_loss)
    accumulation_steps = cfg.constitution.get("accumulation_steps", 1)

    for epoch in trange(1, cfg.constitution.num_epochs + 1):
        e_opt.zero_grad()
        p_opt.zero_grad()
        
        x, v, C, F, _ = dataset.get_init_material_data()
        
        total_loss_rgb = 0.0
        total_loss_match = 0.0
        total_loss_depth = 0.0
        total_loss_e_consistency = 0.0
        
        # Run simulation with gradient accumulation
        chunk_size = nsteps // accumulation_steps
        for acc_step in range(accumulation_steps):
            loss_rgb = torch.tensor(0.0, device=x.device)
            loss_match = torch.tensor(0.0, device=x.device)
            loss_depth = torch.tensor(0.0, device=x.device)
            loss_e_consistency = torch.tensor(0.0, device=x.device)
            loss_preset = torch.tensor(0.0, device=x.device)

            start_it = acc_step * chunk_size
            end_it = (acc_step + 1) * chunk_size if acc_step < accumulation_steps - 1 else nsteps
            
            for it in range(start_it, end_it):
                stress = elasticity(F)
                
                # Physical consistency loss for a subset of particles
                results = solve_e_corotated(F[:4], stress[:4])
                loss_e_consistency += results["loss"] if not math.isnan(results["loss"]) else 1e-8

                x, v, C, F = sim(statics, it, x, v, C, F, stress)
                F = plasticity(F)
                
                # Preset physics priors
                loss_preset += volume_preserve(F) + stress_symmetry(stress) + incompressible_plasticity(F)

                if (it + 1) % cfg.constitution.substeps == 0:
                    cur_step_idx = (it + 1) // cfg.constitution.substeps
                    cur_frame = dataset.steps[cur_step_idx]
                    means3D = denormalize_points_helper_func(x, init_data.size, init_data.center)
                    
                    for view in sorted(cfg.constitution.views):
                        render_results = diff_rasterization(
                            means3D, F, gaussians, dataset.getCameras(view, cur_frame), background
                        )
                        render = render_results["render"]
                        gt_image = dataset.getCameras(view, cur_frame).original_image.to(x.device).detach()
                        
                        # RGB Image Loss
                        loss_rgb += pixel_loss(render, gt_image)
                        
                        # Depth Loss (if depth data is available)
                        if dataset_depth is not None:
                            render_depth = render_results["depth"]
                            gt_depth = dataset_depth.getCameras(view, cur_frame).original_image.to(x.device).detach()
                            loss_depth += pixel_loss(process_depth(render_depth), process_depth(gt_depth))
                        
                        # Feature Matching Loss
                        with torch.no_grad():
                            gt_feats = extractor.extract(gt_image)
                            render_feats = extractor.extract(render)
                            matches = matcher({'image0': gt_feats, 'image1': render_feats})
                            matches = rbd(matches)['matches']
                            points0 = rbd(gt_feats)['keypoints'][matches[..., 0]]
                            points1 = rbd(render_feats)['keypoints'][matches[..., 1]]
                        loss_match += torch.norm(points0 - points1, dim=1).mean()

            # Combine and scale losses for the current chunk
            total_chunk_loss = (
                cfg.constitution.lambda_rgb * loss_rgb +
                cfg.constitution.lambda_match * loss_match +
                cfg.constitution.lambda_depth * loss_depth +
                cfg.constitution.lambda_e_consistency * loss_e_consistency +
                cfg.constitution.lambda_preset * loss_preset
            )
            
            # Backward pass with scaling for accumulation
            if total_chunk_loss.requires_grad:
                scaled_loss = total_chunk_loss / accumulation_steps
                scaled_loss.backward()

            # Detach tensors to free memory for the next chunk
            x, v, C, F = x.detach(), v.detach(), C.detach(), F.detach()
            
            with torch.no_grad():
                total_loss_rgb += loss_rgb.item()
                total_loss_match += loss_match.item()
                total_loss_depth += loss_depth.item()
                total_loss_e_consistency += loss_e_consistency.item()

        # Perform optimizer step after all chunks
        clip_grad_norm_(elasticity.parameters(), max_norm=cfg.constitution.elasticity_grad_max_norm)
        clip_grad_norm_(plasticity.parameters(), max_norm=cfg.constitution.plasticity_grad_max_norm)
        e_opt.step()
        p_opt.step()
        e_sch.step()
        p_sch.step()

        # Logging
        writer.add_scalar('loss/rgb', total_loss_rgb, epoch)
        writer.add_scalar('loss/match', total_loss_match, epoch)
        writer.add_scalar('loss/depth', total_loss_depth, epoch)
        writer.add_scalar('loss/e_consistency', total_loss_e_consistency, epoch)
        
        if epoch % 10 == 0:
            torch.save({
                'elasticity': elasticity.lora_state_dict(),
                'plasticity': plasticity.lora_state_dict(),
            }, tune_root / f'{epoch:04d}_lora.pt')
    
    writer.close()
    print('\nFine-tuning finished.')

def main(cfg: DictConfig):
    """Main function to set up and run the fine-tuning process."""
    # --- 1. Setup Environment ---
    print(OmegaConf.to_yaml(cfg, resolve=True))
    random.seed(cfg.seed)
    np.random.seed(cfg.seed)
    torch.manual_seed(cfg.seed)
    wp.init()
    torch_device = torch.device(f'cuda:{cfg.gpu}')
    
    # --- 2. Setup Paths and Logging ---
    exp_root = Path(cfg.root) / cfg.name
    mkdir(exp_root, resume=cfg.resume, overwrite=cfg.overwrite)
    OmegaConf.save(cfg, exp_root / 'config.yaml', resolve=True)
    sys.stdout = Logger(exp_root / 'log.txt')
    tune_root = exp_root / 'finetune'
    tune_root.mkdir(exist_ok=True)
    data_root = ASSETS_PATH / cfg.sim_data_name
    data_root.mkdir(exist_ok=True)

    # --- 3. Load Datasets ---
    cfg.video_data.device = f"cuda:{cfg.gpu}"
    
    # Load RGB dataset
    cfg.video_data.data.read_mask_only = False
    cfg.video_data.data.read_depth = False # Explicitly disable for RGB
    dataset = VideoDataset(cfg.video_data)
    
    # Load Mask dataset
    cfg.video_data.data.read_mask_only = True
    cfg.video_data.data.read_depth = False # Explicitly disable for Mask
    dataset_mask = VideoDataset(cfg.video_data)
    
    # Load Depth dataset (optional)
    dataset_depth = None
    if cfg.video_data.data.get("has_depth", False):
        cfg.video_data.data.read_mask_only = False
        cfg.video_data.data.read_depth = True
        dataset_depth = VideoDataset(cfg.video_data)
        print("Depth dataset loaded.")

    background = torch.tensor([1, 1, 1] if cfg.video_data.data.get("white_background") else [0, 0, 0], dtype=torch.float32, device=torch_device)

    # --- 4. Load Models ---
    gaussians = GaussianModel(cfg.gaussian.sh_degree)
    gaussians.load_ply(str(data_root / "kernels.ply"), requires_grad=False)
    
    elasticity = InvariantFullMetaElasticity(cfg.constitution.elasticity).to(torch_device).train()
    plasticity = InvariantFullMetaPlasticity(cfg.constitution.plasticity).to(torch_device).train()
    
    # Load pretrained base model weights
    pretrained = torch.load(cfg.pretrained_ckpt, map_location=torch_device)
    elasticity.load_state_dict(pretrained['elasticity'])
    plasticity.load_state_dict(pretrained['plasticity'])
    print(f'Loaded pretrained weights from {cfg.pretrained_ckpt}')

    # Initialize feature matching models
    match_model = init_extractor('superpoint')
    extractor = match_model['extractor'].to(torch_device)
    matcher = match_model['matcher'].to(torch_device)

    # --- 5. Run Optimization Stages ---
    optimize_init_velocity(
        cfg, gaussians, dataset, dataset_mask, elasticity, plasticity, background, tune_root
    )
    
    finetune_constitutive(
        cfg, gaussians, dataset, dataset_mask, dataset_depth,
        elasticity, plasticity, extractor, matcher, background, tune_root
    )

if __name__ == "__main__":
    args = parse_args()
    cfg = OmegaConf.load(args.config)
    main(cfg)