# import igl # work around some env/packaging problems by loading this first

# import sys, os, time, math
# os.environ['OptiX_INSTALL_DIR'] = '/home/ /Documents/NVIDIA-OptiX-SDK-8.0.0-linux64-x86_64'

import time
import gc
import argparse
import warnings

import torch
import os

# Imports from this project
import render, geometry, queries
from kd_tree import *
import implicit_mlp_utils
import matplotlib.pyplot as plt
from matplotlib.colors import PowerNorm, LogNorm, SymLogNorm, Normalize
import matplotlib.colors as mcolors

import imageio
# import jax.numpy as jnp
import trimesh
from PIL import Image
from matplotlib.ticker import MaxNLocator
from scipy.ndimage import convolve

os.environ['OptiX_INSTALL_DIR'] = '/home/ /Documents/NVIDIA-OptiX-SDK-8.0.0-linux64-x86_64'
# os.environ['OptiX_INSTALL_DIR'] = '/media/  /b5df3483-c11a-42f1-b414-023f33bc5312/home/ /Documents/NVIDIA-OptiX-SDK-8.0.0-linux64-x86_64'

from triro.ray.ray_optix import RayMeshIntersector  # FIXME: Should be uncommented when rendering meshes


# Config

SRC_DIR = os.path.dirname(os.path.realpath(__file__))
ROOT_DIR = os.path.join(SRC_DIR, "..")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.set_default_tensor_type(torch.cuda.FloatTensor)


def get_count(args, implicit_func, params, load_from, opts, matcaps):
    # root = torch.tensor([-2.5, 0., 0.]) #+ torch.ones(3)
    # look = torch.tensor([1., 0., 0.])
    # up = torch.tensor([0., 1., 0.])
    # left = torch.tensor([0., 0., 1.])
    #
    # root = torch.tensor([0., -1.5, 0.])
    # left = torch.tensor([1., 0., 0.])
    # look = torch.tensor([0.4, 1., 0.5])
    # up = torch.tensor([0., 0., 1.])

    root = torch.tensor([0., -3.5, 0.])
    left = torch.tensor([1., 0., 0.])
    look = torch.tensor([0., 1., 0.])
    up = torch.tensor([0., 0., 1.])

    # root = torch.tensor([2.5, 0., 2.5])
    # up = torch.tensor([0., 1., 0.])
    # look = torch.tensor([-1., 0., -1.])
    # left = torch.tensor([1., 0., 0.])
    #
    root = torch.tensor([0., 0., 3.5])
    up = torch.tensor([0., 1., 0.])
    look = torch.tensor([0., 0., -1.])
    left = torch.tensor([1., 0., 0.])

    fov_deg = 30
    res = args.res // opts['res_scale']

    mesh = trimesh.load(load_from)

    intersector = RayMeshIntersector(vertices=torch.tensor(mesh.vertices), faces=torch.tensor(mesh.faces))

    img, rendering_time, counts = render.render_image_mesh(implicit_func, params, intersector, root, look,
                                                   up, left, res,
                                                   fov_deg, opts,
                                                   shading='matcap_color', matcaps=matcaps, approx=False,
                                                   shading_color_tuple=torch.tensor(((0., 0.5, 0.),)))

    del intersector
    del mesh
    gc.collect()
    torch.cuda.empty_cache()
    return rendering_time, counts.detach().cpu().numpy().reshape(res, res)

def get_count_baseline(args, implicit_func, params, opts, matcaps):
    # root = torch.tensor([0., -3.5, 0.])
    # left = torch.tensor([1., 0., 0.])
    # look = torch.tensor([0., 1., 0.])
    # up = torch.tensor([0., 0., 1.])

    root = torch.tensor([0., 0., 3.5])
    up = torch.tensor([0., 1., 0.])
    look = torch.tensor([0., 0., -1.])
    left = torch.tensor([1., 0., 0.])
    fov_deg = 30
    res = args.res // opts['res_scale']
    img, depth, count, _, eval_sum, raycast_time = render.render_image_naive(implicit_func, params, root, look, up,
                                                                             left, res, fov_deg, False, opts,
                                                                             shading='matcap_color', matcaps=matcaps)

    return raycast_time, count.detach().cpu().numpy().reshape(res, res)

def filter_large_entries(matrix, threshold):
    # Mask of where values exceed the threshold
    mask = matrix > threshold

    # Define a 3x3 kernel to compute average of 8-connected neighbors
    kernel = np.array([[1, 1, 1],
                       [1, 0, 1],
                       [1, 1, 1]])

    # Sum of neighbors
    neighbor_sum = convolve(matrix, kernel, mode='constant', cval=0.0)

    # Count of neighbors (to handle border effects)
    neighbor_count = convolve(np.ones_like(matrix), kernel, mode='constant', cval=0.0)

    # Compute average of neighbors
    neighbor_avg = neighbor_sum / neighbor_count

    # Replace large entries with the average of their neighbors
    filtered = matrix.copy()
    filtered[mask] = neighbor_avg[mask]

    return filtered

class ShiftedNorm(mcolors.Normalize):
    def __init__(self, vmin=None, vmax=None, clip=False, out_min=0.2, out_max=1.0):
        super().__init__(vmin, vmax, clip)
        self.out_min = out_min
        self.out_max = out_max

    def __call__(self, value, clip=None):
        # Normalize to [0, 1] first
        normed = super().__call__(value, clip)
        # Then scale to [out_min, out_max]
        return self.out_min + normed * (self.out_max - self.out_min)

def q_clip(img, q):
    Q = np.percentile(img, q)
    values, counts = np.unique(img, return_counts=True)
    sorted_idx = np.argsort(-counts)
    secondary_mode = values[sorted_idx[1]]
    return np.where(img > Q, secondary_mode, img)

def main():
    parser = argparse.ArgumentParser()

    # Build arguments
    parser.add_argument("input", type=str)
    parser.add_argument("load_from_1", type=str)
    parser.add_argument("load_from_2", type=str)
    parser.add_argument("--rendering", type=str)
    parser.add_argument("--res", type=int, default=1024)
    parser.add_argument("--output", type=str, default=None)
    # Parse arguments
    args = parser.parse_args()

    opts = queries.get_default_cast_opts()
    opts['data_bound'] = 1
    opts['res_scale'] = 1
    opts['tree_max_depth'] = 12
    opts['tree_split_aff'] = False
    opts['hit_eps'] = 1e-3

    implicit_func, params = implicit_mlp_utils.generate_implicit_from_file(args.input, mode='crown', **{})

    # load the matcaps
    matcaps = render.load_matcap(os.path.join(ROOT_DIR, "assets", "matcaps", "wax_{}.png"))
    matcaps = torch.stack(matcaps)

    t1, img1 = get_count(args, implicit_func, params, args.load_from_1, opts, matcaps)
    print("getting img1")
    t2, img2 = get_count(args, implicit_func, params, args.load_from_2, opts, matcaps)
    print("getting img2")
    implicit_func, params = implicit_mlp_utils.generate_implicit_from_file(args.input, mode='affine_fixed', **{})
    t3, img3 = get_count_baseline(args, implicit_func, params, opts, matcaps)
    print("getting img3")
    del implicit_func, params
    gc.collect()
    torch.cuda.empty_cache()
    # img1 = filter_large_entries(img1, 20)
    # img2 = filter_large_entries(img2, 20)
    # img3 = filter_large_entries(img3, 20)
    img1 = q_clip(img1, 90)
    img2 = q_clip(img2, 90)
    img3 = q_clip(img3, 90)
    plt.hist(img1.flatten(), 50)
    plt.show()
    plt.hist(img2.flatten(), 50)
    plt.show()
    plt.hist(img3.flatten(), 50)
    plt.show()
    # img1 = img1.clip(0, 300)
    # img2 = img2.clip(0, 300)
    # img3 = img3.clip(0, 300)
    # norm = SymLogNorm(vmin=0., vmax=max(img1.max(), img2.max(), img3.max()), linthresh=1)
    # norm = SymLogNorm(vmin=0., vmax=100, linthresh=1)
    # norm = PowerNorm(gamma=0.5, vmin=0., vmax=max(img1.max(), img2.max(), img3.max()))
    norm = PowerNorm(gamma=0.5, vmin=0., vmax=100.)
    # norm = Normalize(vmax=max(img1.max(), img2.max(), img3.max()))
    # norm = ShiftedNorm(vmin=0, vmax=100, out_min=0.2, out_max=1.0)
    # norm = Normalize(vmax=100.)

    print("max of img1: ", img1.max())
    print("max of img2: ", img2.max())
    print("max of img3: ", img3.max())
    if args.rendering:
        raise NotImplementedError("Rendering image support has been removed as per updated requirements.")
    else:
        cmap = 'hot_r'

        # Figure for adaptive shells
        fig1, ax1 = plt.subplots(figsize=(8, 8))
        im1 = ax1.imshow(img1, cmap=cmap, norm=norm, origin='lower')
        ax1.set_xticks([])
        ax1.set_yticks([])
        for spine in ax1.spines.values():
            spine.set_visible(False)
        fig1.tight_layout()
        # fig1.savefig(f"{args.output}_adaptive.png") if args.output else plt.show()
        fig1.savefig('/home/ /3d_vnn_ref/de_sample_counts.pdf', bbox_inches='tight')


        # Figure for our shells
        fig2, ax2 = plt.subplots(figsize=(10, 8))
        im2 = ax2.imshow(img2, cmap=cmap, norm=norm, origin='lower')
        ax2.set_xticks([])
        ax2.set_yticks([])
        for spine in ax2.spines.values():
            spine.set_visible(False)

        cbar = fig2.colorbar(im2, ax=ax2, fraction=0.046, pad=0.04, shrink=0.8)
        # cbar.locator = MaxNLocator(nbins=5)
        # cbar.update_ticks()
        # cbar.ax.tick_params(labelsize=10)

        # for label in cbar.ax.get_yticklabels():
        #     label.set_fontname("serif")
        # cbar.set_label('')  # remove label

        # cbar = fig2.colorbar(im2, ax=ax2, fraction=0.046, pad=0.04, orientation='vertical')
        # cbar.set_label('steps', labelpad=10, fontsize=12, fontname='serif', loc='center')
        cbar.ax.xaxis.set_label_position('top')
        cbar.ax.xaxis.tick_top()
        cbar.set_ticks([cbar.vmin, cbar.vmax])
        cbar.ax.set_yticklabels([f'{int(cbar.vmin)}', f'{int(cbar.vmax)}'], fontname='serif', fontsize=36)

        fig2.tight_layout()
        fig2.savefig('/home/ /3d_vnn_ref/gios_sample_counts.pdf', bbox_inches='tight')
        # plt.show()
        # fig2.savefig(f"{args.output}_ours.png", bbox_inches='tight') if args.output else plt.show()

        fig3, ax3 = plt.subplots(figsize=(8, 8))
        im3 = ax3.imshow(img3, cmap=cmap, norm=norm, origin='lower')
        ax3.set_xticks([])
        ax3.set_yticks([])
        for spine in ax3.spines.values():
            spine.set_visible(False)
        fig3.tight_layout()
        fig3.savefig('/home/ /3d_vnn_ref/spelunking_sample_counts.pdf', bbox_inches='tight')
        plt.show()

if __name__ == '__main__':
    main()
