import argparse
import math
from pathlib import Path
import json

from PIL import Image
import cv2
import torch
from torchvision import transforms as torch_trans
import numpy as np

from deptheval.utils.proj import depth_to_xyz
from deptheval.utils.common import assign_item_to_dict, pathlib_file
from deptheval.utils.image import resize
from deptheval.visualization.cfg import get_crop_region
from deptheval.utils.image import imread_rgb
from deptheval.utils.depth import load_data
from deptheval.utils.np_and_th import get_shifted_data


@torch.no_grad()
def compute_grid_lb_ub(data, i, j):
    '''
    .           .           .

          -------------
          |(0,0)|(0,1)|
    .     ------.------     .
          |(1,0)|(1,1)|
          -------------

    .           .           .
    '''
    if i == 0 and j == 0:
        x00 = .25 * (data[:-1, :-1] + data[:-1, 1:] + data[1:, :-1] + data[1:, 1:])
        x01 = .5 * (data[:-1, 1:] + data[1:, 1:])
        x10 = .5 * (data[1:, :-1] + data[1:, 1:])
        x11 = 1. * data[1:, 1:]
    elif i == 0 and j == 1:
        x00 = .5 * (data[:-1, :-1] + data[1:, :-1])
        x01 = .25 * (data[:-1, :-1] + data[:-1, 1:] + data[1:, :-1] + data[1:, 1:])
        x10 = 1. * data[1:, :-1]
        x11 = .5 * (data[1:, :-1] + data[1:, 1:])
    elif i == 1 and j == 0:
        x00 = .5 * (data[:-1, :-1] + data[:-1, 1:])
        x01 = 1. * data[:-1, 1:]
        x10 = .25 * (data[:-1, :-1] + data[:-1, 1:] + data[1:, :-1] + data[1:, 1:])
        x11 = .5 * (data[:-1, 1:] + data[1:, 1:])
    else:
        x00 = 1. * data[:-1, :-1]
        x01 = .5 * (data[:-1, :-1] + data[:-1, 1:])
        x10 = .5 * (data[:-1, :-1] + data[1:, :-1])
        x11 = .25 * (data[:-1, :-1] + data[:-1, 1:] + data[1:, :-1] + data[1:, 1:])
    x = torch.stack([x00, x01, x10, x11], dim=-1)
    lb, ub = x.min(dim=-1).values, x.max(dim=-1).values  # (H - 1, W - 1), (H - 1, W - 1)
    del x
    return lb, ub


@torch.no_grad()
def compute_high_res_idx(high_res_shape, data_low_res, valid, valid_high_res, gap, val_lb):
    '''
    :param high_res_shape: (Hu, Wu)
    :param data_low_res: shape (Hl, Wl)
    :param valid_high_res: shape (Hl, Wl)
    :param gap:
    :param val_lb:
    :return: res_high_res
        res_high_res: shape (Hu, Wu)
    '''
    Hu, Wu = high_res_shape

    # fill invalid pixels with neighbor means
    data_low_res = data_low_res.clone()
    nb_data_sum = torch.zeros_like(data_low_res)
    nb_data_cnt = torch.zeros_like(data_low_res)
    for di in [-1, 0, 1]:
        for dj in [-1, 0, 1]:
            nb_valid = get_shifted_data(valid, di, dj)
            nb_data = get_shifted_data(data_low_res, di, dj)
            nb_data_sum[nb_valid] += nb_data[nb_valid]
            nb_data_cnt[nb_valid] += 1
    nb_data_sum[nb_data_cnt < .5] = 0
    data_low_res[~valid] = (nb_data_sum / nb_data_cnt.clamp(min=1))[~valid]

    data_high_res = torch_trans.functional.resize(data_low_res[None], (Hu, Wu), torch_trans.InterpolationMode.BILINEAR)[0]
    res_high_res = -torch.ones((Hu, Wu), dtype=torch.int32, device=data_high_res.device)

    for i in range(2):
        for j in range(2):
            lb, ub = compute_grid_lb_ub(data_high_res, i, j)
            lb_i = torch.clip(torch.ceil((lb - val_lb) / gap), min=0).to(res_high_res.dtype)
            ub_i = torch.clip(torch.floor((ub - val_lb) / gap), max=2e9).to(res_high_res.dtype)

            multi_line_mask = (lb_i < ub_i) | ((lb_i == ub_i) & (res_high_res[1 - i: Hu - i, 1 - j: Wu - j] != -1))
            single_line_mask = (lb_i == ub_i) & (res_high_res[1 - i: Hu - i, 1 - j: Wu - j] == -1)

            res_high_res[1 - i: Hu - i, 1 - j: Wu - j][single_line_mask] = lb_i[single_line_mask]

            multi_line_upd_idx = torch.clip(torch.round((data_high_res[1 - i: Hu - i, 1 - j: Wu - j] - val_lb) / gap), min=0, max=2e9).to(res_high_res.dtype)
            multi_line_upd_idx = torch.where(multi_line_upd_idx < lb_i, lb_i, multi_line_upd_idx)
            multi_line_upd_idx = torch.where(multi_line_upd_idx > ub_i, ub_i, multi_line_upd_idx)
            multi_line_upd_mask = ((res_high_res[1 - i: Hu - i, 1 - j: Wu - j] == -1) | (
                torch.abs(data_high_res[1 - i: Hu - i, 1 - j: Wu - j] - (res_high_res[1 - i: Hu - i, 1 - j: Wu - j] * gap + val_lb)) >
                torch.abs(data_high_res[1 - i: Hu - i, 1 - j: Wu - j] - (multi_line_upd_idx * gap + val_lb))
            )) & multi_line_mask
            res_high_res[1 - i: Hu - i, 1 - j: Wu - j][multi_line_upd_mask] = multi_line_upd_idx[multi_line_upd_mask]
    res_high_res[~valid_high_res] = -1
    return res_high_res


def get_contour_line_gap(data: torch.Tensor, valid: torch.Tensor, num_gap, qt):
    if not valid.any():
        return 1
    qt_lb = data[valid].quantile(qt).item()
    qt_ub = data[valid].quantile(1 - qt).item()
    gap = (qt_ub - qt_lb) / (num_gap * (1 - qt * 2))
    return gap


@torch.no_grad()
def gen_contour_line(rgb_high_res, data, valid, valid_high_res, is_z, num_gap, shift, thickness=0, qt=0.05, colormap=cv2.COLORMAP_JET):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    data = torch.from_numpy(data).to(device)
    valid = torch.from_numpy(valid).to(device)
    valid_high_res = torch.from_numpy(valid_high_res).to(device)
    if is_z:
        data = 1 / data

    gap = get_contour_line_gap(data, valid, num_gap, qt)
    data_lb = data[valid].min().item() if valid.any() else 0
    val_lb = data_lb + gap * shift

    high_res_shape = rgb_high_res.shape[:2]
    res_high_res = compute_high_res_idx(high_res_shape, data, valid, valid_high_res, gap, val_lb)

    res = res_high_res.clone()
    dlt_rng = int(math.floor(thickness))
    for di in range(-dlt_rng, dlt_rng + 1):
        for dj in range(-dlt_rng, dlt_rng + 1):
            if di * di + dj * dj > thickness * thickness:
                continue
            nb_res = get_shifted_data(res_high_res, di, dj)
            upd_mask = get_shifted_data(valid_high_res, di, dj) & (res == -1) & (nb_res != -1) & valid_high_res
            res[upd_mask] = nb_res[upd_mask]

    if (res != -1).any():
        res[res != -1] -= res[res != -1].min()
    res = res.cpu().numpy()
    num_val = max(2, res.max().item() + 1)
    valid_high_res = valid_high_res.cpu().numpy()

    base_col = cv2.applyColorMap(np.arange(256, dtype=np.uint8)[None], colormap)[0].astype(np.float32)  # (256, 3)
    idx = np.arange(num_val, dtype=np.float32) / (num_val - 1) * 255  # (itr,)
    idx_lb = np.floor(idx).astype(np.int32)  # (itr,)
    coef_lb = (idx_lb.astype(np.float32) + 1 - idx)[:, None]  # (itr, 1)
    col = base_col[idx_lb] * coef_lb + base_col[np.clip(idx_lb + 1, a_min=None, a_max=255)] * (1 - coef_lb)  # (itr, 3)
    col = np.round(col).astype(np.uint8)

    img = np.zeros_like(rgb_high_res)
    non_colored_mask = valid_high_res & (res == -1)
    img[non_colored_mask] = rgb_high_res[non_colored_mask]
    colored_mask = valid_high_res & (res != -1)
    img[colored_mask] = col[res[colored_mask]]
    return img, colored_mask, col


def caption_img_xy(img_xy, col_maps):
    description_H, description_W_txt, description_W_col = 60, 180, 1080
    imgs = []
    for name, col in [
        ('yz-plane:', col_maps['x']),
        ('xz-plane:', col_maps['y']),
    ]:
        img_txt = 255 * np.ones((description_H, description_W_txt, 3), dtype=np.uint8)
        cv2.putText(img_txt, name, (description_H // 2, int(.7 * description_H)),
                    cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2, cv2.LINE_AA)
        img_col = 255 * np.ones((description_H, description_W_col, 3), dtype=np.uint8)
        ratio_h, ratio_w = .1, .02
        i_lb, i_ub, j_lb, j_ub = int(ratio_h * description_H), int((1 - ratio_h) * description_H), int(ratio_w * description_W_col), int((1 - ratio_w) * description_W_col)
        img_col[i_lb:i_ub, j_lb:j_ub] = cv2.resize(col[None].copy(), (j_ub - j_lb, i_ub - i_lb))
        imgs.append(np.concatenate([img_txt, img_col], axis=1))
    imgs = np.concatenate(imgs, axis=0)
    # from deptheval.utils.image import display_img
    # display_img(imgs)
    imgs = resize(imgs, W=img_xy.shape[1])
    img_xy = np.concatenate([img_xy, imgs], axis=0)
    return img_xy


def pil_ds(img: np.ndarray, H, W):
    pil_img = Image.fromarray(img, mode='RGB')
    pil_img = pil_img.resize((W, H), Image.Resampling.LANCZOS)
    return np.array(pil_img)


def render_contour_line_imgs(xyz: np.ndarray, valid: np.ndarray, rgb_low_res: np.ndarray, save_shape, out_root):
    '''
    :param xyz:
    :param valid:
    :param rgb_low_res:
    :param save_shape: (H, W)
    :param out_root:
    :return:
    '''
    # hyperparams
    texture_strength = 0.8
    draw_dim_lb = 4 * np.linalg.norm([1920, 1080])

    out_root = pathlib_file(out_root)

    dim = np.linalg.norm(rgb_low_res.shape[:2])
    us_sc = int(math.ceil(draw_dim_lb / dim))
    us_shape = (us_sc * rgb_low_res.shape[0], us_sc * rgb_low_res.shape[1])
    rgb_high_res = np.round(texture_strength * cv2.resize(rgb_low_res, (us_shape[1], us_shape[0]))).astype(np.uint8)
    valid_high_res = torch_trans.functional.resize(torch.from_numpy(valid)[None], rgb_high_res.shape[:2], torch_trans.InterpolationMode.NEAREST_EXACT)[0].numpy()

    summary = {}
    for thickness in [5 * np.linalg.norm(rgb_high_res.shape[:2]) / (4 * np.linalg.norm([1920, 1080]))]:
        for rel_num_gap in [0.015, 0.03, 0.06, 0.09, 0.12, 0.24, 0.42, 0.6]:
            num_gap = int(dim * rel_num_gap)
            for shift in [0.5]:
                imgs, colored_masks, col_maps = {}, {}, {}
                for i, name in enumerate(['x', 'y', 'z']):
                    imgs[name], colored_masks[name], col_maps[name] = \
                        gen_contour_line(rgb_high_res, xyz[..., i], valid, valid_high_res, name == 'z',
                                         num_gap, shift, thickness)

                    out_f = out_root / name / f'thickness__{thickness:.1f}___num_gap__{num_gap}___shift__{shift:.2f}.png'
                    out_f.parent.mkdir(parents=True, exist_ok=True)
                    cv2.imwrite(out_f.as_posix(), pil_ds(imgs[name][us_sc:-us_sc, us_sc:-us_sc, ::-1].copy(), save_shape[0], save_shape[1]))
                    print(f'Saved to {out_f.resolve()}')
                    assign_item_to_dict(summary, [name, thickness, num_gap, shift], str(out_f.resolve().relative_to(out_root.resolve())))

                img_xy = rgb_high_res.copy()
                img_xy[np.logical_and(colored_masks['x'], colored_masks['y'])] = np.round(.5 * (imgs['x'].astype(np.float32) + imgs['y'].astype(np.float32))).astype(np.uint8)[np.logical_and(colored_masks['x'], colored_masks['y'])]
                img_xy[np.logical_and(colored_masks['x'], np.logical_not(colored_masks['y']))] = imgs['x'][np.logical_and(colored_masks['x'], np.logical_not(colored_masks['y']))]
                img_xy[np.logical_and(np.logical_not(colored_masks['x']), colored_masks['y'])] = imgs['y'][np.logical_and(np.logical_not(colored_masks['x']), colored_masks['y'])]
                img_xy[~valid_high_res] = 0
                # img_xy = caption_img_xy(img_xy, col_maps)
                out_f = out_root / 'xy' / f'thickness__{thickness:.1f}___num_gap__{num_gap}___shift__{shift:.2f}.png'
                out_f.parent.mkdir(parents=True, exist_ok=True)
                cv2.imwrite(out_f.as_posix(), pil_ds(img_xy[us_sc:-us_sc, us_sc:-us_sc, ::-1].copy(), save_shape[0], save_shape[1]))
                print(f'Saved to {out_f.resolve()}')
                assign_item_to_dict(summary, ['xy', thickness, num_gap, shift], str(out_f.resolve().relative_to(out_root.resolve())))
    with (out_root / 'summary.json').open('w') as F:
        json.dump(summary, F)


def get_out_dir(work_dir, depth_f):
    return work_dir / 'contour_line' / str((work_dir / depth_f).resolve().relative_to(work_dir.resolve()))[:-4].replace('/', '_')


def main(args):
    save_dim_ub = args.save_dim_ub

    root = args.root
    rgb_f = root / 'rgb.png'
    data_f = root / args.depth_f

    raw_rgb = imread_rgb(rgb_f)

    crop_region = get_crop_region(args)
    if crop_region is not None and len(crop_region) > 0:
        lb_i, ub_i, lb_j, ub_j = crop_region
    else:
        lb_i, ub_i, lb_j, ub_j = 0, raw_rgb.shape[0], 0, raw_rgb.shape[1]
    raw_rgb = raw_rgb[lb_i:ub_i, lb_j:ub_j]

    save_sc = int(math.floor(save_dim_ub / np.linalg.norm(raw_rgb.shape[:2])))
    save_shape = (save_sc * raw_rgb.shape[0], save_sc * raw_rgb.shape[1])

    depth, intr, valid = load_data(data_f)
    xyz = depth_to_xyz(intr, depth)
    xyz = xyz[lb_i:ub_i, lb_j:ub_j]
    valid = valid[lb_i:ub_i, lb_j:ub_j]

    render_contour_line_imgs(xyz, valid, raw_rgb, save_shape, get_out_dir(root, args.depth_f))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("root", type=Path)
    parser.add_argument("--depth_f", type=str, help='Path to depth file, relative to root.')
    parser.add_argument('--rgb_f', type=str, nargs='?', const=None, default='rgb.png', help='Path to rgb file, relative to root.')
    parser.add_argument('--crop_region', nargs='*', type=int, default=[], help='Specify 4 integers: lb_i, ub_i, lb_j, ub_j, and only render mesh of [lb_i, ub_i)x[lb_j, ub_j)')
    parser.add_argument("--save_dim_ub", type=float, default=np.linalg.norm([1920, 1080]))
    args = parser.parse_args()

    main(args)
