import os

import BboxTools as bbt
import cv2
import numpy as np
import scipy.io as sio
from PIL import Image

from corr.utils import cal_point_weight
from corr.utils import rle_to_mask
from corr.utils.pascal3d_utils import CATEGORIES
from corr.utils.pascal3d_utils import get_anno
from corr.utils.pascal3d_utils import get_obj_ids
from corr.utils.pascal3d_utils import KP_LIST

mesh_para_names = [
    "azimuth",
    "elevation",
    "theta",
    # "distance",
    "focal",
    "principal",
    "viewport",
    "cad_index",
    "bbox",
]
cate_to_id = {cate: i for i, cate in enumerate(CATEGORIES)}


def get_target_distances(start=4.0, end=32.0, num=14):
    ranges = np.linspace(start, end, num + 1)
    return (
        np.random.rand(14).astype(np.float32) * (ranges[1:] - ranges[:-1]) + ranges[:-1]
    )


def prepare_pascal3d_sample(
    cate,
    img_name,
    img_path,
    anno_path,
    occ_level,
    save_image_path,
    save_annotation_path,
    out_shape,
    occ_path=None,
    prepare_mode="first",
    augment_by_dist=False,
    texture_filenames=None,
    texture_path=None,
    single_mesh=True,
    mesh_manager=None,
    direction_dicts=None,
    obj_ids=None,
    extra_anno=None,
    seg_mask_path=None,
    center_and_resize=True,
    skip_3d_anno=False
):
    """
    Prepare a sample for training and validation.

    Parameters
    ----------
    cate: str
    img_name: str
    img_path: str
    anno_path: str
    occ_level: int
    save_image_path: str
    save_annotation_path: str
    out_shape: list
    occ_path: str, default None
    prepare_mode: {'first', 'all'}, default 'first'
    augment_by_dist: bool, default False
    texture_filenames: list, default None
    texture_path: str, default None
    single_mesh: bool, default True
    mesh_manager: MeshConverter, default None
    direction_dicts: dict, default None
    obj_ids: list, default None
    """
    if not os.path.isfile(img_path):
        print(img_path)
        return None
    if not os.path.isfile(anno_path):
        print(anno_path)
        return None

    mat_contents = sio.loadmat(anno_path)
    record = mat_contents["record"][0][0]
    if occ_path is not None:
        occ_mask = np.load(occ_path, allow_pickle=True)["occluder_mask"].astype(np.uint8)
    else:
        occ_mask = None
    if seg_mask_path is not None and os.path.isfile(seg_mask_path):
        rle = np.load(seg_mask_path, allow_pickle=True)
        amodal_mask = rle_to_mask(rle).astype(np.uint8)
    else:
        amodal_mask = None

    if obj_ids is None:
        obj_ids = get_obj_ids(record, cate=cate)
        if len(obj_ids) == 0:
            return None
        if prepare_mode == "first":
            if obj_ids[0] != 0:
                return []
            else:
                obj_ids = [0]

    img = np.array(Image.open(img_path))
    _h, _w = img.shape[0], img.shape[1]

    save_image_names = []
    for obj_id in obj_ids:
        bbox = get_anno(record, "bbox", idx=obj_id)
        box = bbt.from_numpy(bbox, sorts=("x0", "y0", "x1", "y1"))

        if get_anno(record, "distance", idx=obj_id) <= 0:
            continue

        if center_and_resize:
            if augment_by_dist:
                target_distances = get_target_distances()
            else:
                target_distances = [5.0]

            dist = get_anno(record, "distance", idx=obj_id)
            all_resize_rates = [float(dist / x) for x in target_distances]
        else:
            all_resize_rates = [
                min(out_shape[0] / img.shape[0], out_shape[1] / img.shape[1])]

        for rr_idx, resize_rate in enumerate(all_resize_rates):
            if resize_rate <= 0.001:
                resize_rate = min(out_shape[0] / box.shape[0], out_shape[1] / box.shape[1])
            # try:
            box_ori = bbt.from_numpy(bbox, sorts=("x0", "y0", "x1", "y1"))
            box = bbt.from_numpy(bbox, sorts=("x0", "y0", "x1", "y1")) * resize_rate

            img = Image.open(img_path)
            # print('img_path: ', img_path)
            img_name = img_path.split('/')[-1].split('.')[0]
            # print(img_name)
            cate = 'bus'
        

            seg_path = f'./OmniNeMo/data/CorrData/UDApart/{cate}_imagenet'
            seg_fn = os.path.join(seg_path, img_name + '.png')

            if os.path.isfile(seg_fn):
                seg_img = Image.open(seg_fn)

                if img.mode != "RGB":
                    img = img.convert("RGB")
                img = np.array(img)
                seg_img = np.array(seg_img)
                box_ori = box_ori.set_boundary(img.shape[0:2])

                if center_and_resize:
                    dsize = (int(img.shape[1] * resize_rate), int(img.shape[0] * resize_rate))
                    img = cv2.resize(img, dsize=dsize)
                    seg_img = cv2.resize(seg_img, dsize=dsize, interpolation=cv2.INTER_NEAREST)
                    if occ_mask is not None:
                        occ_mask = cv2.resize(occ_mask, dsize=dsize, interpolation=cv2.INTER_NEAREST)
                    if amodal_mask is not None:
                        amodal_mask = cv2.resize(amodal_mask, dsize=dsize, interpolation=cv2.INTER_NEAREST)

                    center = (
                        get_anno(record, "principal", idx=obj_id)[::-1] * resize_rate
                    ).astype(int)

                    new_px, new_py = float(out_shape[1] // 2), float(out_shape[0] // 2)
                else:
                    resize_rate = min(out_shape[0] / img.shape[0], out_shape[1] / img.shape[1])
                    dsize = (int(img.shape[1] * resize_rate), int(img.shape[0] * resize_rate))
                    img = cv2.resize(img, dsize=dsize)
                    seg_img = cv2.resize(seg_img, dsize=dsize, interpolation=cv2.INTER_NEAREST)

                    center = np.array([img.shape[0] // 2, img.shape[1] // 2]).astype(np.int32)
                    new_px = float(get_anno(record, "principal", idx=obj_id)[0]) * resize_rate + (out_shape[1] - int(img.shape[1])) / 2
                    new_py = float(get_anno(record, "principal", idx=obj_id)[1]) * resize_rate + (out_shape[0] - int(img.shape[0])) / 2

                box1 = bbt.box_by_shape(out_shape, center)
                if (
                    out_shape[0] // 2 - center[0] > 0
                    or out_shape[1] // 2 - center[1] > 0
                    or out_shape[0] // 2 + center[0] - img.shape[0] > 0
                    or out_shape[1] // 2 + center[1] - img.shape[1] > 0
                ):
                    padding = (
                        (
                            max(out_shape[0] // 2 - center[0], 0),
                            max(out_shape[0] // 2 + center[0] - img.shape[0], 0),
                        ),
                        (
                            max(out_shape[1] // 2 - center[1], 0),
                            max(out_shape[1] // 2 + center[1] - img.shape[1], 0),
                        ),
                        (0, 0),
                    )

                    seg_padding = (
                        (
                            max(out_shape[0] // 2 - center[0], 0),
                            max(out_shape[0] // 2 + center[0] - img.shape[0], 0),
                        ),
                        (
                            max(out_shape[1] // 2 - center[1], 0),
                            max(out_shape[1] // 2 + center[1] - img.shape[1], 0),
                        ),
                    )

                    img = np.pad(img, padding, mode="constant")
                    # seg_img = np.pad(seg_img, seg_padding, mode="constant", constant_values=int(seg_img.max()))
                    seg_img = np.pad(seg_img, seg_padding, mode="constant", constant_values=0)

                    if occ_mask is not None:
                        occ_mask = np.pad(occ_mask, (padding[0], padding[1]), mode='constant')
                    if amodal_mask is not None:
                        amodal_mask = np.pad(amodal_mask, (padding[0], padding[1]), mode='constant')

                    box = box.shift([padding[0][0], padding[1][0]])
                    box1 = box1.shift([padding[0][0], padding[1][0]])
                else:
                    padding = ((0, 0), (0, 0), (0, 0))

                box_in_cropped = box.copy()
                box = box1.set_boundary(img.shape[0:2])
                box_in_cropped = box.box_in_box(box_in_cropped)

                bbox = box.bbox
                # img_cropped = box.apply(img)
                img_cropped = img[bbox[0][0]:bbox[0][1], bbox[1][0]:bbox[1][1], :]
                seg_img_cropped = seg_img[bbox[0][0]:bbox[0][1], bbox[1][0]:bbox[1][1]]
                print('img: ', img_cropped.shape)
                print('seg_img: ', seg_img_cropped.shape)
                if occ_mask is not None:
                    occ_mask = occ_mask[bbox[0][0]:bbox[0][1], bbox[1][0]:bbox[1][1]]
                if amodal_mask is not None:
                    amodal_mask = amodal_mask[bbox[0][0]:bbox[0][1], bbox[1][0]:bbox[1][1]]

                if amodal_mask is not None:
                    if occ_mask is not None:
                        inmodal_mask = amodal_mask * (1 - occ_mask)
                    else:
                        inmodal_mask = amodal_mask
                else:
                    inmodal_mask = None

                if augment_by_dist:
                    curr_img_name = f"{img_name}_{obj_id:02d}_aug{rr_idx}"
                else:
                    curr_img_name = f"{img_name}"

                save_parameters = dict(
                    name=img_name,
                    box=box.numpy(),
                    box_ori=box_ori.numpy(),
                    box_obj=box_in_cropped.numpy(),
                    # cropped_kp_list=cropped_kp_list,
                    # visible=states_list,
                    occ_mask=occ_mask,
                    amodal_mask=amodal_mask,
                    inmodal_mask=inmodal_mask,
                    px=new_px,
                    py=new_py,
                    distance=target_distances[rr_idx] if center_and_resize else get_anno(record, 'distance', idx=obj_id) / resize_rate
                )
                save_parameters = {
                    **save_parameters,
                    **{
                        k: v
                        for k, v in zip(
                            mesh_para_names, get_anno(record, *mesh_para_names, idx=obj_id)
                        )
                    },
                }
                save_parameters["height"] = _h
                save_parameters["width"] = _w
                save_parameters["resize_rate"] = resize_rate
                save_parameters["padding_params"] = np.array(
                    [
                        padding[0][0],
                        padding[0][1],
                        padding[1][0],
                        padding[1][1],
                        padding[2][0],
                        padding[2][1],
                    ]
                )

                if extra_anno is not None:
                    for k in extra_anno:
                        save_parameters[k] = extra_anno[k]

                # try:
                # Prepare 3D annotations for NeMo training
                if not skip_3d_anno and (mesh_manager is not None and direction_dicts is not None):

                    save_parameters["true_cad_index"] = save_parameters["cad_index"]
                    if single_mesh:
                        save_parameters["cad_index"] = 1

                    kps, vis = mesh_manager.get_one(save_parameters)
                    idx = save_parameters["cad_index"] - 1
                    weights = cal_point_weight(
                        direction_dicts[idx],
                        mesh_manager.loader[idx][0],
                        save_parameters,
                    )

                    save_parameters["kp_weights"] = np.abs(weights)
                    save_parameters["cropped_kp_list"] = kps
                    save_parameters["visible"] = vis

                save_to_pascal = './OmniNeMo/data/CorrData/pascalUDApart/'
                pascal_annotation = save_to_pascal + f'annotations/{cate}'
                if not os.path.exists(pascal_annotation):
                    os.makedirs(pascal_annotation)
                np.savez(
                    os.path.join(pascal_annotation, curr_img_name), **save_parameters
                )
                pascal_image = save_to_pascal + f'images/{cate}'
                if not os.path.exists(pascal_image):
                    os.makedirs(pascal_image)
                Image.fromarray(img_cropped).save(
                    os.path.join(pascal_image, curr_img_name + ".JPEG")
                )
                Image.fromarray(seg_img_cropped).save(
                    os.path.join(pascal_image, curr_img_name + "_seg.png")
                )
                vis_seg = seg_img_cropped / seg_img_cropped.max() * 255
                Image.fromarray(vis_seg.astype(np.uint8)).save(
                    os.path.join(pascal_image, curr_img_name + "_vis_seg.png")
                )
                print('save to ', os.path.join(pascal_image, curr_img_name + "_seg.png"))
                save_image_names.append(
                    (get_anno(record, "cad_index", idx=obj_id), curr_img_name)
                )

    return save_image_names


def prepare_pascal3d_sample_det(
    cate,
    img_name,
    img_path,
    anno_path,
    occ_level,
    save_image_path,
    save_annotation_path,
    out_shape,
    occ_path=None,
    prepare_mode="first",
    augment_by_dist=False,
    texture_filenames=None,
    texture_path=None,
    single_mesh=True,
    mesh_manager=None,
    direction_dicts=None,
    obj_ids=None,
    extra_anno=None,
    seg_mask_path=None,
    center_and_resize=True
):
    """
    Prepare a sample for training and validation.

    Parameters
    ----------
    cate: str
    img_name: str
    img_path: str
    anno_path: str
    occ_level: int
    save_image_path: str
    save_annotation_path: str
    out_shape: list
    occ_path: str, default None
    prepare_mode: {'first', 'all'}, default 'first'
    augment_by_dist: bool, default False
    texture_filenames: list, default None
    texture_path: str, default None
    single_mesh: bool, default True
    mesh_manager: MeshConverter, default None
    direction_dicts: dict, default None
    obj_ids: list, default None
    """
    if not os.path.isfile(img_path):
        print(img_path)
        return None
    if not os.path.isfile(anno_path):
        print(anno_path)
        return None

    mat_contents = sio.loadmat(anno_path)
    record = mat_contents["record"][0][0]

    if obj_ids is None:
        obj_ids = get_obj_ids(record, cate=cate)
        if len(obj_ids) == 0:
            return None
        if prepare_mode == "first":
            if obj_ids[0] != 0:
                return []
            else:
                obj_ids = [0]

    img = np.array(Image.open(img_path))
    _h, _w = img.shape[0], img.shape[1]

    filtered_obj_ids = []
    for obj_id in obj_ids:
        if get_anno(record, "distance", idx=obj_id) <= 0:
            continue
        filtered_obj_ids.append(obj_id)
    obj_ids = filtered_obj_ids

    boxes, labels, distances, azimuths, elevations, thetas = [], [], [], [], [], []
    for obj_id in obj_ids:
        boxes.append(get_anno(record, 'bbox', idx=obj_id))
        labels.append(cate_to_id[get_anno(record, 'category', idx=obj_id)])
        azimuths.append(get_anno(record, 'azimuth', idx=obj_id))
        elevations.append(get_anno(record, 'elevation', idx=obj_id))
        thetas.append(get_anno(record, 'theta', idx=obj_id))
        distances.append(get_anno(record, 'distance', idx=obj_id))
    boxes = [bbt.from_numpy(b, sorts=("x0", "y0", "x1", "y1")) for b in boxes]

    img = Image.open(img_path)
    if img.mode != "RGB":
        img = img.convert("RGB")
    img = np.array(img)

    resize_rate = min(out_shape[0] / img.shape[0], out_shape[1] / img.shape[1])
    dsize = (int(img.shape[1] * resize_rate), int(img.shape[0] * resize_rate))
    img = cv2.resize(img, dsize=dsize)
    boxes = [b * resize_rate for b in boxes]

    if texture_filenames is not None:
        texture_name = np.random.choice(texture_filenames)

    center = (out_shape[0]//2, out_shape[1]//2)
    box1 = bbt.box_by_shape(out_shape, center)
    if (
        out_shape[0] // 2 - center[0] > 0
        or out_shape[1] // 2 - center[1] > 0
        or out_shape[0] // 2 + center[0] - img.shape[0] > 0
        or out_shape[1] // 2 + center[1] - img.shape[1] > 0
    ):
        padding = (
            (
                max(out_shape[0] // 2 - center[0], 0),
                max(out_shape[0] // 2 + center[0] - img.shape[0], 0),
            ),
            (
                max(out_shape[1] // 2 - center[1], 0),
                max(out_shape[1] // 2 + center[1] - img.shape[1], 0),
            ),
            (0, 0),
        )

        if texture_filenames is None:
            img = np.pad(img, padding, mode="constant")
        else:
            texture_img = Image.open(
                os.path.join(texture_path, "images", texture_name)
            )
            if texture_img.mode != "RGB":
                texture_img = texture_img.convert("RGB")
            texture_img = np.array(texture_img)
            texture_img = cv2.resize(
                texture_img,
                dsize=(
                    img.shape[1] + padding[1][0] + padding[1][1],
                    img.shape[0] + padding[0][0] + padding[0][1],
                ),
            )
            texture_img[
                padding[0][0] : padding[0][0] + img.shape[0],
                padding[1][0] : padding[1][0] + img.shape[1],
                :,
            ] = img
            img = texture_img

        boxes = [b.shift([padding[0][0], padding[1][0]]) for b in boxes]
        box1.shift([padding[0][0], padding[1][0]])
    else:
        padding = ((0, 0), (0, 0), (0, 0))

    box1 = box1.set_boundary(img.shape[0:2])

    bbox = box1.bbox
    img_cropped = img[bbox[0][0]:bbox[0][1], bbox[1][0]:bbox[1][1], :]

    boxes = [box1.box_in_box(b) for b in boxes]
    distances = [d / resize_rate for d in distances]

    save_parameters = {}
    save_parameters["img_name"] = img_name
    save_parameters["boxes"] = [b.numpy() for b in boxes]
    save_parameters["distances"] = distances
    save_parameters["azimuth"] = azimuths
    save_parameters["elevation"] = elevations
    save_parameters["theta"] = thetas
    save_parameters["num_obj"] = len(azimuths)
    save_parameters["height"] = _h
    save_parameters["width"] = _w
    save_parameters["resize_rate"] = resize_rate
    save_parameters["padding_params"] = np.array(
        [
            padding[0][0],
            padding[0][1],
            padding[1][0],
            padding[1][1],
            padding[2][0],
            padding[2][1],
        ]
    )

    if texture_filenames is not None:
        save_parameters["texture_name"] = texture_name

    np.savez(
        os.path.join(save_annotation_path, img_name), **save_parameters
    )
    Image.fromarray(img_cropped).save(
        os.path.join(save_image_path, img_name + ".JPEG")
    )

    return [(1, img_name)]
