import torch

class LossMetrics:
    def __init__(self, blank_symbol, timesteps, image_size, area_loss_factor=0, aspect_ratio_loss_factor=0, uses_original_data=False,
                 area_scaling_factor=2):
        self.aspect_ratio_loss_factor = aspect_ratio_loss_factor
        self.blank_symbol = blank_symbol
        self.image_size = image_size
        self.num_timesteps = timesteps
        self.base_area_loss_factor = area_loss_factor
        self.area_scaling_factor = area_scaling_factor
        self.area_loss_factor = self.base_area_loss_factor

    def get_bbox_side_lengths(self, grids):
        x0, x1, x2, y0, y1, y2 = self.get_corners(grids)

        width = torch.sqrt(
            torch.square(x1 - x0) + torch.square(y1 - y0)
        )

        height = torch.sqrt(
            torch.square(x2 - x0) + torch.square(y2 - y0)
        )
        return width, height

    def get_corners(self, xy_cordinates):
        batch_size, num_times, width, height, ___ = xy_cordinates.shape
        xy_cordinates = xy_cordinates.view(batch_size*num_times, width, height, ___)
        top_left_x = xy_cordinates[:, 0, 0 , 0]
        top_left_y = xy_cordinates[:, 0, 0 , 1]
        top_right_x = xy_cordinates[:, -1, 0 , 0]
        top_right_y = xy_cordinates[:, -1, 0 , 1]
        bottom_left_x = xy_cordinates[:, 0, -1 ,0]
        bottom_left_y = xy_cordinates[:, 0, -1 ,1]
        return top_left_x, top_right_x, bottom_left_x, top_left_y, top_right_y, bottom_left_y

    def calc_direction_loss(self, grids):
        top_left_x, top_right_x, _, top_left_y, _, bottom_left_y = self.get_corners(grids)

        # penalize upside down images
        distance = top_left_y - bottom_left_y
        loss_values = torch.maximum(distance, torch.zeros_like(distance))
        up_down_loss = torch.mean(loss_values)

        # penalize images that are vertically mirrored
        distance = top_left_x - top_right_x
        loss_values = torch.maximum(distance, torch.zeros_like(distance))
        left_right_loss = torch.mean(loss_values)

        return up_down_loss + left_right_loss

    def calc_height_loss(self, height):
        # penalize bboxes that are not high enough to contain text (10 pixels)
        shifted_height = height - 10
        thresholded_height = torch.minimum(shifted_height, torch.zeros_like(shifted_height))
        thresholded_height *= -1

        return torch.mean(thresholded_height)

    def calc_area_loss(self, width, height):
        loc_area = width * height
        loc_ratio = loc_area / (self.image_size[0] * self.image_size[1])
        return sum(loc_ratio) / max(len(loc_ratio), 1)

    def calc_aspect_ratio_loss(self, width, height, label_lengths=None):
        # penalize aspect ratios that are higher than wide, and penalize aspect ratios that are tooo wide
        aspect_ratio = height / torch.maximum(width, torch.ones_like(width))
        # do not give an incentive to bboxes with a width that is 2x the height of the box
        aspect_loss = torch.maximum(aspect_ratio - 0.5, torch.zeros_like(aspect_ratio))

        # penalize very long bboxes (based on the underlying word), by assuming that a single letter
        # has a max width of its height, if the width of the bbox is too large it will be penalized
        if label_lengths is not None:
            max_width = label_lengths * height
            width_ratio = width - max_width
            width_threshold = torch.maximum(width_ratio, torch.zeros_like(width_ratio))
            aspect_loss = aspect_ratio + width_threshold

        return sum(aspect_loss) / len(aspect_loss)

    def calc_loss(self, grids):
        
        losses = []

        for i, sub_grid in enumerate(grids):
            # adapt ctc weight depending on current prediction position and labels
            # if all labels are blank, we want this weight to be full weight!
            import pdb; pdb.set_trace()
            width, height = self.get_bbox_side_lengths(sub_grid)
            loss += self.area_loss_factor * self.calc_area_loss(width, height)
            loss += self.aspect_ratio_loss_factor * self.calc_aspect_ratio_loss(width, height)
            loss += self.calc_direction_loss(sub_grid)
            loss += self.calc_height_loss(height)
        return sum(losses) / len(losses)