import taichi as ti
import torch
from taichi.math import uvec3

taichi_block_size = 128

data_type = ti.f32
torch_type = torch.float32

MAX_SAMPLES = 1024
NEAR_DISTANCE = 0.01
SQRT3 = 1.7320508075688772
SQRT3_MAX_SAMPLES = SQRT3 / 1024
SQRT3_2 = 1.7320508075688772 * 2


@ti.func
def scalbn(x, exponent):
    return x * ti.math.pow(2, exponent)


@ti.func
def calc_dt(t, exp_step_factor, grid_size, scale):
    return ti.math.clamp(t * exp_step_factor, SQRT3_MAX_SAMPLES,
                         SQRT3_2 * scale / grid_size)


@ti.func
def frexp_bit(x):
    exponent = 0
    if x != 0.0:
        # frac = ti.abs(x)
        bits = ti.bit_cast(x, ti.u32)
        exponent = ti.i32((bits & ti.u32(0x7f800000)) >> 23) - 127
        # exponent = (ti.i32(bits & ti.u32(0x7f800000)) >> 23) - 127
        bits &= ti.u32(0x7fffff)
        bits |= ti.u32(0x3f800000)
        frac = ti.bit_cast(bits, ti.f32)
        if frac < 0.5:
            exponent -= 1
        elif frac > 1.0:
            exponent += 1
    return exponent


@ti.func
def mip_from_pos(xyz, cascades):
    mx = ti.abs(xyz).max()
    # _, exponent = _frexp(mx)
    exponent = frexp_bit(ti.f32(mx)) + 1
    # frac, exponent = ti.frexp(ti.f32(mx))
    return ti.min(cascades - 1, ti.max(0, exponent))


@ti.func
def mip_from_dt(dt, grid_size, cascades):
    # _, exponent = _frexp(dt*grid_size)
    exponent = frexp_bit(ti.f32(dt * grid_size))
    # frac, exponent = ti.frexp(ti.f32(dt*grid_size))
    return ti.min(cascades - 1, ti.max(0, exponent))


@ti.func
def __expand_bits(v):
    v = (v * ti.uint32(0x00010001)) & ti.uint32(0xFF0000FF)
    v = (v * ti.uint32(0x00000101)) & ti.uint32(0x0F00F00F)
    v = (v * ti.uint32(0x00000011)) & ti.uint32(0xC30C30C3)
    v = (v * ti.uint32(0x00000005)) & ti.uint32(0x49249249)
    return v


@ti.func
def __morton3D(xyz):
    xyz = __expand_bits(xyz)
    return xyz[0] | (xyz[1] << 1) | (xyz[2] << 2)


@ti.func
def __morton3D_invert(x):
    x = x & (0x49249249)
    x = (x | (x >> 2)) & ti.uint32(0xc30c30c3)
    x = (x | (x >> 4)) & ti.uint32(0x0f00f00f)
    x = (x | (x >> 8)) & ti.uint32(0xff0000ff)
    x = (x | (x >> 16)) & ti.uint32(0x0000ffff)
    return ti.int32(x)


@ti.kernel
def morton3D_invert_kernel(indices: ti.types.ndarray(ndim=1),
                           coords: ti.types.ndarray(ndim=2)):
    for i in indices:
        ind = ti.uint32(indices[i])
        coords[i, 0] = __morton3D_invert(ind >> 0)
        coords[i, 1] = __morton3D_invert(ind >> 1)
        coords[i, 2] = __morton3D_invert(ind >> 2)


def morton3D_invert(indices):
    coords = torch.zeros(indices.size(0),
                         3,
                         device=indices.device,
                         dtype=torch.int32)
    morton3D_invert_kernel(indices.contiguous(), coords)
    ti.sync()
    return coords


@ti.kernel
def morton3D_kernel(xyzs: ti.types.ndarray(ndim=2),
                    indices: ti.types.ndarray(ndim=1)):
    for s in indices:
        xyz = uvec3([xyzs[s, 0], xyzs[s, 1], xyzs[s, 2]])
        indices[s] = ti.cast(__morton3D(xyz), ti.int32)


def morton3D(coords1):
    indices = torch.zeros(coords1.size(0),
                          device=coords1.device,
                          dtype=torch.int32)
    morton3D_kernel(coords1.contiguous(), indices)
    ti.sync()
    return indices


@ti.kernel
def packbits(density_grid: ti.types.ndarray(ndim=1),
             density_threshold: float,
             density_bitfield: ti.types.ndarray(ndim=1)):

    for n in density_bitfield:
        bits = ti.uint8(0)

        for i in ti.static(range(8)):
            bits |= (ti.uint8(1) << i) if (
                density_grid[8 * n + i] > density_threshold) else ti.uint8(0)

        density_bitfield[n] = bits


@ti.kernel
def torch2ti(field: ti.template(), data: ti.types.ndarray()):
    for I in ti.grouped(data):
        field[I] = data[I]


@ti.kernel
def ti2torch(field: ti.template(), data: ti.types.ndarray()):
    for I in ti.grouped(data):
        data[I] = field[I]


@ti.kernel
def ti2torch_grad(field: ti.template(), grad: ti.types.ndarray()):
    for I in ti.grouped(grad):
        grad[I] = field.grad[I]


@ti.kernel
def torch2ti_grad(field: ti.template(), grad: ti.types.ndarray()):
    for I in ti.grouped(grad):
        field.grad[I] = grad[I]


@ti.kernel
def torch2ti_vec(field: ti.template(), data: ti.types.ndarray()):
    for I in range(data.shape[0] // 2):
        field[I] = ti.Vector([data[I * 2], data[I * 2 + 1]])


@ti.kernel
def ti2torch_vec(field: ti.template(), data: ti.types.ndarray()):
    for i, j in ti.ndrange(data.shape[0], data.shape[1] // 2):
        data[i, j * 2] = field[i, j][0]
        data[i, j * 2 + 1] = field[i, j][1]


@ti.kernel
def ti2torch_grad_vec(field: ti.template(), grad: ti.types.ndarray()):
    for I in range(grad.shape[0] // 2):
        grad[I * 2] = field.grad[I][0]
        grad[I * 2 + 1] = field.grad[I][1]


@ti.kernel
def torch2ti_grad_vec(field: ti.template(), grad: ti.types.ndarray()):
    for i, j in ti.ndrange(grad.shape[0], grad.shape[1] // 2):
        field.grad[i, j][0] = grad[i, j * 2]
        field.grad[i, j][1] = grad[i, j * 2 + 1]


def extract_model_state_dict(ckpt_path,
                             model_name='model',
                             prefixes_to_ignore=[]):
    checkpoint = torch.load(ckpt_path, map_location='cpu')
    checkpoint_ = {}
    if 'state_dict' in checkpoint:  # if it's a pytorch-lightning checkpoint
        checkpoint = checkpoint['state_dict']
    for k, v in checkpoint.items():
        if not k.startswith(model_name):
            continue
        k = k[len(model_name) + 1:]
        for prefix in prefixes_to_ignore:
            if k.startswith(prefix):
                break
        else:
            checkpoint_[k] = v
    return checkpoint_


def load_ckpt(model, ckpt_path, model_name='model', prefixes_to_ignore=[]):
    if not ckpt_path:
        return
    model_dict = model.state_dict()
    checkpoint_ = extract_model_state_dict(ckpt_path, model_name,
                                           prefixes_to_ignore)
    model_dict.update(checkpoint_)
    model.load_state_dict(model_dict)

def depth2img(depth):
    depth = (depth - depth.min()) / (depth.max() - depth.min())
    depth_img = cv2.applyColorMap((depth * 255).astype(np.uint8),
                                  cv2.COLORMAP_TURBO)

    return depth_img