import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.ops import nms


class DecodeBox_full(nn.Module):
    """ This module converts the model's output into the format expected by the coco api"""

    def box_cxcywh_to_xyxy(self, x):
        x_c, y_c, w, h = x.unbind(-1)
        b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
             (x_c + 0.5 * w), (y_c + 0.5 * h)]
        return torch.stack(b, dim=-1)

    @torch.no_grad()
    def forward(self, outputs, target_sizes, confidence=1e-7):
        full_out_logits, full_out_bbox = outputs['full_pred_logits'], outputs['full_pred_boxes']
        full_results = []
        for i in range(full_out_logits.shape[0]):
            out_logits, out_bbox = full_out_logits[i], full_out_bbox[i]
            assert len(out_logits) == len(target_sizes)
            assert target_sizes.shape[1] == 2

            prob = F.softmax(out_logits, -1)
            scores, labels = prob[..., :-1].max(-1)

            # convert to [x0, y0, x1, y1] format
            boxes = self.box_cxcywh_to_xyxy(out_bbox)
            # and from relative [0, 1] to absolute [0, height] coordinates
            img_h, img_w = target_sizes.unbind(1)
            img_h = img_h.float()
            img_w = img_w.float()
            scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
            boxes = boxes * scale_fct[:, None, :]

            outputs = torch.cat([
                torch.unsqueeze(boxes[:, :, 0], -1),
                torch.unsqueeze(boxes[:, :, 1], -1),
                torch.unsqueeze(boxes[:, :, 2], -1),
                torch.unsqueeze(boxes[:, :, 3], -1),
                torch.unsqueeze(scores, -1),
                torch.unsqueeze(labels.float(), -1),
            ], -1)

            results = []
            for output in outputs:
                results.append(output[output[:, 4] > confidence])
        # results = [{'scores': s, 'labels': l, 'boxes': b} for s, l, b in zip(scores, labels, boxes)]
            full_results.append(results)
        return full_results

class DecodeBox(nn.Module):
    """ This module converts the model's output into the format expected by the coco api"""
    def box_cxcywh_to_xyxy(self, x):
        x_c, y_c, w, h = x.unbind(-1)
        b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
            (x_c + 0.5 * w), (y_c + 0.5 * h)]
        return torch.stack(b, dim=-1)
    
    @torch.no_grad()
    def forward(self, outputs, target_sizes, confidence=1e-7):
        out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes']

        assert len(out_logits) == len(target_sizes)
        assert target_sizes.shape[1] == 2

        prob = F.softmax(out_logits, -1)
        scores, labels = prob[..., :-1].max(-1)
        
        # convert to [x0, y0, x1, y1] format
        boxes = self.box_cxcywh_to_xyxy(out_bbox)
        # and from relative [0, 1] to absolute [0, height] coordinates
        img_h, img_w    = target_sizes.unbind(1)
        img_h           = img_h.float()
        img_w           = img_w.float()
        scale_fct       = torch.stack([img_w, img_h, img_w, img_h], dim=1)
        boxes           = boxes * scale_fct[:, None, :]

        outputs = torch.cat([
                torch.unsqueeze(boxes[:, :, 0], -1),
                torch.unsqueeze(boxes[:, :, 1], -1),
                torch.unsqueeze(boxes[:, :, 2], -1),
                torch.unsqueeze(boxes[:, :, 3], -1),
                torch.unsqueeze(scores, -1),
                torch.unsqueeze(labels.float(), -1),    
            ], -1)
        
        results = []
        for output in outputs:
            results.append(output[output[:, 4] > confidence])
        # results = [{'scores': s, 'labels': l, 'boxes': b} for s, l, b in zip(scores, labels, boxes)]
        return results