import torch
import torch.nn as nn
from copy import copy, deepcopy

from eval.dataset_utils.corr import geotrf, inv


def invalid_to_nans(arr, valid_mask, ndim=999):
    if valid_mask is not None:
        arr = arr.clone()
        arr[~valid_mask] = float("nan")
    if arr.ndim > ndim:
        arr = arr.flatten(-2 - (arr.ndim - ndim), -2)
    return arr


def invalid_to_zeros(arr, valid_mask, ndim=999):
    if valid_mask is not None:
        arr = arr.clone()
        arr[~valid_mask] = 0
        nnz = valid_mask.view(len(valid_mask), -1).sum(1)
    else:
        nnz = arr.numel() // len(arr) if len(arr) else 0  # number of point per image
    if arr.ndim > ndim:
        arr = arr.flatten(-2 - (arr.ndim - ndim), -2)
    return arr, nnz


class BaseCriterion(nn.Module):
    def __init__(self, reduction="mean"):
        super().__init__()
        self.reduction = reduction


class Criterion(nn.Module):
    def __init__(self, criterion=None):
        super().__init__()
        assert isinstance(
            criterion, BaseCriterion
        ), f"{criterion} is not a proper criterion!"
        self.criterion = copy(criterion)

    def get_name(self):
        return f"{type(self).__name__}({self.criterion})"

    def with_reduction(self, mode="none"):
        res = loss = deepcopy(self)
        while loss is not None:
            assert isinstance(loss, Criterion)
            loss.criterion.reduction = mode  # make it return the loss for each sample
            loss = loss._loss2  # we assume loss is a Multiloss
        return res


class MultiLoss(nn.Module):
    """Easily combinable losses (also keep track of individual loss values):
        loss = MyLoss1() + 0.1*MyLoss2()
    Usage:
        Inherit from this class and override get_name() and compute_loss()
    """

    def __init__(self):
        super().__init__()
        self._alpha = 1
        self._loss2 = None

    def compute_loss(self, *args, **kwargs):
        raise NotImplementedError()

    def get_name(self):
        raise NotImplementedError()

    def __mul__(self, alpha):
        assert isinstance(alpha, (int, float))
        res = copy(self)
        res._alpha = alpha
        return res

    __rmul__ = __mul__  # same

    def __add__(self, loss2):
        assert isinstance(loss2, MultiLoss)
        res = cur = copy(self)

        while cur._loss2 is not None:
            cur = cur._loss2
        cur._loss2 = loss2
        return res

    def __repr__(self):
        name = self.get_name()
        if self._alpha != 1:
            name = f"{self._alpha:g}*{name}"
        if self._loss2:
            name = f"{name} + {self._loss2}"
        return name

    def forward(self, *args, **kwargs):
        loss = self.compute_loss(*args, **kwargs)
        if isinstance(loss, tuple):
            loss, details = loss
        elif loss.ndim == 0:
            details = {self.get_name(): float(loss)}
        else:
            details = {}
        loss = loss * self._alpha

        if self._loss2:
            loss2, details2 = self._loss2(*args, **kwargs)
            loss = loss + loss2
            details |= details2

        return loss, details


class LLoss(BaseCriterion):
    """L-norm loss"""

    def forward(self, a, b):
        assert (
            a.shape == b.shape and a.ndim >= 2 and 1 <= a.shape[-1] <= 3
        ), f"Bad shape = {a.shape}"
        dist = self.distance(a, b)

        if self.reduction == "none":
            return dist
        if self.reduction == "sum":
            return dist.sum()
        if self.reduction == "mean":
            return dist.mean() if dist.numel() > 0 else dist.new_zeros(())
        raise ValueError(f"bad {self.reduction=} mode")

    def distance(self, a, b):
        raise NotImplementedError()


class L21Loss(LLoss):
    """Euclidean distance between 3d points"""

    def distance(self, a, b):
        return torch.norm(a - b, dim=-1)  # normalized L2 distance


L21 = L21Loss()


def get_pred_pts3d(gt, pred, use_pose=False):
    assert use_pose is True
    return pred["pts3d_in_other_view"]  # return!


def Sum(losses, masks, conf=None):
    loss, mask = losses[0], masks[0]
    if loss.ndim > 0:
        # we are actually returning the loss for every pixels
        if conf is not None:
            return losses, masks, conf
        return losses, masks
    else:
        # we are returning the global loss
        for loss2 in losses[1:]:
            loss = loss + loss2
        return loss


def get_norm_factor(pts, norm_mode="avg_dis", valids=None, fix_first=True):
    assert pts[0].ndim >= 3 and pts[0].shape[-1] == 3
    assert pts[1] is None or (pts[1].ndim >= 3 and pts[1].shape[-1] == 3)
    norm_mode, dis_mode = norm_mode.split("_")

    nan_pts = []
    nnzs = []

    if norm_mode == "avg":
        # gather all points together (joint normalization)

        for i, pt in enumerate(pts):
            nan_pt, nnz = invalid_to_zeros(pt, valids[i], ndim=3)
            nan_pts.append(nan_pt)
            nnzs.append(nnz)

            if fix_first:
                break
        all_pts = torch.cat(nan_pts, dim=1)

        # compute distance to origin
        all_dis = all_pts.norm(dim=-1)
        if dis_mode == "dis":
            pass  # do nothing
        elif dis_mode == "log1p":
            all_dis = torch.log1p(all_dis)
        else:
            raise ValueError(f"bad {dis_mode=}")

        norm_factor = all_dis.sum(dim=1) / (torch.cat(nnzs).sum() + 1e-8)
    else:
        raise ValueError(f"Not implemented {norm_mode=}")

    norm_factor = norm_factor.clip(min=1e-8)
    while norm_factor.ndim < pts[0].ndim:
        norm_factor.unsqueeze_(-1)

    return norm_factor


def normalize_pointcloud_t(
    pts, norm_mode="avg_dis", valids=None, fix_first=True, gt=False
):
    if gt:
        norm_factor = get_norm_factor(pts, norm_mode, valids, fix_first)
        res = []

        for i, pt in enumerate(pts):
            res.append(pt / norm_factor)

    else:
        # pts_l, pts_r = pts
        # use pts_l and pts_r[-1] as pts to normalize
        norm_factor = get_norm_factor(pts, norm_mode, valids, fix_first)

        res = []

        for i in range(len(pts)):
            res.append(pts[i] / norm_factor)
            # res_r.append(pts_r[i] / norm_factor)

        # res = [res_l, res_r]

    return res, norm_factor


@torch.no_grad()
def get_joint_pointcloud_depth(zs, valid_masks=None, quantile=0.5):
    # set invalid points to NaN
    _zs = []
    for i in range(len(zs)):
        valid_mask = valid_masks[i] if valid_masks is not None else None
        _z = invalid_to_nans(zs[i], valid_mask).reshape(len(zs[i]), -1)
        _zs.append(_z)

    _zs = torch.cat(_zs, dim=-1)

    # compute median depth overall (ignoring nans)
    if quantile == 0.5:
        shift_z = torch.nanmedian(_zs, dim=-1).values
    else:
        shift_z = torch.nanquantile(_zs, quantile, dim=-1)
    return shift_z  # (B,)


@torch.no_grad()
def get_joint_pointcloud_center_scale(pts, valid_masks=None, z_only=False, center=True):
    # set invalid points to NaN

    _pts = []
    for i in range(len(pts)):
        valid_mask = valid_masks[i] if valid_masks is not None else None
        _pt = invalid_to_nans(pts[i], valid_mask).reshape(len(pts[i]), -1, 3)
        _pts.append(_pt)

    _pts = torch.cat(_pts, dim=1)

    # compute median center
    _center = torch.nanmedian(_pts, dim=1, keepdim=True).values  # (B,1,3)
    if z_only:
        _center[..., :2] = 0  # do not center X and Y

    # compute median norm
    _norm = ((_pts - _center) if center else _pts).norm(dim=-1)
    scale = torch.nanmedian(_norm, dim=1).values
    return _center[:, None, :, :], scale[:, None, None, None]


class Regr3D_t(Criterion, MultiLoss):
    def __init__(self, criterion, norm_mode="avg_dis", gt_scale=False, fix_first=True):
        super().__init__(criterion)
        self.norm_mode = norm_mode
        self.gt_scale = gt_scale
        self.fix_first = fix_first

    def get_all_pts3d_t(self, gts, preds, dist_clip=None):
        # everything is normalized w.r.t. camera of view1
        in_camera1 = inv(gts[0]["camera_pose"])

        gt_pts = []
        valids = []
        pr_pts = []

        for i, gt in enumerate(gts):
            # in_camera1: Bs, 4, 4 gt['pts3d']: Bs, H, W, 3
            gt_pts.append(geotrf(in_camera1, gt["pts3d"]))
            valid = gt["valid_mask"].clone()

            if dist_clip is not None:
                # points that are too far-away == invalid
                dis = gt["pts3d"].norm(dim=-1)
                valid = valid & (dis <= dist_clip)

            valids.append(valid)
            pr_pts.append(get_pred_pts3d(gt, preds[i], use_pose=True))
            # if i != len(gts)-1:
            #     pr_pts_l.append(get_pred_pts3d(gt, preds[i][0], use_pose=(i!=0)))

            # if i != 0:
            #     pr_pts_r.append(get_pred_pts3d(gt, preds[i-1][1], use_pose=(i!=0)))

        # pr_pts = (pr_pts_l, pr_pts_r)

        if self.norm_mode:
            pr_pts, pr_factor = normalize_pointcloud_t(
                pr_pts, self.norm_mode, valids, fix_first=self.fix_first, gt=False
            )
        else:
            pr_factor = None

        if self.norm_mode and not self.gt_scale:
            gt_pts, gt_factor = normalize_pointcloud_t(
                gt_pts, self.norm_mode, valids, fix_first=self.fix_first, gt=True
            )
        else:
            gt_factor = None

        return gt_pts, pr_pts, gt_factor, pr_factor, valids, {}

    def compute_frame_loss(self, gts, preds, **kw):
        gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring = (
            self.get_all_pts3d_t(gts, preds, **kw)
        )

        pred_pts_l, pred_pts_r = pred_pts

        loss_all = []
        mask_all = []
        conf_all = []

        loss_left = 0
        loss_right = 0
        pred_conf_l = 0
        pred_conf_r = 0

        for i in range(len(gt_pts)):

            # Left (Reference)
            if i != len(gt_pts) - 1:
                frame_loss = self.criterion(
                    pred_pts_l[i][masks[i]], gt_pts[i][masks[i]]
                )

                loss_all.append(frame_loss)
                mask_all.append(masks[i])
                conf_all.append(preds[i][0]["conf"])

                # To compare target/reference loss
                if i != 0:
                    loss_left += frame_loss.cpu().detach().numpy().mean()
                    pred_conf_l += preds[i][0]["conf"].cpu().detach().numpy().mean()

            # Right (Target)
            if i != 0:
                frame_loss = self.criterion(
                    pred_pts_r[i - 1][masks[i]], gt_pts[i][masks[i]]
                )

                loss_all.append(frame_loss)
                mask_all.append(masks[i])
                conf_all.append(preds[i - 1][1]["conf"])

                # To compare target/reference loss
                if i != len(gt_pts) - 1:
                    loss_right += frame_loss.cpu().detach().numpy().mean()
                    pred_conf_r += preds[i - 1][1]["conf"].cpu().detach().numpy().mean()

        if pr_factor is not None and gt_factor is not None:
            filter_factor = pr_factor[pr_factor > gt_factor]
        else:
            filter_factor = []

        if len(filter_factor) > 0:
            factor_loss = (filter_factor - gt_factor).abs().mean()
        else:
            factor_loss = 0.0

        self_name = type(self).__name__
        details = {
            self_name + "_pts3d_1": float(loss_all[0].mean()),
            self_name + "_pts3d_2": float(loss_all[1].mean()),
            self_name + "loss_left": float(loss_left),
            self_name + "loss_right": float(loss_right),
            self_name + "conf_left": float(pred_conf_l),
            self_name + "conf_right": float(pred_conf_r),
        }

        return Sum(loss_all, mask_all, conf_all), (details | monitoring), factor_loss


class ConfLoss_t(MultiLoss):
    """Weighted regression by learned confidence.
        Assuming the input pixel_loss is a pixel-level regression loss.

    Principle:
        high-confidence means high conf = 0.1 ==> conf_loss = x / 10 + alpha*log(10)
        low  confidence means low  conf = 10  ==> conf_loss = x * 10 - alpha*log(10)

        alpha: hyperparameter
    """

    def __init__(self, pixel_loss, alpha=1):
        super().__init__()
        assert alpha > 0
        self.alpha = alpha
        self.pixel_loss = pixel_loss.with_reduction("none")

    def get_name(self):
        return f"ConfLoss({self.pixel_loss})"

    def get_conf_log(self, x):
        return x, torch.log(x)

    def compute_frame_loss(self, gts, preds, **kw):
        # compute per-pixel loss
        (losses, masks, confs), details, loss_factor = (
            self.pixel_loss.compute_frame_loss(gts, preds, **kw)
        )

        # weight by confidence
        conf_losses = []
        conf_sum = 0
        for i in range(len(losses)):
            conf, log_conf = self.get_conf_log(confs[i][masks[i]])
            conf_sum += conf.mean()
            conf_loss = losses[i] * conf - self.alpha * log_conf
            conf_loss = conf_loss.mean() if conf_loss.numel() > 0 else 0
            conf_losses.append(conf_loss)

        conf_losses = torch.stack(conf_losses) * 2.0
        conf_loss_mean = conf_losses.mean()

        return (
            conf_loss_mean,
            dict(
                conf_loss_1=float(conf_losses[0]),
                conf_loss2=float(conf_losses[1]),
                conf_mean=conf_sum / len(losses),
                **details,
            ),
            loss_factor,
        )


class Regr3D_t_ShiftInv(Regr3D_t):
    """Same than Regr3D but invariant to depth shift."""

    def get_all_pts3d_t(self, gts, preds):
        # compute unnormalized points
        gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring = (
            super().get_all_pts3d_t(gts, preds)
        )

        # pred_pts_l, pred_pts_r = pred_pts
        gt_zs = [gt_pt[..., 2] for gt_pt in gt_pts]

        pred_zs = [pred_pt[..., 2] for pred_pt in pred_pts]
        # pred_zs.append(pred_pts_r[-1][..., 2])

        # compute median depth
        gt_shift_z = get_joint_pointcloud_depth(gt_zs, masks)[:, None, None]
        pred_shift_z = get_joint_pointcloud_depth(pred_zs, masks)[:, None, None]

        # subtract the median depth
        for i in range(len(gt_pts)):
            gt_pts[i][..., 2] -= gt_shift_z

        for i in range(len(pred_pts)):
            # for j in range(len(pred_pts[i])):
            pred_pts[i][..., 2] -= pred_shift_z

        monitoring = dict(
            monitoring,
            gt_shift_z=gt_shift_z.mean().detach(),
            pred_shift_z=pred_shift_z.mean().detach(),
        )
        return gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring


class Regr3D_t_ScaleInv(Regr3D_t):
    """Same than Regr3D but invariant to depth shift.
    if gt_scale == True: enforce the prediction to take the same scale than GT
    """

    def get_all_pts3d_t(self, gts, preds):
        # compute depth-normalized points
        gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring = (
            super().get_all_pts3d_t(gts, preds)
        )

        # measure scene scale

        # pred_pts_l, pred_pts_r = pred_pts

        pred_pts_all = [
            x.clone() for x in pred_pts
        ]  # [pred_pt for pred_pt in pred_pts_l]
        # pred_pts_all.append(pred_pts_r[-1])

        _, gt_scale = get_joint_pointcloud_center_scale(gt_pts, masks)
        _, pred_scale = get_joint_pointcloud_center_scale(pred_pts_all, masks)

        # prevent predictions to be in a ridiculous range
        pred_scale = pred_scale.clip(min=1e-3, max=1e3)

        # subtract the median depth
        if self.gt_scale:
            for i in range(len(pred_pts)):
                # for j in range(len(pred_pts[i])):
                pred_pts[i] *= gt_scale / pred_scale

        else:
            for i in range(len(pred_pts)):
                # for j in range(len(pred_pts[i])):
                pred_pts[i] *= pred_scale / gt_scale

            for i in range(len(gt_pts)):
                gt_pts[i] *= gt_scale / pred_scale

        monitoring = dict(
            monitoring, gt_scale=gt_scale.mean(), pred_scale=pred_scale.mean().detach()
        )

        return gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring


class Regr3D_t_ScaleShiftInv(Regr3D_t_ScaleInv, Regr3D_t_ShiftInv):
    # calls Regr3D_ShiftInv first, then Regr3D_ScaleInv
    pass
