from typing import List
from math import floor

import numpy as np
import torch
import torch.nn.functional as F

from .utils.normal import get_triangle_normal_and_valid
from .utils.downsample import downsample
from .utils.misc import reformat_as_torch_tensor

DEFAULT_CONFIG={
    'scales': [1, 2, 4, 8],
    'num_sample': int(1e6),
    'radius': 32,
    'min_radius': 3,
    'invalid': 'penalty',
}

@torch.no_grad()
def get_angle_between(n1: torch.Tensor, n2: torch.Tensor) -> torch.Tensor:
    '''
    :param n1: shape (..., 3), norm > 0
    :param n2: shape (..., 3), norm > 0
    :return: shape (...)
    '''
    return torch.acos((F.normalize(n1, dim=-1) * F.normalize(n2, dim=-1)).sum(dim=-1).clamp(-1, 1))

@torch.no_grad()
def get_pair_pxl(H: int, W: int, num_sample: int, radius: int, device):
    radius = min(radius, max(H, W))
    i1 = torch.empty((num_sample,), dtype=torch.long, device=device)
    j1 = torch.empty((num_sample,), dtype=torch.long, device=device)
    i2 = torch.empty((num_sample,), dtype=torch.long, device=device)
    j2 = torch.empty((num_sample,), dtype=torch.long, device=device)

    n = 0
    s = torch.quasirandom.SobolEngine(4)
    while n < num_sample:
        samples = s.draw(floor(num_sample * 1.1)).to(device)
        samples[:,0] *= H
        samples[:,1] *= W
        samples[:,2] *= radius * 2
        samples[:,2] -= radius
        samples[:,3] *= radius * 2
        samples[:,3] -= radius
        points = torch.cat([samples[:,:2], samples[:,:2] + samples[:,2:]], dim=1)
        points = torch.floor(points)

        valid = (points[:,[0,2]] < H).all(dim=-1) & (points[:,[1,3]] < W).all(dim=-1) & (0 <= points[:,[0,2]]).all(dim=-1) & (0 <= points[:,[1,3]]).all(dim=-1)
        points = points[valid]
        m = min(len(points), num_sample - n)
        i1[n:n+m] = points[:m,0]
        j1[n:n+m] = points[:m,1]
        i2[n:n+m] = points[:m,2]
        j2[n:n+m] = points[:m,3]
        n += m
    
    return i1, j1, i2, j2


@torch.no_grad()
def get_rel_normal_err_heatmap_idx(gt_xyz: torch.Tensor, gt_valid: torch.Tensor,
                               pred_xyz: torch.Tensor, pred_valid: torch.Tensor,
                               num_sample: int, radius: int):
    '''
    :param gt_xyz:
    :param gt_valid:
    :param pred_xyz:
    :param pred_valid:
    :param num_sample:
    :param radius:
    :return: rel_normal_err, gt_pair_valid, pred_pair_valid
        rel_normal_err: shape (-1,)
        gt_pair_valid: shape (-1,)
        pred_pair_valid: shape (-1,)
    '''
    gt_normal, gt_normal_valid = get_triangle_normal_and_valid(gt_xyz, gt_valid, flatten=False)
    pred_normal, pred_normal_valid = get_triangle_normal_and_valid(pred_xyz, pred_valid, flatten=False)

    H, W = gt_normal.shape[:2]
    i1, j1, i2, j2 = get_pair_pxl(H, W, num_sample, radius, gt_xyz.device)

    gt_rel_normal = get_angle_between(gt_normal[i1, j1], gt_normal[i2, j2])
    gt_pair_valid = gt_normal_valid[i1, j1] & gt_normal_valid[i2, j2]
    pred_rel_normal = get_angle_between(pred_normal[i1, j1], pred_normal[i2, j2])
    pred_pair_valid = pred_normal_valid[i1, j1] & pred_normal_valid[i2, j2]
    rel_normal_err = torch.abs(gt_rel_normal - pred_rel_normal)  # [0, pi]
    return rel_normal_err, gt_pair_valid, pred_pair_valid, (i1,j1,i2,j2)



def get_multi_scale_rel_normal_err(gt_xyz: torch.Tensor, gt_valid: torch.Tensor,
                                   pred_xyz: torch.Tensor, pred_valid: torch.Tensor,
                                   scales: List[int], num_sample: int, radius: int, min_radius: int, invalid):
    '''
    :param gt_xyz:
    :param gt_valid:
    :param pred_xyz:
    :param pred_valid:
    :param scales: list of down-sample scales
    :param num_sample:
    :param radius:
    :param min_radius:
    :return: list of avg relative normal errors under each scale
    '''
    ret = []
    for sc in scales:
        ds_gt_valid, ds_gt_xyz, ds_pred_valid, ds_pred_xyz = downsample(sc, gt_valid, [gt_xyz, pred_valid, pred_xyz])
        err, gt_pair_valid, pred_pair_valid, _ = get_rel_normal_err_heatmap_idx(ds_gt_xyz, ds_gt_valid, ds_pred_xyz, ds_pred_valid, num_sample, max(radius // sc, min_radius))
        match invalid:
            case 'penalty':
                err = torch.where(gt_pair_valid & ~pred_pair_valid, torch.pi, err)
                err = err[gt_pair_valid]
            case 'ignore':
                err = err[gt_pair_valid & pred_pair_valid]
            case _:
                raise ValueError()

        scalar_err = err.mean().item()
        ret.append(scalar_err)
    return ret

def get_multi_scale_rel_normal_heatmap(gt_xyz: torch.Tensor, gt_valid: torch.Tensor,
                                   pred_xyz: torch.Tensor, pred_valid: torch.Tensor,
                                   scales: List[int], num_sample: int, radius: int, min_radius: int, invalid: str):
    '''
    :param gt_xyz:
    :param gt_valid:
    :param pred_xyz:
    :param pred_valid:
    :param scales: list of down-sample scales
    :param num_sample:
    :param radius:
    :param min_radius:
    :return: list of avg relative normal errors under each scale
    '''
    ret = []
    for sc in scales:
        ds_gt_valid, ds_gt_xyz, ds_pred_valid, ds_pred_xyz = downsample(sc, gt_valid, [gt_xyz, pred_valid, pred_xyz])
        err, gt_pair_valid, pred_pair_valid, idxs = get_rel_normal_err_heatmap_idx(ds_gt_xyz, ds_gt_valid, ds_pred_xyz, ds_pred_valid, num_sample, max(radius // sc, min_radius))
        i1,j1,i2,j2 = idxs
        match invalid:
            case 'penalty':
                err = torch.where(gt_pair_valid & ~pred_pair_valid, torch.pi, err)
                err = err[gt_pair_valid]
                i1,j1,i2,j2 = map(lambda x: x[gt_pair_valid.flatten()], (i1,j2,i2,j2))

            case 'ignore':
                err = err[gt_pair_valid & pred_pair_valid]
                i1,j1,i2,j2 = map(lambda x: x[(gt_pair_valid & pred_pair_valid).flatten()], (i1,j2,i2,j2))
            case _:
                raise ValueError()

        H, W = ds_gt_valid.shape
        heatmap = torch.zeros(H, W, device=err.device, dtype=err.dtype)
        linear_idx = i1 * W + j1
        heatmap.view(-1).index_add_(0, linear_idx, err / 2)

        linear_idx = i2 * W + j2
        heatmap.view(-1).index_add_(0, linear_idx, err / 2)

        heatmap *= H * W / len(err)

        ret.append(heatmap)
    return ret

def rel_normal(gt_xyz, gt_valid, pred_xyz, pred_valid, cfg=None, **kwargs):
    if cfg is None:
        cfg = DEFAULT_CONFIG | kwargs
    device_args = {k:v for k,v in cfg.items() if k == 'device'}
    cfg.pop('device', None)
    gt_xyz = reformat_as_torch_tensor(gt_xyz, **device_args)
    gt_valid = reformat_as_torch_tensor(gt_valid, **device_args)
    pred_xyz = reformat_as_torch_tensor(pred_xyz, **device_args)
    pred_valid = reformat_as_torch_tensor(pred_valid, **device_args)
    return np.mean(get_multi_scale_rel_normal_err(gt_xyz, gt_valid, pred_xyz, pred_valid, **cfg))

def rel_normal_heatmap(gt_xyz, gt_valid, pred_xyz, pred_valid, cfg=None, **kwargs):
    if cfg is None:
        cfg = DEFAULT_CONFIG | kwargs
    gt_xyz = reformat_as_torch_tensor(gt_xyz)
    gt_valid = reformat_as_torch_tensor(gt_valid)
    pred_xyz = reformat_as_torch_tensor(pred_xyz)
    pred_valid = reformat_as_torch_tensor(pred_valid)
    heatmaps = get_multi_scale_rel_normal_heatmap(gt_xyz, gt_valid, pred_xyz, pred_valid, **cfg)

    H,W = gt_valid.shape
    heatmap = torch.zeros(H,W, dtype=heatmaps[0].dtype, device=heatmaps[0].device)

    for h in heatmaps:
        heatmap += F.interpolate(h[None,None], size=(H,W), mode='nearest')[0,0]
    heatmap /= len(heatmaps)

    return heatmap
