import torch
import torch.nn.functional as F
from torch.autograd import Variable
from math import exp
import pytorch3d.ops


def l1_loss(network_output, gt):
    return torch.abs((network_output - gt)).mean()


def l2_loss(network_output, gt):
    return ((network_output - gt) ** 2).mean()


def gaussian(window_size, sigma):
    gauss = torch.Tensor(
        [
            exp(-((x - window_size // 2) ** 2) / float(2 * sigma ** 2))
            for x in range(window_size)
        ]
    )
    return gauss / gauss.sum()


def create_window(window_size, channel):
    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
    window = Variable(
        _2D_window.expand(channel, 1, window_size, window_size).contiguous()
    )
    return window


def ssim(img1, img2, window_size=11, size_average=True):
    channel = img1.size(-3)
    window = create_window(window_size, channel)

    if img1.is_cuda:
        window = window.cuda(img1.get_device())
    window = window.type_as(img1)

    return _ssim(img1, img2, window, window_size, channel, size_average)


def _ssim(img1, img2, window, window_size, channel, size_average=True):
    mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
    mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2

    sigma1_sq = (
            F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
    )
    sigma2_sq = (
            F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
    )
    sigma12 = (
            F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel)
            - mu1_mu2
    )

    C1 = 0.01 ** 2
    C2 = 0.03 ** 2

    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / (
            (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)
    )

    if size_average:
        return ssim_map.mean()
    else:
        return ssim_map.mean(1).mean(1).mean(1)


def pixelwise_l1_with_mask(img1, img2, pixel_mask, weight_matrix):
    # img1, img2: (3, H, W)
    # pixel_mask: (H, W) bool torch tensor as mask.
    # only compute l1 loss for the pixels that are touched

    pixelwise_l1_loss = torch.abs((img1 - img2)) * weight_matrix * pixel_mask.unsqueeze(0)
    return pixelwise_l1_loss


def pixelwise_ssim_with_mask(img1, img2, pixel_mask, weight_matrix=None):
    window_size = 11

    channel = img1.size(-3)
    window = create_window(window_size, channel)
    if img1.is_cuda:
        window = window.cuda(img1.get_device())
    window = window.type_as(img1)

    mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
    mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2

    sigma1_sq = (
            F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
    )
    sigma2_sq = (
            F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
    )
    sigma12 = (
            F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel)
            - mu1_mu2
    )

    C1 = 0.01 ** 2
    C2 = 0.03 ** 2

    pixelwise_ssim_loss = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / (
            (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)
    )

    if weight_matrix is not None:
        pixelwise_ssim_loss = pixelwise_ssim_loss * weight_matrix * pixel_mask.unsqueeze(0)
    else:
        pixelwise_ssim_loss = pixelwise_ssim_loss * pixel_mask.unsqueeze(0)

    return pixelwise_ssim_loss


def pixelwise_depth_with_mask(pred_depth, gt_depth, pixel_mask=None, loss_type='l2', max_depth=80, weight_matrix=None):
    def normalize_depth(depth, max_depth=80.0):
        return torch.clamp(depth / max_depth, 0.0, 1.0)

    valid_mask = (gt_depth > 0.01) & (gt_depth < max_depth)
    pred_depth = normalize_depth(pred_depth[valid_mask], max_depth=max_depth)
    gt_depth = normalize_depth(gt_depth[valid_mask], max_depth=max_depth)
    if loss_type == "smooth_l1":
        loss = F.smooth_l1_loss(pred_depth, gt_depth, reduction="none")
    elif loss_type == "l1":
        loss = F.l1_loss(pred_depth, gt_depth, reduction="none")
    elif loss_type == "l2":
        loss = F.mse_loss(pred_depth, gt_depth, reduction="none")
    else:
        raise NotImplementedError(f"Unknown loss type: {loss_type}")
    if weight_matrix is not None:
        loss *= weight_matrix[valid_mask]
    if pixel_mask is not None:
        loss = loss * pixel_mask[valid_mask.squeeze(0)].unsqueeze(0)

    return loss


def pixelwise_semantic_with_mask(semantic_render, semantic_gt, pixel_mask=None):
    semantic_loss = torch.abs((semantic_render - semantic_gt))
    if pixel_mask is not None:
        semantic_loss_mask = semantic_loss * pixel_mask.unsqueeze(0)
    return semantic_loss_mask


def cal_bit_groups(max_cls_id):
    bit_groups = []
    while max_cls_id > 0:
        max_bit = max_cls_id.bit_length()
        if max_bit != 1:
            max_bit = max_bit - 1 if 2 ** max_bit >= max_cls_id else max_bit
        else:
            max_bit = 1
        bit_groups.append(max_bit)
        max_cls_id -= 2 ** max_bit - 1

    return bit_groups


def entity_binary_convert_tensor(entity_gt, bit_groups=[7, 4, 3, 2, 1]):
    binary_results = []
    for bit in bit_groups[::-1]:
        cur_binary_value = (1 << bit) - 1 if bit > 1 else 1
        wait_convert_value = torch.where((cur_binary_value < entity_gt) & (entity_gt >= 0), cur_binary_value, entity_gt)
        cur_binary = torch.stack([(wait_convert_value >> i) & 1 for i in range(bit - 1, -1, -1)], dim=-1)
        binary_results.append(cur_binary.squeeze(0).permute(2, 0, 1))
        entity_gt -= wait_convert_value
    assert entity_gt.sum() == 0, "the number of max class don't match"
    binary_results = torch.cat(binary_results[::-1], dim=0)
    return binary_results


def pixelwise_entity_with_mask(entity_render, entity_gt, pixel_mask=None, weight_matrix=None,
                               bit_groups=[7, 4, 3, 2, 1], use_truncated_binary=False):
    binary_loss = torch.nn.BCELoss(reduction='none')
    # binary_loss = torch.nn.BCEWithLogitsLoss(reduction='none')
    if use_truncated_binary:
        entity_map_gt_binary = entity_binary_convert_tensor(entity_gt, bit_groups=bit_groups)
        entity_map_gt_binary = entity_map_gt_binary.to(torch.float32).permute(1, 2, 0).contiguous()
    else:
        entity_map_gt_binary = (entity_gt.squeeze(0).unsqueeze(-1) >> torch.arange(entity_render.shape[0] - 1, -1, -1,
                                                                                   dtype=torch.int32,
                                                                                   device=entity_render.device)) & 1
        entity_map_gt_binary = entity_map_gt_binary.to(torch.float32)

    entity_loss = binary_loss(entity_render.permute(1, 2, 0).contiguous(), entity_map_gt_binary)
    entity_loss *= weight_matrix.permute(1, 2, 0)
    if pixel_mask is not None:
        entity_loss = entity_loss * pixel_mask.unsqueeze(-1)
    return entity_loss


def entity_semantic_loss(features, entity_binary_ids, entity_semantic_centroid, unique_entity, min_entity_num=50):
    loss = 0.
    count = 0
    for i, entity_cls in enumerate(unique_entity):
        entity_index = (entity_binary_ids == entity_cls.unsqueeze(0)).all(dim=1).nonzero().squeeze()
        if entity_index.numel() < min_entity_num:
            continue
        cur_feature = features[entity_index]
        loss += torch.mean(torch.abs(cur_feature - entity_semantic_centroid[i].unsqueeze(0)))
        count += 1

    if count == 0:
        return 0.
    else:
        return loss / count


def semantic_3d_loss(features, predictions, k=5, lambda_val=2.0, max_points=10000, sample_size=1000):
    # Conditionally downsample if points exceed max_points
    if features.size(0) > max_points:
        indices = torch.randperm(features.size(0))[:max_points]
        features = features[indices]
        predictions = predictions[indices]

    # Randomly sample points for which we'll compute the loss
    indices = torch.randperm(features.size(0))[:sample_size]
    sample_features = features[indices]
    sample_preds = predictions[indices]

    # Compute top-k nearest neighbors directly in PyTorch
    dists = torch.cdist(sample_features, features)  # Compute pairwise distances
    _, neighbor_indices_tensor = dists.topk(k, largest=False)  # Get top-k smallest distances

    # Fetch neighbor predictions using indexing
    neighbor_preds = predictions[neighbor_indices_tensor[:, 1:]]

    semantic_sim_loss = F.l1_loss(neighbor_preds, sample_preds.unsqueeze(1).expand(-1, k - 1, -1), dim=-1)
    normalized_loss = (1 - (semantic_sim_loss + 1) / 2).mean()
    return normalized_loss


def entity_3d_loss(features, predictions, k=5, lambda_val=2.0, max_points=50000, sample_size=1000, use_kl=True):
    # Conditionally downsample if points exceed max_points
    if features.size(0) > max_points:
        indices = torch.randperm(features.size(0))[:max_points]
        features = features[indices]
        predictions = predictions[indices]

    # Randomly sample points for which we'll compute the loss
    indices = torch.randperm(features.size(0))[:sample_size]
    sample_features = features[indices]
    sample_preds = predictions[indices]

    # Compute top-k nearest neighbors directly in PyTorch
    dists = torch.cdist(sample_features, features)  # Compute pairwise distances
    _, neighbor_indices_tensor = dists.topk(k, largest=False)  # Get top-k smallest distances

    # Fetch neighbor predictions using indexing
    neighbor_preds = predictions[neighbor_indices_tensor[:, 1:]]

    # Compute KL divergence
    if use_kl:
        forward_kl = sample_preds.unsqueeze(1) * (
                torch.log(sample_preds.unsqueeze(1) + 1e-10) - torch.log(neighbor_preds + 1e-10))
        reverse_kl = neighbor_preds * (torch.log(neighbor_preds + 1e-10) - torch.log(sample_preds.unsqueeze(1) + 1e-10))
        loss = (forward_kl.sum(dim=-1).mean() + reverse_kl.sum(dim=-1).mean()) / 2
        # Normalize loss into [0, 1]
        num_classes = predictions.size(1)
        normalized_loss = loss / num_classes
    else:
        binary_loss = torch.nn.BCELoss(reduction='none')
        forward_bce = binary_loss(sample_preds.unsqueeze(1).expand(-1, k - 1, -1), neighbor_preds)
        reverse_bce = binary_loss(neighbor_preds, sample_preds.unsqueeze(1).expand(-1, k - 1, -1))
        loss = (forward_bce.sum(dim=-1).mean() + reverse_bce.sum(dim=-1).mean()) / 2
        normalized_loss = loss.mean()

    return normalized_loss
