import numpy as np
import torch
from torch import nn


class RelClassifier(nn.Module):
    """Concept EBM for arbitrary objects and a concept name."""

    def __init__(self):
        """Initialize layers."""
        super().__init__()
        self.f_net = nn.Sequential(
            nn.Linear(12, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, target_boxes, anchor_boxes):
        """
        Forward pass, all boxes are [xyz1, xyz2].

        Inputs:
            target_boxes (tensor): (B, N, 6)
            anchor_boxes (tensor): (B, N2, 6)
        """
        target_boxes = target_boxes.unsqueeze(2).repeat(
            1, 1, anchor_boxes.size(1), 1
        )
        anchor_boxes = anchor_boxes.unsqueeze(1).repeat(
            1, target_boxes.size(1), 1, 1
        )
        # Embed object boxes to feature vectors
        feats = torch.cat((
            target_boxes - anchor_boxes,
            target_boxes - anchor_boxes[:, :, :, (3, 4, 5, 0, 1, 2)]
        ), 3)
        # Return matchness
        return self.f_net(feats)  # (B, N, N2, 1)


def iou_2d(box0, box1):
    """Compute 2d IoU for two boxes in coordinate format."""
    box_a = np.concatenate(
        [box0[:2] - box0[2:] / 2.0, box0[:2] + box0[2:] / 2.0])
    box_b = np.concatenate(
        [box1[:2] - box1[2:] / 2.0, box1[:2] + box1[2:] / 2.0])

    # Intersection
    xA = max(box_a[0], box_b[0])
    yA = max(box_a[1], box_b[1])
    xB = min(box_a[2], box_b[2])
    yB = min(box_a[3], box_b[3])
    inter_area = max(0, xB - xA) * max(0, yB - yA)
    # Areas
    box_a_area = (box_a[2] - box_a[0]) * (box_a[3] - box_a[1])
    box_b_area = (box_b[2] - box_b[0]) * (box_b[3] - box_b[1])
    # Return IoU and area ratios
    return (
            inter_area / (box_a_area + box_b_area - inter_area),  # iou
            [inter_area / box_a_area, inter_area / box_b_area],
            [box_a_area / box_b_area, box_b_area / box_a_area]
        )


def on_classifier(box1, box2):
    """
        box1: [xc, yc, zc, w, h, d]
        box2: [xc, yc, zc, w, h, d]
        box1 on box2
    """
    is_on = True

    # the z of the top surface should be close to z of below surface
    z_top = box1[2] - box1[5] / 2.0
    z_below = box2[2] + box2[5] / 2.0
    z_diff = abs(z_top - z_below)
    is_on = is_on and z_diff < 0.3

    # there should be high iou between the surfaces of 2d boxes
    box1_surf = np.array([box1[0], box1[1], box1[3], box1[4]])
    box2_surf = np.array([box2[0], box2[1], box2[3], box2[4]])
    _, intersect_ratios, _ = iou_2d(box1_surf, box2_surf)
    int2box_ratio, _ = intersect_ratios
    is_on = is_on and int2box_ratio > 0.3

    return is_on
