import argparse
import math
import pathlib
import time
import sys
import os
import imageio
import numpy as np
import torch
import torch.nn.functional as F
import tqdm
from tqdm import trange
import torchvision
from lpips import LPIPS
from hypernet_core import NGPradianceField_tar
from run_nerf_helpers import render_image_with_ngp

from datasets.nerf_vec import SubjectLoader
from nerfacc import OccupancyGrid
try:
    from nerfacc import ContractionType
    GRID_CONTRACTION_TYPE = ContractionType.AABB
except ImportError:
    pass
import random
from torch.utils.tensorboard import SummaryWriter
from pdb import set_trace as bb
def set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

def run(args):
    device = "cuda:0"
    set_random_seed(42)

    # training parameters
    max_steps = 20000
    init_batch_size = 1024
    target_sample_batch_size = 1 << 18
    weight_decay = 0.0
    # scene parameters
    aabb = torch.tensor([-0.7, -0.7, -0.7, 0.7, 0.7, 0.7], device=device)
    near_plane = 2
    far_plane = 6
    # model parameters
    grid_resolution = 96
    # render parameters
    render_step_size = (
        (aabb[3:] - aabb[:3]).max()
        * math.sqrt(3) / 1024
    ).item()


    train_dataset = SubjectLoader(
        subject_id=args.scene,
        root_fp=args.data_root,
        split=args.train_split,
        num_rays=init_batch_size,
        device=device,
    )

    test_dataset = SubjectLoader(
        subject_id=args.scene,
        root_fp=args.data_root,
        split="test",
        num_rays=None,
        device=device,
    )

    occupancy_grid = OccupancyGrid(
            roi_aabb=aabb,
            resolution=96,
            contraction_type=GRID_CONTRACTION_TYPE,
        ) .to(device)

    grad_scaler = torch.cuda.amp.GradScaler(2**10)
    radiance_field = NGPradianceField_tar(aabb=aabb,encoding_size=24,n_hidden_layers=3,n_neurons=64,unbounded=False).to(device)
    optimizer = torch.optim.Adam(
        radiance_field.parameters(),
        lr=1e-2,
        eps=1e-15,
        weight_decay=weight_decay,
    )
    scheduler = torch.optim.lr_scheduler.ChainedScheduler(
        [
            torch.optim.lr_scheduler.LinearLR(
                optimizer, start_factor=0.01, total_iters=100
            ),
            torch.optim.lr_scheduler.MultiStepLR(
                optimizer,
                milestones=[
                    max_steps // 2,
                    max_steps * 3 // 4,
                    max_steps * 9 // 10,
                ],
                gamma=0.33,
            ),
        ]
    )
    lpips_net = LPIPS(net="vgg").to(device)
    lpips_norm_fn = lambda x: x[None, ...].permute(0, 3, 1, 2) * 2 - 1
    lpips_fn = lambda x, y: lpips_net(lpips_norm_fn(x), lpips_norm_fn(y)).mean()

    save_dir = pathlib.Path(args.save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)
    writer = SummaryWriter(log_dir=str(save_dir))
    if args.model_path is not None:
        checkpoint = torch.load(args.model_path)
        radiance_field.load_state_dict(checkpoint)
    if args.grid_weights is not None:
        grid_weights = args.grid_weights
    # training
    tic = time.time()
    for step in trange(max_steps + 1):
        radiance_field.train()
        occupancy_grid.eval()

        i = torch.randint(0, len(train_dataset), (1,)).item()
        data = train_dataset[i]

        render_bkgd = data["color_bkgd"]
        rays = data["rays"]
        pixels = data["pixels"]

        def occ_eval_fn(x):
            density = radiance_field.query_density(x)
            return density * render_step_size

        # update occupancy grid

        # render
        rgb, acc, depth, n_rendering_samples = render_image_with_ngp(
            radiance_field,
            occupancy_grid,
            rays,
            # rendering options
            aabb,
            # near_plane=near_plane,
            render_step_size=render_step_size,
            color_bkgds=render_bkgd,
            grid_weights=args.grid_weights,
        )
        if n_rendering_samples == 0:
            continue

        if target_sample_batch_size > 0:
            # dynamic batch size for rays to keep sample batch size constant.
            num_rays = len(pixels)
            num_rays = int(
                num_rays
                * (target_sample_batch_size / float(n_rendering_samples))
            )
            train_dataset.update_num_rays(num_rays)

            # occupancy_grid.every_n_step(
            #     step=step,
            #     occ_eval_fn=occ_eval_fn,
            #     occ_thre=1e-2
            # )
                    # compute loss
        loss = F.smooth_l1_loss(rgb, pixels)

        optimizer.zero_grad()
        # do not unscale it because we are using Adam.
        grad_scaler.scale(loss).backward()
        optimizer.step()
        scheduler.step()

        if step % 10000 == 0:
            elapsed_time = time.time() - tic
            loss = F.mse_loss(rgb, pixels)
            psnr = -10.0 * torch.log(loss) / np.log(10.0)
            writer.add_scalar("loss-training", loss.item(), step)
            writer.add_scalar("psnr-training", psnr.item(), step)
            print(
                f"elapsed_time={elapsed_time:.2f}s | step={step} | "
                f"loss={loss:.5f} | psnr={psnr:.2f} | "
                f"n_rendering_samples={n_rendering_samples:d} | num_rays={len(pixels):d} | "
                f"max_depth={depth.max():.3f} | "
            )

        if step > 0 and step % max_steps == 0:
        # if step == 0:
            # evaluation
            radiance_field.eval()
            occupancy_grid.eval()

            psnrs = []
            lpips = []
            vis_all = []
            with torch.no_grad():
                for i in tqdm.tqdm(range(len(test_dataset))):
                    data = test_dataset[i]
                    render_bkgd = data["color_bkgd"]
                    rays = data["rays"]
                    pixels = data["pixels"]

                    rgb, acc, depth, n_rendering_samples = render_image_with_ngp(
                        radiance_field,
                        occupancy_grid,
                        rays,
                        # rendering options
                        aabb,
                        # near_plane=near_plane,
                        render_step_size=render_step_size,
                        color_bkgds=render_bkgd,
                        grid_weights=grid_weights,
                    )
                    mse = F.mse_loss(rgb, pixels)
                    psnr = -10.0 * torch.log(mse) / np.log(10.0)
                    psnrs.append(psnr.item())
                    lpips.append(lpips_fn(rgb, pixels).item())
                    pred = rgb.cpu().permute(2, 0, 1)  # [3, H, W] in range [0,1]
                    err = (rgb - pixels).norm(dim=-1).cpu().unsqueeze(0).repeat(3, 1, 1)  # [1, H, W] in range [0,1]
                    gt = pixels.cpu().permute(2, 0, 1)  # [3, H, W] in range [0,1]

                    vis = torchvision.utils.make_grid([pred, err, gt], nrow=3)

                    vis_all.append(vis)
                    
                # Save using torchvision
                torchvision.utils.save_image(
                    torchvision.utils.make_grid(vis_all, nrow=1),  # Normalize to [0,1] range
                    os.path.join(save_dir, f"vis_{step}.png")
                )
            psnr_avg = sum(psnrs) / len(psnrs)
            lpips_avg = sum(lpips) / len(lpips)
            print(f"evaluation: psnr_avg={psnr_avg}, lpips_avg={lpips_avg}")
            writer.add_scalar("psnr-test", psnr_avg, step)
            writer.add_scalar("lpips-test", lpips_avg, step)

