from pathlib import Path
import json
from functools import reduce

import torch
import numpy as np
import cv2 as cv
from typing import NamedTuple

MIN_DEPTH=1e-1
MAX_DEPTH=1e3

def reformat_input(x):
    if not isinstance(x, torch.Tensor):
        x = torch.from_numpy(x)
    # convert to float unless it is a bool
    if not x.dtype == torch.bool:
        x = x.to(torch.float32)
    return x

def absrel(pred, target, mask):
    def _absrel(t_m, p_m):
        err_heatmap = torch.abs(t_m - p_m) / (t_m + 1e-10)  # (H, W)
        err = err_heatmap.sum() / t_m.shape[0]
        assert not (torch.isnan(err) | torch.isinf(err))
        return err.item()

    if mask.sum() == 0:
        return 0.
    return _absrel(target[mask], pred[mask])


def rmse(pred, target, mask):
    def _rmse(t_m, p_m):
        err_heatmap = (t_m - p_m) ** 2  # (H, W)
        err = torch.sqrt(err_heatmap.sum() / t_m.shape[0])
        assert not (torch.isnan(err) | torch.isinf(err))
        return err.item()

    if mask.sum() == 0:
        return 0.
    return _rmse(target[mask], pred[mask])


def rmse_log(pred, target, mask):
    def _rmse_log(t_m, p_m):
        err_heatmap = (torch.log10(p_m+1e-10) - torch.log10(t_m+1e-10)) ** 2  # (H, W)
        err = torch.sqrt(err_heatmap.sum() / t_m.shape[0])
        assert not (torch.isnan(err) | torch.isinf(err))
        return err.item()

    if mask.sum() == 0:
        return 0.
    return _rmse_log(target[mask], pred[mask])


def delta(pred, target, mask):
    def _delta(t_m, p_m):
        gt_pred = t_m / (p_m + 1e-10)  # (H, W)
        pred_gt = p_m / (t_m + 1e-10)  # (H, W)
        gt_pred_gt = torch.stack([gt_pred, pred_gt], dim=-1)  # (H, W, 2)
        ratio_max = torch.amax(gt_pred_gt, dim=-1)  # (H, W)

        delta_0125_sum = torch.sum(ratio_max < 1.25 ** 0.125)
        delta_025_sum = torch.sum(ratio_max < 1.25 ** 0.25)
        delta_05_sum = torch.sum(ratio_max < 1.25 ** 0.5)
        delta_1_sum = torch.sum(ratio_max < 1.25)
        delta_2_sum = torch.sum(ratio_max < 1.25 ** 2)
        delta_3_sum = torch.sum(ratio_max < 1.25 ** 3)
        delta_0125, delta_025, delta_05 = (delta_0125_sum / t_m.shape[0]), (delta_025_sum / t_m.shape[0]), (delta_05_sum / t_m.shape[0])
        delta_1, delta_2, delta_3 = (delta_1_sum / t_m.shape[0]), (delta_2_sum / t_m.shape[0]), (delta_3_sum / t_m.shape[0])
        assert not (torch.isnan(delta_0125) | torch.isinf(delta_0125))
        assert not (torch.isnan(delta_025) | torch.isinf(delta_025))
        assert not (torch.isnan(delta_05) | torch.isinf(delta_05))
        assert not (torch.isnan(delta_1) | torch.isinf(delta_1))
        assert not (torch.isnan(delta_2) | torch.isinf(delta_2))
        assert not (torch.isnan(delta_3) | torch.isinf(delta_3))
        return delta_0125.item(), delta_025.item(), delta_05.item(), delta_1.item(), delta_2.item(), delta_3.item()

    if mask.sum() == 0:
        return 0., 0., 0., 0., 0., 0.
    return _delta(target[mask], pred[mask])


def log10(pred, target, mask):
    def _log10(t_m, p_m):
        err_heatmap = torch.abs((torch.log10(p_m+1e-10) - torch.log10(t_m+1e-10)))
        err = err_heatmap.sum() / t_m.shape[0]
        assert not (torch.isnan(err) | torch.isinf(err))
        return err.item()

    if mask.sum() == 0:
        return 0.
    return _log10(target[mask], pred[mask])

def rel_depth(pred, target, mask):
    pred, target, mask = reformat_input(pred), reformat_input(target), reformat_input(mask)
    assert pred.dim() == 2 and target.dim() == 2 and mask.dim() == 2
    mask = mask > .5
    if mask.sum() == 0:
        return 0.
    p, t = pred[mask], target[mask]
    device = p.device
    N = p.shape[0]
    M = int(1e7)
    i = torch.randint(0, N, (M,), device=device, dtype=torch.long)
    j = torch.randint(0, N, (M,), device=device, dtype=torch.long)
    correct = (p[i] < p[j]) == (t[i] < t[j])
    return correct.float().mean().item()

def align_disparity(true, pred, mask, scale_only=False):
    assert not scale_only # not needed for models selected
    true = 1 / true
    if mask is None:
        b = true.reshape(-1)
        A = np.stack([pred.reshape(-1), np.ones_like(b)], axis=1)
    else:
        b = true[mask]
        A = np.stack([pred[mask], np.ones_like(b)], axis=1)

    assert np.isfinite(A).all() and np.isfinite(b).all(), ((~np.isfinite(A)).sum(), (~np.isfinite(b)).sum())
    a,b = np.linalg.lstsq(A,b,rcond=None)[0]
    return 1 / (pred * a + b)

def align_depth(true, pred, mask=None, scale_only=False):
    if scale_only:
        return align_depth_scale(true, pred, mask)
    if mask is None:
        b = true.reshape(-1)
        A = np.stack([pred.reshape(-1), np.ones_like(b)], axis=1)
    else:
        b = true[mask]
        A = np.stack([pred[mask], np.ones_like(b)], axis=1)

    assert np.isfinite(A).all() and np.isfinite(b).all(), ((~np.isfinite(A)).sum(), (~np.isfinite(b)).sum())
    a,b = np.linalg.lstsq(A,b,rcond=None)[0]
    return pred * a + b

def align_depth_scale(true, pred, mask=None):
    if mask is None:
        b = true.reshape(-1)
        A = pred.reshape(-1,1)
    else:
        b = true[mask]
        A = pred[mask].reshape(-1,1)

    assert np.isfinite(A).all() and np.isfinite(b).all(), ((~np.isfinite(A)).sum(), (~np.isfinite(b)).sum())
    a = np.linalg.lstsq(A,b,rcond=None)[0]
    return pred * a

def rotate_image_f(img, theta):
    H, W = img.shape[:2]
    center = (W / 2, H / 2)
    rot_mat = cv.getRotationMatrix2D(center, theta * 180 / np.pi, 1.0)
    rot_img = cv.warpAffine(img, rot_mat, (W, H), flags=cv.INTER_NEAREST)
    return rot_img

class Instance(NamedTuple):
    root_folder: str
    obj: str
    variation_type: str
    scene: str
    id: str

    @property
    def depth_file(self):
        return Path(self.root_folder) / self.obj / self.variation_type / self.scene / self.id / "depth.npy"

    def get_depth(self):
        d = np.load(self.depth_file)
        return d

    @property
    def image_file(self):
        return Path(self.root_folder) / self.obj / self.variation_type / self.scene / self.id / "image.png"

    @property
    def segmentation_file(self):
        return Path(self.root_folder) / self.obj / self.variation_type / self.scene / self.id / "segmentation.npy"

    def get_segmentation(self):
        return np.load(self.segmentation_file)

    @property
    def variation_data_file(self):
        return Path(self.root_folder) / self.obj / self.variation_type / self.scene / self.id / "variation.json"

    @property
    def material_data_file(self):
        return Path(self.root_folder) / self.obj / self.variation_type / self.scene / self.id / "material.json"

    @property
    def material_data(self):
        return json.load(self.material_data_file.open())

    @property
    def material_segmentation_file(self):
        return Path(self.root_folder) / self.obj / self.variation_type / self.scene / self.id / "material_segmentation.npy"

    def get_material_segmentation(self):
        return np.load(self.material_segmentation_file)

    @property
    def variation_data(self):
        return json.load(self.variation_data_file.open())

    @property
    def obj_data_file(self):
        return Path(self.root_folder) / self.obj / self.variation_type / self.scene / self.id / "obj.json"

    @property
    def obj_data(self):
        return json.load(self.obj_data_file.open())

    def get_obj_mask(self, rotate=False):
        obj_index = self.variation_data["obj_index"]
        mask = self.get_segmentation() == obj_index
        if rotate:
            theta = self.variation_data['theta']
            mask = rotate_image_f(mask.astype(np.uint8), theta).astype(bool)
        return mask


    @property
    def base_included_in_folder(self):
        return self.variation_type == 'rotate' or self.scene == '2ff05f79'

    def get_base(self) -> "Instance":
        if self.scene == '2ff05f79':
            frames = sorted(list((Path(self.root_folder) / self.obj / self.variation_type / self.scene).iterdir()))
            first_id = frames[0].name
            return Instance(self.root_folder, self.obj, self.variation_type, self.scene, first_id)

        if self.variation_type == 'rotate_camera':
            frames = sorted(list((Path(self.root_folder) / self.obj / self.variation_type / self.scene).iterdir()))
            min_theta_abs = np.inf
            min_theta_inst = None
            assert len(frames) > 0
            for f in frames:
                I = Instance(self.root_folder, self.obj, self.variation_type, self.scene, f.name)
                theta = I.variation_data['theta']
                theta_abs = min(abs(theta), abs(theta - 2 * np.pi))
                if theta_abs < min_theta_abs:
                    min_theta_abs = theta_abs
                    min_theta_inst = I
            return min_theta_inst

        return Instance(self.root_folder, self.obj, 'base', self.scene, '00')

def get_background(instance: Instance):
    mat = instance.material_data
    mat_seg = instance.get_material_segmentation()

    obj = instance.obj_data
    obj_seg = instance.get_segmentation()

    glass_indices = [v['pass_index'] for k,v in mat.items() if 'glass' in k.lower()]
    glass_mask = reduce(lambda a,b: a | b, (mat_seg == idx for idx in glass_indices), np.zeros_like(mat_seg))

    window_indices = [v['object_index'] for k,v in obj.items() if 'window' in k.lower()]
    window_obj_mask = reduce(lambda a,b: a | b, (obj_seg == idx for idx in window_indices), np.zeros_like(obj_seg))

    window_mask = window_obj_mask & glass_mask

    liquid_indices = [v['pass_index'] for k,v in obj.items() if 'liquid' in k.lower() and 'pass_index' in v]
    liquid_mask = reduce(lambda a,b: a | b, (obj_seg == idx for idx in liquid_indices), np.zeros_like(obj_seg))

    background_obj = (obj_seg == 0) & (mat_seg == 0)

    dist_mask = np.zeros_like(background_obj)
    if instance.obj == 'fishes':
        FISHES_BG_THRESH = {
            '7049eca4': 55,
            '7e1e4f38': 65,
            '6a699cef': 30,
            'a9065c4': 55,
            '6640b05c': 60,
            '3674e167': 40,
            '6cfb10ec': 65,
            '5bf62fad': 35,
        }

        dist_mask = instance.get_depth() > FISHES_BG_THRESH[instance.scene]

    background = window_mask | liquid_mask | background_obj | dist_mask

    return background

def analyze_file(model, align_method, instance: Instance, full_scene=False, scale_only=False, compare_with_base=False):

    rotate = compare_with_base and instance.variation_type == 'rotate_camera'
    if compare_with_base:
        obj_mask = instance.get_obj_mask(rotate)
        base = instance.get_base()
        if base == instance:
            return
        base_mask = base.get_obj_mask(rotate)

        obj_mask = obj_mask & base_mask
    else:
        obj_mask = instance.get_obj_mask(rotate)

    # use erosion kernel size 2 to account for border ambiguity
    kernel = np.ones((2, 2), np.uint8)
    obj_mask = cv.erode(obj_mask.astype(np.uint8), kernel, iterations=1).astype(bool)

    if full_scene:
        obj_mask = np.ones_like(obj_mask)

    background_mask = get_background(instance)

    # dilate background mask to likewise avoid boundary ambiguity
    background_mask = cv.dilate(background_mask.astype(np.uint8), kernel, iterations=1).astype(bool)
    if rotate:
        background_mask = rotate_image_f(background_mask.astype(np.uint8), instance.variation_data['theta']).astype(bool)

    pred_depth = model.predict(instance)
    if rotate:
        pred_depth = rotate_image_f(pred_depth, instance.variation_data['theta'])

    if compare_with_base:
        base_pred = model.predict(base)
        true_depth = base_pred / np.median(base_pred)
        if rotate:
            true_depth = rotate_image_f(true_depth, base.variation_data['theta'])
    else:
        true_depth = np.clip(instance.get_depth(), MIN_DEPTH, MAX_DEPTH)


    valid_mask = np.isfinite(true_depth) & np.isfinite(pred_depth)
    eval_mask = obj_mask & ~background_mask & valid_mask

    pred_depth = align_method(true_depth, pred_depth, eval_mask, scale_only)
    if not compare_with_base:
        pred_depth = np.clip(pred_depth, MIN_DEPTH, MAX_DEPTH)

    metrics = {
        'absrel': absrel(reformat_input(pred_depth), reformat_input(true_depth), reformat_input(eval_mask)),
        'rmse': rmse(reformat_input(pred_depth), reformat_input(true_depth), reformat_input(eval_mask)),
        'rmse_log': rmse_log(reformat_input(pred_depth), reformat_input(true_depth), reformat_input(eval_mask)),
        **dict(zip(
            ['delta_0125', 'delta_025', 'delta_05', 'delta_1', 'delta_2', 'delta_3'],
            delta(reformat_input(pred_depth), reformat_input(true_depth), reformat_input(eval_mask))
        )),
        'log10': log10(reformat_input(pred_depth), reformat_input(true_depth), reformat_input(eval_mask)),
        'rel_depth': rel_depth(pred_depth, true_depth, eval_mask),
    }

    return {
        'scene': instance.scene,
        'obj': instance.obj,
        'variation_type': instance.variation_type,
        'model': model.name,
        'id': instance.id,
        **metrics,
    }