# import igl # work around some env/packaging problems by loading this first

import sys, os, time, math
import time
import argparse
import imageio
import numpy as np
import polyscope.imgui as psim

# Imports from this project
import render, geometry, queries
from kd_tree import *
import implicit_mlp_utils
import time
import cv2
from PIL import Image
# Config

SRC_DIR = os.path.dirname(os.path.realpath(__file__))
ROOT_DIR = os.path.join(SRC_DIR, "..")


def gaussian_blur_preserve_alpha(img: np.ndarray, ksize=(15, 15), sigma=5) -> np.ndarray:
    """
    Apply Gaussian blur to an image with an alpha channel, avoiding bleed from transparent areas.

    Args:
        img (np.ndarray): Input image with shape (H, W, 4) in uint8 format (RGBA).
        ksize (tuple): Kernel size for Gaussian blur.
        sigma (float): Standard deviation for Gaussian blur.

    Returns:
        np.ndarray: Blurred image with preserved alpha, same shape and type as input.
    """
    if img.shape[2] != 4:
        raise ValueError("Input image must have 4 channels (RGBA).")

    # Normalize and split
    rgb = img[..., :3].astype(np.float32) / 255.0
    alpha = img[..., 3].astype(np.float32) / 255.0
    alpha = np.clip(alpha, 1e-5, 1.0)  # Avoid divide-by-zero

    # Premultiply RGB by alpha
    rgb_premul = rgb * alpha[..., None]

    # Apply Gaussian blur
    blurred_rgb = cv2.GaussianBlur(rgb_premul, ksize, sigma)
    blurred_alpha = cv2.GaussianBlur(alpha, ksize, sigma)

    # Un-premultiply
    final_rgb = blurred_rgb / blurred_alpha[..., None]
    final_rgb = np.clip(final_rgb, 0, 1)
    final_alpha = np.clip(blurred_alpha, 0, 1)

    # Combine and convert back to uint8
    final_img = np.dstack((final_rgb, final_alpha[..., None])) * 255
    return final_img.astype(np.uint8)


def save_render_current_view(args, implicit_func, params, cast_frustum, cast_tree_based, opts, matcaps, surf_color):
    # print(jax.devices())
    # root = jnp.array([2., 0., 0.])
    # look = jnp.array([-1., 0., 0.])
    # up = jnp.array([0., 1., 0.])
    # left = jnp.array([0., 0., 1.])
    #
    # root = jnp.array([0., -3., 0.])
    # left = jnp.array([1., 0., 0.])
    # look = jnp.array([0., 1., 0.])
    # up = jnp.array([0., 0., 1.])

    # root = jnp.array([0., 0., 4.])
    # up = jnp.array([0., 1., 0.])
    # look = jnp.array([0., 0., -1.])
    # left = jnp.array([1., 0., 0.])
    # root = jnp.array([0., -1.5, 0.])
    # left = jnp.array([1., 0., 0.])
    # look = jnp.array([0.4, 1., 0.5])
    # up = jnp.array([0., 0., 1.])

    root = torch.tensor([0.24, -0.8, 0.50])
    left = torch.tensor([1., 0., 0.])
    look = torch.tensor([0., 1., 0.])
    up = torch.tensor([0., 0., 1.])

    fov_deg = 45
    # fov_deg = 60.
    # res = args.res // opts['res_scale']
    res = int(args.res // opts['res_scale'])


    surf_color = tuple(surf_color)
    # time_render_start = time.time()
    img, depth, count, _, eval_sum, raycast_time = render.render_image_naive(implicit_func, params, root, look, up,
                                                                             left, res, fov_deg, cast_frustum, opts,
                                                                             shading='matcap_color', matcaps=matcaps,)
                                                                             # shading_color_tuple=(surf_color,), tree_based=cast_tree_based)
    time_render_end = time.time()

    # flip Y
    img = img[::-1, :, :]
    img = np.array(img)
    alpha_channel = (np.min(img, axis=-1) < 1.).astype(np.float32)
    img_alpha = np.concatenate((img, alpha_channel[:, :, None]), axis=-1)
    img_alpha = np.clip(img_alpha, a_min=0., a_max=1.)
    img_alpha = (img_alpha * 255.).astype(np.uint8)
    print(f"Saving image to {args.image_write_path}")
    # img_alpha = cv2.GaussianBlur(img_alpha, (3, 3), 0)
    img_alpha = gaussian_blur_preserve_alpha(img_alpha, ksize=(3, 3), sigma=0)
    img_alpha = cv2.resize(img_alpha, (1024, 1024), interpolation=cv2.INTER_LINEAR)
    Image.fromarray(img_alpha, 'RGBA').convert('RGB').save(args.image_write_path)

def main():
    parser = argparse.ArgumentParser()

    # Build arguments
    parser.add_argument("input", type=str)
    parser.add_argument("--output", type=str)
    parser.add_argument("--log_output", type=str)
    parser.add_argument("--mode", type=str, default='affine_fixed')
    parser.add_argument("--cast_frustum", action='store_true')
    parser.add_argument("--cast_tree_based", action='store_true')
    parser.add_argument("--grid_cam", action='store_true')
    parser.add_argument("--res", type=int, default=1024)

    parser.add_argument("--image_write_path", type=str, default="render_out.png")

    parser.add_argument("--log-compiles", action='store_true')
    parser.add_argument("--disable-jit", action='store_true')
    parser.add_argument("--debug-nans", action='store_true')
    parser.add_argument("--enable-double-precision", action='store_true')

    # Parse arguments
    args = parser.parse_args()

    opts = queries.get_default_cast_opts()
    opts['data_bound'] = 1
    opts['res_scale'] = 1
    opts['tree_max_depth'] = 12
    opts['tree_split_aff'] = False
    opts['hit_eps'] = 1e-4
    cast_frustum = args.cast_frustum
    cast_tree_based = args.cast_tree_based
    mode = args.mode
    modes = ['sdf', 'interval', 'affine_fixed', 'affine_truncate', 'affine_append', 'affine_all', 'slope_interval',
             'crown', 'alpha_crown', 'forward+backward', 'forward', 'forward-optimized', 'dynamic_forward',
             'dynamic_forward+backward', 'affine+backward']
    affine_opts = {}
    affine_opts['affine_n_truncate'] = 8
    affine_opts['affine_n_append'] = 4
    affine_opts['sdf_lipschitz'] = 1.
    affine_opts['crown'] = 1.
    affine_opts['alpha_crown'] = 1.
    affine_opts['forward+backward'] = 1.
    affine_opts['forward'] = 1.
    affine_opts['forward-optimized'] = 1.
    affine_opts['dynamic_forward'] = 1.
    affine_opts['dynamic_forward+backward'] = 1.
    affine_opts['affine+backward'] = 1.
    truncate_policies = ['absolute', 'relative']
    affine_opts['affine_truncate_policy'] = 'absolute'
    surf_color = torch.tensor((0.157, 0.613, 1.000))

    implicit_func, params = implicit_mlp_utils.generate_implicit_from_file(args.input, mode=mode, **affine_opts)

    # load the matcaps
    matcaps = render.load_matcap(os.path.join(ROOT_DIR, "assets", "matcaps", "wax_{}.png"))
    matcaps = torch.stack(matcaps)
    print("matcap shape", matcaps.shape)
    if mode == 'affine_truncate':
        # truncate options
        implicit_func, params = implicit_mlp_utils.generate_implicit_from_file(args.input, mode=mode, **affine_opts)

    elif mode == 'affine_append':
        # truncate options
        implicit_func, params = implicit_mlp_utils.generate_implicit_from_file(args.input, mode=mode, **affine_opts)
    elif mode == 'sdf':
            implicit_func, params = implicit_mlp_utils.generate_implicit_from_file(args.input, mode=mode, **affine_opts)


    save_render_current_view(args, implicit_func, params, cast_frustum, cast_tree_based, opts, matcaps, surf_color)
    save_render_current_view(args, implicit_func, params, cast_frustum, cast_tree_based, opts, matcaps, surf_color)


if __name__ == '__main__':
    main()
