import functorch
import torch
import torch.nn.functional as F
from functorch import vmap
from crown import CrownImplicitFunction
import imageio
import geometry
import queries
from utils import *
import matplotlib.pyplot as plt
import os, time

os.environ['OptiX_INSTALL_DIR'] = '/home/ /Documents/NVIDIA-OptiX-SDK-8.0.0-linux64-x86_64'
# os.environ['OptiX_INSTALL_DIR'] = '/media/  /b5df3483-c11a-42f1-b414-023f33bc5312/home/ /Documents/NVIDIA-OptiX-SDK-8.0.0-linux64-x86_64'

# from triro.ray.ray_optix import RayMeshIntersector  # FIXME: Should be uncommented when rendering meshes

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
torch.set_default_tensor_type(torch.cuda.FloatTensor)


HIT_EPS = 1e-3
FD_OFFSET = torch.tensor((
                (+HIT_EPS, -HIT_EPS, -HIT_EPS),
                (-HIT_EPS, -HIT_EPS, +HIT_EPS),
                (-HIT_EPS, +HIT_EPS, -HIT_EPS),
                (+HIT_EPS, +HIT_EPS, +HIT_EPS),
            )) * 1e1

# theta_x/y should be
def camera_ray(look_dir, up_dir, left_dir, fov_deg_x, fov_deg_y, theta_x, theta_y):
    ray_image_plane_pos = look_dir \
                          + left_dir * (theta_x * torch.tan(
        torch.deg2rad(torch.tensor(fov_deg_x, device=look_dir.device)) / 2)) \
                          + up_dir * (theta_y * torch.tan(
        torch.deg2rad(torch.tensor(fov_deg_y, device=look_dir.device)) / 2))

    ray_dir = geometry.normalize(ray_image_plane_pos)

    return ray_dir

def generate_camera_rays(eye_pos, look_dir, up_dir, res=1024, fov_deg=30.):
    D = res  # image dimension
    R = res * res  # number of rays

    ## Generate rays according to a pinhole camera

    # Image coords on [-1,1] for each output pixel
    cam_ax_x = torch.linspace(-1., 1., res)
    cam_ax_y = torch.linspace(-1., 1., res)
    cam_y, cam_x = torch.meshgrid(cam_ax_x, cam_ax_y, indexing='ij')
    cam_x = cam_x.flatten()
    cam_y = cam_y.flatten()

    # Orthornormal camera frame
    up_dir = up_dir - torch.dot(look_dir, up_dir) * look_dir
    up_dir /= torch.norm(up_dir)
    left_dir = torch.cross(look_dir, up_dir)

    ray_dirs = vmap(partial(camera_ray, look_dir, up_dir, left_dir, fov_deg, fov_deg))(cam_x, cam_y)
    ray_roots = torch.tile(eye_pos, (ray_dirs.shape[0], 1))
    return ray_roots, ray_dirs


def outward_normal(funcs_tuple, params_tuple, hit_pos, hit_id, eps, method='finite_differences'):
    grad_out = torch.zeros(3)
    i_func = 1
    for func, params in zip(funcs_tuple, params_tuple):
        if isinstance(func, CrownImplicitFunction):
            # f = partial(func.call_implicit_func, params)
            f = func.torch_forward
        else:
            f = partial(func, params)

        if method == 'autodiff':
            grad_f = functorch.jacfwd(f)
            grad = grad_f(hit_pos)

        elif method == 'finite_differences':
            # 'tetrahedron' central differences approximation
            # see e.g. https://www.iquilezles.org/www/articles/normalsSDF/normalsSDF.htm

            x_pts = hit_pos[None, :] + FD_OFFSET
            samples = vmap(f)(x_pts).detach()
            if samples.dim() > 1:
                samples = samples.squeeze(1)
            # samples = vmap(f)(x_pts).squeeze(1).detach()
            grad = torch.sum(FD_OFFSET * samples[:, None], dim=0)

        else:
            raise ValueError("unrecognized method")

        grad = geometry.normalize(grad)
        grad_out = torch.where(hit_id == i_func, grad, grad_out)
        i_func += 1

    return grad_out

def outward_normals(funcs_tuple, params_tuple, hit_pos, hit_ids, eps, method='finite_differences'):
    this_normal_one = lambda p, id: outward_normal(funcs_tuple, params_tuple, p, id, eps, method=method)
    if method == 'autodiff':
        total_samples = hit_pos.shape[0]
        out_normal = torch.empty_like(hit_pos)
        batch_size_per_iteration = 256
        for start_idx in range(0, total_samples, batch_size_per_iteration):
            end_idx = min(start_idx + batch_size_per_iteration, total_samples)
            out_normal[start_idx:end_idx] \
                = vmap(this_normal_one)(hit_pos[start_idx:end_idx], hit_ids[start_idx:end_idx])

        return out_normal
    elif method == 'finite_differences':
        total_samples = hit_pos.shape[0]
        out_normal = torch.empty_like(hit_pos)
        batch_size_per_iteration = 2**17
        # batch_size_per_iteration = 2**12
        for start_idx in range(0, total_samples, batch_size_per_iteration):
            end_idx = min(start_idx + batch_size_per_iteration, total_samples)
            out_normal[start_idx:end_idx] \
                = vmap(this_normal_one)(hit_pos[start_idx:end_idx], hit_ids[start_idx:end_idx])

        return out_normal
    return vmap(this_normal_one)(hit_pos, hit_ids)

def render_image(funcs_tuple, params_tuple, eye_pos, look_dir, up_dir, left_dir, res, fov_deg, frustum, branching_method, opts,
                 shading="normal", shading_color_tuple=((0.157, 0.613, 1.000)), matcaps=None, tonemap=False,
                 shading_color_func=None, tree_based=False, load_from=None, save_to=None):
    # make sure inputs are tuples not lists (can't has lists)
    if isinstance(funcs_tuple, list): funcs_tuple = tuple(funcs_tuple)
    if isinstance(params_tuple, list): params_tuple = tuple(params_tuple)
    if isinstance(shading_color_tuple, list): shading_color_tuple = tuple(shading_color_tuple)

    # wrap in tuples if single was passed
    if not isinstance(funcs_tuple, tuple):
        funcs_tuple = (funcs_tuple,)
    if not isinstance(params_tuple, tuple):
        params_tuple = (params_tuple,)
    if not isinstance(shading_color_tuple[0], tuple):
        shading_color_tuple = (shading_color_tuple,)

    L = len(funcs_tuple)
    if (len(params_tuple) != L) or (len(shading_color_tuple) != L):
        raise ValueError("render_image tuple arguments should all be same length")

    ray_roots, ray_dirs = generate_camera_rays(eye_pos, look_dir, up_dir, res=res, fov_deg=fov_deg)
    if frustum:
        # == Frustum raycasting

        cam_params = eye_pos, look_dir, up_dir, left_dir, fov_deg, fov_deg, res, res

        with Timer("frustum raycast"):
            t_raycast, hit_ids, counts, n_eval = queries.cast_rays_frustum(funcs_tuple, params_tuple, cam_params, opts)
            # t_raycast.block_until_ready()
            torch.cuda.synchronize()

        # TODO transposes here due to image layout conventions. can we get rid of them?
        t_raycast = t_raycast.transpose().flatten()
        hit_ids = hit_ids.transpose().flatten()
        counts = counts.transpose().flatten()

    else:
        # == Standard raycasting
        with Timer("raycast"):
            t_raycast, hit_ids, counts, n_eval = queries.cast_rays(funcs_tuple, params_tuple, ray_roots, ray_dirs, opts)
            # t_raycast.block_until_ready()
            # print("t_raycast", t_raycast)
            torch.cuda.synchronize()

    hit_pos = ray_roots + t_raycast[:, None] * ray_dirs

    torch.cuda.empty_cache()

    hit_normals = outward_normals(funcs_tuple, params_tuple, hit_pos, hit_ids, opts['hit_eps'])
    hit_color = shade_image(shading, ray_dirs, hit_pos, hit_normals, hit_ids, up_dir, matcaps, shading_color_tuple,
                            shading_color_func=shading_color_func)
    # print(hit_pos, hit_normals, hit_color)
    img = torch.where(hit_ids[:, None].bool(), hit_color, torch.ones((res * res, 3)))

    if tonemap:
        # We intentionally tonemap before compositing in the shadow. Otherwise the white level clips the shadow and gives it a hard edge.
        img = tonemap_image(img)

    img = img.reshape(res, res, 3)
    depth = t_raycast.reshape(res, res)
    counts = counts.reshape(res, res)
    hit_ids = hit_ids.reshape(res, res)

    return img, depth, counts, hit_ids, n_eval, -1


def render_image_naive(funcs_tuple, params_tuple, eye_pos, look_dir, up_dir, left_dir, res, fov_deg, frustum, opts,
                       shading="normal", shading_color_tuple=torch.tensor(((0.157, 0.613, 1.000),)), matcaps=None, tonemap=False,
                       shading_color_func=None, tree_based=False, shell_based=False, batch_size=None, enable_clipping=False, load_from=None, save_to=None):
    # make sure inputs are tuples not lists (can't has lists)
    if isinstance(funcs_tuple, list): funcs_tuple = tuple(funcs_tuple)
    if isinstance(params_tuple, list): params_tuple = tuple(params_tuple)
    # if isinstance(shading_color_tuple, list): shading_color_tuple = tuple(shading_color_tuple)

    # wrap in tuples if single was passed
    if not isinstance(funcs_tuple, tuple):
        funcs_tuple = (funcs_tuple,)
    if not isinstance(params_tuple, tuple):
        params_tuple = (params_tuple,)
    # if not isinstance(shading_color_tuple[0], tuple):
    #     shading_color_tuple = (shading_color_tuple,)

    # L = len(funcs_tuple)
    # if (len(params_tuple) != L) or (len(shading_color_tuple) != L):
    #     raise ValueError("render_image tuple arguments should all be same length")

    ray_roots, ray_dirs = generate_camera_rays(eye_pos, look_dir, up_dir, res=res, fov_deg=fov_deg)
    if frustum:
        # == Frustum raycasting

        cam_params = eye_pos, look_dir, up_dir, left_dir, fov_deg, fov_deg, res, res

        t_raycast, hit_ids, counts, n_eval = queries.cast_rays_frustum(funcs_tuple, params_tuple, cam_params, opts)
        # t_raycast.block_until_ready()
        torch.cuda.synchronize()

        # TODO transposes here due to image layout conventions. can we get rid of them?
        t_raycast = t_raycast.transpose().flatten()
        hit_ids = hit_ids.transpose().flatten()
        counts = counts.transpose().flatten()

    elif shell_based:
        t_raycast, hit_ids, counts, n_eval = queries.cast_rays_shell_based(funcs_tuple, params_tuple, ray_roots,
                                                                          ray_dirs, batch_size=batch_size,
                                                                          load_from=load_from)
        torch.cuda.synchronize()
    else:
        # == Standard raycasting
        t_raycast, hit_ids, counts, n_eval = queries.cast_rays(funcs_tuple, params_tuple, ray_roots, ray_dirs, opts)
        torch.cuda.synchronize()

    hit_pos = ray_roots + t_raycast[:, None] * ray_dirs

    torch.cuda.empty_cache()

    hit_normals = outward_normals(funcs_tuple, params_tuple, hit_pos, hit_ids, opts['hit_eps'])
    hit_color = shade_image(shading, ray_dirs, hit_pos, hit_normals, hit_ids, up_dir, matcaps, shading_color_tuple,
                            shading_color_func=shading_color_func)

    img = torch.where(hit_ids[:, None].bool(), hit_color, torch.ones((res * res, 3)))

    if tonemap:
        # We intentionally tonemap before compositing in the shadow. Otherwise the white level clips the shadow and gives it a hard edge.
        img = tonemap_image(img)

    img = img.reshape(res, res, 3)
    depth = t_raycast.reshape(res, res)
    counts = counts.reshape(res, res)
    hit_ids = hit_ids.reshape(res, res)

    return img, depth, counts, hit_ids, n_eval, -1


def render_image_mesh(funcs_tuple, params_tuple, intersector, eye_pos, look_dir, up_dir, left_dir, res, fov_deg, opts,
                      shading="normal", shading_color_tuple=torch.tensor(((0.157, 0.613, 1.000),)), approx=False, matcaps=None, tonemap=False,
                      shading_color_func=None):
    if isinstance(funcs_tuple, list): funcs_tuple = tuple(funcs_tuple)
    if isinstance(params_tuple, list): params_tuple = tuple(params_tuple)

    # wrap in tuples if single was passed
    if not isinstance(funcs_tuple, tuple):
        funcs_tuple = (funcs_tuple,)
    if not isinstance(params_tuple, tuple):
        params_tuple = (params_tuple,)

    ray_roots, ray_dirs = generate_camera_rays(eye_pos, look_dir, up_dir, res=res, fov_deg=fov_deg)
    hit_pos, hit_ids, hit, count = queries.cast_rays_shell_based(funcs_tuple, params_tuple, ray_roots, ray_dirs, intersector, approx, opts['hit_eps'])

    # _, _, _, _ = queries.cast_rays_shell_based(funcs_tuple, params_tuple, torch.empty_like(ray_roots), torch.empty_like(ray_dirs), intersector, approx, opts['hit_eps'])

    time_render_start = time.perf_counter()
    hit_pos, hit_ids, hit, count = queries.cast_rays_shell_based(funcs_tuple, params_tuple, ray_roots, ray_dirs, intersector, approx, opts['hit_eps'])

    time_normals_start = time.perf_counter()
    print("time queries: ", time_normals_start - time_render_start)
    hit_normals = outward_normals(funcs_tuple, params_tuple, hit_pos, hit_ids, opts['hit_eps'], method='finite_differences')
    time_normals_end = time.perf_counter()
    print("time normals: ", time_normals_end - time_normals_start)
    hit_color = shade_image(shading, ray_dirs, hit_pos, hit_normals, hit_ids, up_dir, matcaps, shading_color_tuple,
                            shading_color_func=shading_color_func)
    img = torch.where(hit_ids[:, None].bool(), hit_color, torch.ones((res * res, 3)))

    if tonemap:
        # We intentionally tonemap before compositing in the shadow. Otherwise the white level clips the shadow and gives it a hard edge.
        img = tonemap_image(img)
    time_render_end = time.perf_counter()

    img = img.reshape(res, res, 3)
    print("time after normals: ", time_render_end - time_normals_end)
    print("Time rendering:", time_render_end - time_render_start)

    return img, time_render_end - time_render_start, count

def tonemap_image(img, gamma=2.2, white_level=.75, exposure=1.):
    img = img * exposure
    num = img * (1.0 + (img / (white_level * white_level)))
    den = (1.0 + img)
    img = num / den
    img = torch.pow(img, 1.0 / gamma)
    return img

@torch.jit.script
def shade_image(shading: str, ray_dirs: torch.Tensor, hit_pos: torch.Tensor, hit_normals: torch.Tensor,
                hit_ids: torch.Tensor, up_dir: torch.Tensor, matcaps: torch.Tensor,
                shading_color_tuple: torch.Tensor, shading_color_func=None) -> torch.Tensor:
    # compute matcap coordinates
    ray_up = (up_dir - (up_dir * ray_dirs).sum(dim=-1, keepdim=True) * ray_dirs)
    ray_up = ray_up / ray_up.norm(p=2, dim=-1, keepdim=True)
    ray_left = torch.cross(ray_dirs, ray_up, dim=-1)
    matcap_u = torch.einsum('ij,ij->i', -ray_left, hit_normals)
    matcap_v = torch.einsum('ij,ij->i', ray_up, hit_normals)

    matcap_u *= 0.98
    matcap_v *= 0.98

    matcap_x = (matcap_u + 1.) / 2. * matcaps[0].shape[0]
    matcap_y = (-matcap_v + 1.) / 2. * matcaps[0].shape[1]
    matcap_coords = torch.stack((matcap_x, matcap_y), dim=-1)

    x = matcap_coords[:, 0].long().clamp(0, matcaps[0].shape[0] - 1)
    y = matcap_coords[:, 1].long().clamp(0, matcaps[0].shape[1] - 1)

    mat_r = matcaps[0][x, y]
    mat_g = matcaps[1][x, y]
    mat_b = matcaps[2][x, y]
    mat_k = matcaps[3][x, y]

    shading_color = torch.ones_like(hit_pos)
    # print(shading_color)
    # if shading_color_func is None:
    i_func = 1
    for c in shading_color_tuple:
        mask = (hit_ids == i_func).unsqueeze(-1)
        # shading_color = torch.where(mask, torch.tensor(c, dtype=shading_color.dtype, device=shading_color.device), shading_color)
        shading_color = torch.where(mask, c, shading_color)
        i_func += 1
    # else:
    #     shading_color = shading_color_func(hit_pos)
    # print(shading_color)
    c_r, c_g, c_b = shading_color[:, 0], shading_color[:, 1], shading_color[:, 2]
    c_k = 1. - (c_r + c_b + c_g)

    c_r = c_r[:, None]
    c_g = c_g[:, None]
    c_b = c_b[:, None]
    c_k = c_k[:, None]

    hit_color = c_r * mat_r + c_b * mat_b + c_g * mat_g + c_k * mat_k

    return hit_color

# @torch.jit.script
def phong_shading(
    hit_points,
    normals,
    view_dirs,
    light_pos,
    light_color,
    albedo,
    shininess=64,
    specular_strength=0.5,
    ambient_strength=0.1,
    ambient_color=torch.tensor([1.0, 1.0, 1.0])
):
    L = F.normalize(light_pos - hit_points, dim=-1)
    V = F.normalize(view_dirs, dim=-1)
    N = F.normalize(normals, dim=-1)
    R = F.normalize(2 * (N * (L * N).sum(-1, keepdim=True)) - L, dim=-1)

    diff = torch.clamp((N * L).sum(-1, keepdim=True), min=0.0)
    diffuse = albedo * diff * light_color

    spec = torch.clamp((R * V).sum(-1, keepdim=True), min=0.0)
    specular = specular_strength * (spec ** shininess) * light_color

    ambient = ambient_strength * albedo * ambient_color

    return ambient + diffuse + specular


def look_at(eye_pos, target=None, up_dir='y'):
    if target == None:
        target = torch.tensor((0., 0., 0.,))
    if up_dir == 'y':
        up_dir = torch.tensor((0., 1., 0.,))
    elif up_dir == 'z':
        up_dir = torch.tensor((0., 0., 1.,))

    look_dir = geometry.normalize(target - eye_pos)
    up_dir = geometry.orthogonal_dir(up_dir, look_dir)
    left_dir = torch.cross(look_dir, up_dir)

    return look_dir, up_dir, left_dir


def load_matcap(fname_pattern):
    imgs = []
    for c in ['r', 'g', 'b', 'k']:
        im = imageio.imread(fname_pattern.format(c))
        im = torch.tensor(im) / 256.
        imgs.append(im)

    return tuple(imgs)