import itertools

import numpy as np
from pymvg.camera_model import CameraModel
from pymvg.multi_camera_system import MultiCameraSystem
import time

# from multiviews.cameras import unfold_camera_param


# def build_multi_camera_system(cameras, no_distortion=True):
#     """
#     Build a multi-camera system with pymvg package for triangulation

#     Args:
#         cameras: list of camera parameters
#     Returns:
#         cams_system: a multi-cameras system
#     """
#     pymvg_cameras = []
#     for (name, camera) in cameras:
#         R, T, f, c, k, p = unfold_camera_param(camera, avg_f=False)
#         camera_matrix = np.array(
#             [[f[0], 0, c[0]], [0, f[1], c[1]], [0, 0, 1]], dtype=float)
#         proj_matrix = np.zeros((3, 4))
#         proj_matrix[:3, :3] = camera_matrix
#         distortion = np.array([k[0], k[1], p[0], p[1], k[2]])
#         distortion.shape = (5,)
#         T = -np.matmul(R, T)
#         M = camera_matrix.dot(np.concatenate((R, T), axis=1))
#         camera = CameraModel.load_camera_from_M(
#             M, name=name, distortion_coefficients=None if no_distortion else distortion)
#         if not no_distortion:
#             camera.distortion = distortion  # bug with pymvg
#         pymvg_cameras.append(camera)
#     return MultiCameraSystem(pymvg_cameras)


def build_multi_camera_system(camera_list, no_distortion=True):
    """
    Build a multi-camera system with pymvg package for triangulation

    Args:
        camera_list: list of (name, camera_obj)
    Returns:
        cams_system: a multi-cameras system
    """
    distortion = None

    pymvg_cameras = []
    for (name, camera) in camera_list:
        M = camera.get_intrinsic() @ camera.get_extrinsic(homo=False)
        camera_model = CameraModel.load_camera_from_M(
            M, name=name, distortion_coefficients=None if no_distortion else distortion)
        if not no_distortion:
            camera_model.distortion = distortion  # bug with pymvg
        pymvg_cameras.append(camera_model)
    return MultiCameraSystem(pymvg_cameras)


def triangulate_one_point(camera_system, points_2d_set):
    """
    Triangulate 3d point in world coordinates with multi-view 2d points

    Args:
        camera_system: pymvg camera system
        points_2d_set: list of structure (camera_name, point2d)
    Returns:
        points_3d: 3x1 point in world coordinates
    """
    points_3d = camera_system.find3d(points_2d_set, undistort=False)
    return points_3d


def triangulate_poses(camera_objs, poses2d, joints_vis=None, no_distortion=True, nviews=4):
    """
    Triangulate 3d points in world coordinates of multi-view 2d poses
    by interatively calling $triangulate_one_point$

    Args:
        camera_objs: [N*C] a list of camera parameters, each corresponding to
                       one prediction in poses2d
        poses2d: [N*C, j, 2], [human1_view1, human1_view2,..., human2_view1, human2_view2,...]
        joints_vis: [N*C, j], only visible joints participate in triangulation
    Returns:
        poses3d: ndarray of shape N x j x 3
    """
    njoints = poses2d.shape[1]
    ninstances = len(camera_objs) // nviews
    if joints_vis is not None:
        assert np.all(joints_vis.shape == poses2d.shape[:2])
    else:
        joints_vis = np.ones((poses2d.shape[0], poses2d.shape[1]))

    poses3d = []
    for i in range(ninstances):
        camera_list = []
        for j in range(nviews):
            camera_name = 'camera_{}'.format(j)
            camera_list.append((camera_name, camera_objs[i * nviews + j]))
        # start = time.time()
        camera_system = build_multi_camera_system(camera_list, no_distortion)
        # print(f'build time: {time.time() - start}')

        pose3d = np.zeros((njoints, 3))
        for k in range(njoints):
            points_2d_set = []

            for j in range(nviews):
                if joints_vis[i * nviews + j, k]:
                    camera_name = 'camera_{}'.format(j)
                    points_2d = poses2d[i * nviews + j, k, :]
                    points_2d_set.append((camera_name, points_2d))
            if len(points_2d_set) < 2:
                continue
            pose3d[k, :] = triangulate_one_point(camera_system, points_2d_set).T
        poses3d.append(pose3d)
    return np.array(poses3d)


def ransac(poses2d, camera_objs, joints_vis, config, nviews=4):
    """
    An group is accepted only if support inliers are not less
    than config.PSEUDO_LABEL.NUM_INLIERS, i.e. num of Trues
    in a 4-view group is not less than config.PSEUDO_LABEL.NUM_INLIERS
    Param:
        poses2d: [N, 16, 2]
        camera_objs: a list of [N]
        joints_vis: [N, 16], only visible joints participate in triangulation
    Return:
        res_vis: [N, 16]
    """
    njoints = poses2d.shape[1]
    ninstances = len(camera_objs) // nviews

    res_vis = np.zeros_like(joints_vis)
    for i in range(ninstances):
        camera_list = []
        for j in range(nviews):
            camera_name = 'camera_{}'.format(j)
            camera_list.append((camera_name, camera_objs[i * nviews + j]))
        camera_system = build_multi_camera_system(camera_list, config.DATASET.NO_DISTORTION)

        for k in range(njoints):
            points_2d_set = []

            for j in range(nviews):
                camera_name = 'camera_{}'.format(j)
                # select out visible points from all 4 views
                if joints_vis[i * nviews + j, k]:
                    points_2d = poses2d[i * nviews + j, k, :]
                    points_2d_set.append((camera_name, points_2d))

            # points < 2, invalid instance, abandon samples of 1 view
            if len(points_2d_set) < 2:
                continue

            best_inliers = []
            best_error = 10000
            for points_pair in itertools.combinations(points_2d_set, 2):
                point_3d = triangulate_one_point(camera_system, list(points_pair)).T
                in_thre = []
                mean_error = 0
                for j in range(nviews):
                    point_2d_proj = camera_system.find2d('camera_{}'.format(j), point_3d)
                    error = np.linalg.norm(point_2d_proj - poses2d[i * nviews + j, k, :])
                    if error < config.PSEUDO_LABEL.REPROJ_THRE:
                        in_thre.append(j)
                        mean_error += error
                num_inliers = len(in_thre)
                if num_inliers < config.PSEUDO_LABEL.NUM_INLIERS:
                    continue
                mean_error /= num_inliers
                # update best candidate
                if num_inliers > len(best_inliers):
                    best_inliers = in_thre
                    best_error = mean_error
                elif num_inliers == len(best_inliers):
                    if mean_error < best_error:
                        best_inliers = in_thre
                        best_error = mean_error
            for idx_view in best_inliers:
                res_vis[i * nviews + idx_view, k] = 1
    return res_vis


def reproject_poses(poses2d, camera_objs, joints_vis, no_distortion=True, nviews=4):
    """
    Triangulate 3d points in world coordinates of multi-view 2d poses
    by interatively calling $triangulate_one_point$

    Args:
        camera_objs: a list of camera objectss, each corresponding to
                       one prediction in poses2d
        poses2d: [N, k, 2], len(cameras) == N
        joints_vis: [N, k], only visible joints participate in triangulatioin
    Returns:
        proj_2d: ndarray of shape [N, k, 2]
        res_vis: [N, k]
    """
    njoints = poses2d.shape[1]
    ninstances = len(camera_objs) // nviews
    assert np.all(joints_vis.shape == poses2d.shape[:2])
    proj_2d = np.zeros_like(poses2d)  # [N, 16, 2]
    res_vis = np.zeros_like(joints_vis)

    for i in range(ninstances):
        camera_list = []
        for j in range(nviews):
            camera_name = 'camera_{}'.format(j)
            camera_list.append((camera_name, camera_objs[i * nviews + j]))
        camera_system = build_multi_camera_system(camera_list, no_distortion)

        for k in range(njoints):
            points_2d_set = []

            for j in range(nviews):
                if joints_vis[i * nviews + j, k]:
                    camera_name = 'camera_{}'.format(j)
                    points_2d = poses2d[i * nviews + j, k, :]
                    points_2d_set.append((camera_name, points_2d))
            if len(points_2d_set) < 2:
                continue
            point_3d = triangulate_one_point(camera_system, points_2d_set).T

            for j in range(nviews):
                point_2d_proj = camera_system.find2d('camera_{}'.format(j), point_3d)
                proj_2d[i * nviews + j, k, :] = point_2d_proj
                res_vis[i * nviews + j, k] = 1
    return proj_2d, res_vis


def fast_triangulate(camera_objs, poses2d, joints_vis=None):
    """
    Triangulate 3d points with DLT,
    confidence scores can be passed to joints_vis

    Args:
        camera_objs: [C] a list of camera parameters, each corresponding to
                       one prediction in poses2d
        poses2d: [C, N, j, 2], [human1_view1, human1_view2,..., human2_view1, human2_view2,...]
        joints_vis: [C, N, j], only visible joints participate in triangulation
    Returns:
        poses3d: ndarray of shape [N, J, 3]
    """
    num_cams = len(camera_objs)

    if joints_vis is not None:
        # assert np.all(joints_vis.shape == poses2d.shape[:-1])
        pass  # no check for faster
    else:
        joints_vis = np.ones(poses2d.shape[:-1])

    # return 0 if less than 2 covered
    covered_mask = np.sum(joints_vis > 0, axis=0) >= 2  # [N, J]
    joints_vis = covered_mask[None, ...].astype(np.int32) * joints_vis  # [C, N, j]

    # make projection matrix
    P = np.zeros((num_cams, 3, 4))  # [C, 3, 4]

    for idx, CamModel in enumerate(camera_objs):
        P[idx] = CamModel.get_intrinsic() @ CamModel.get_extrinsic(homo=False)

    P = P[None, None, ...]  # [1, 1, C, 3, 4]
    row0 = P[..., 0, :]  # [1, 1, C, 4]
    row1 = P[..., 1, :]  # [1, 1, C, 4]
    row2 = P[..., 2, :]  # [1, 1, C, 4]
    # row0, row1, row2 = np.split(, 3, axis=-2)  # [1, 1, C, 4]

    poses2d = poses2d.transpose(1, 2, 0, 3)  # [N, J, C, 2]

    joints_vis = joints_vis[..., None].transpose(1, 2, 0, 3)  # [N, J, C, 1]
    eq1 = poses2d[..., [0]] * row2 - row0  # [N, J, C, 4]
    eq2 = poses2d[..., [1]] * row2 - row1  # [N, J, C, 4]

    eq1 = eq1 * joints_vis
    eq2 = eq2 * joints_vis

    A = np.concatenate((eq1, eq2), axis=-2)  # [N, J, 2C, 4]

    # batch SVD on [2C, 4]
    u, s, vh = np.linalg.svd(A)  # vh: [N, J, 4, 4]

    points_un = vh[..., -1, :3]  # [N, J, 3]
    points_scale = vh[..., -1, [3]]  # [N, J, 1]
    points = np.divide(
        points_un,
        points_scale,
        where = points_scale != 0,
        out = np.zeros_like(points_un),
    )  # [N, J, 3]

    return points


