from tqdm import tqdm
from utils.utils_bbox import DecodeBox
import torch
import numpy as np
from nets.ops import box_cxcywh_to_xyxy
import time
class Acc_calculate(object):
  def __init__(self, model, dataloader, input_shape, batch_size):
      self.model      = model
      self.dataloader = dataloader
      self.batch_size = batch_size
      self.input_shape = np.array((input_shape[0], input_shape[1]))
      self.DecodeBox  = DecodeBox()
      self.input_shape= torch.unsqueeze(torch.from_numpy(self.input_shape), 0).cuda()
      
  def calculate(self):
      
      steps = len(self.dataloader)
      print('Evaluation begins...')
      pbar = tqdm(total=steps, desc=f'Evaluating...',postfix=dict,mininterval=0.3)
      
      acc_50 = 0
      acc_25 = 0

      for i, batch in enumerate(self.dataloader):
          q_img, reimg, q_mask, targets, re_mask, q_xy, features = batch[0], batch[1], batch[2], batch[3], batch[4], batch[5], batch[-1]
          b, _, _, _  = q_img.size()
          with torch.no_grad():
              q_img  = q_img.cuda()
              reimg  = reimg.cuda()
              targets = [{key:ann[key].cuda() for key in ann} for ann in targets]
              q_mask  = q_mask.cuda()
              re_mask = re_mask.cuda()
              q_xy   = q_xy.cuda()
              features = features.cuda()
              
              outputs = self.model(q_img, reimg, q_mask, q_xy)
              outputs = self.DecodeBox(outputs, self.input_shape)
              
              
              outputs = torch.cat([out for out in outputs])
              
              confidence = outputs[:, 4]            
              max_loc = torch.argwhere(confidence == confidence.max())
              result  = outputs[max_loc[0][0], :].unsqueeze(0)
              
              gt_bbox = targets[0]['boxes']
              gt_bbox[:, [0, 2]] = gt_bbox[:, [0, 2]] * self.input_shape[0][0]
              gt_bbox[:, [1, 3]] = gt_bbox[:, [1, 3]] * self.input_shape[0][1]
              gt_bbox = box_cxcywh_to_xyxy(gt_bbox)
              
              IoU = self.bbox_iou(result, gt_bbox).sum()
              if IoU>0.5: acc_50 += 1
              if IoU>0.25: acc_25 += 1



          pbar.update(1)
      pbar.close()
      print("\nacc@0.25: {:.2f}%, acc@0.50: {:.2f}%\n".format(acc_25 * 100 / steps, acc_50 * 100 / steps))

      print('Evaluation completes!')
      return acc_25 / steps, acc_50 / steps


  def bbox_iou(self, box1, box2, x1y1x2y2=True):
      if x1y1x2y2:
          # Get the coordinates of bounding boxes
          b1_x1, b1_y1, b1_x2, b1_y2 = box1[:, 0], box1[:, 1], box1[:, 2], box1[:, 3]
          b2_x1, b2_y1, b2_x2, b2_y2 = box2[:, 0], box2[:, 1], box2[:, 2], box2[:, 3]
      else:
          # Transform from center and width to exact coordinates
          b1_x1, b1_x2 = box1[:, 0] - box1[:, 2] / 2, box1[:, 0] + box1[:, 2] / 2
          b1_y1, b1_y2 = box1[:, 1] - box1[:, 3] / 2, box1[:, 1] + box1[:, 3] / 2
          b2_x1, b2_x2 = box2[:, 0] - box2[:, 2] / 2, box2[:, 0] + box2[:, 2] / 2
          b2_y1, b2_y2 = box2[:, 1] - box2[:, 3] / 2, box2[:, 1] + box2[:, 3] / 2
  
      # get the coordinates of the intersection rectangle
      inter_rect_x1 = torch.max(b1_x1, b2_x1)
      inter_rect_y1 = torch.max(b1_y1, b2_y1)
      inter_rect_x2 = torch.min(b1_x2, b2_x2)
      inter_rect_y2 = torch.min(b1_y2, b2_y2)
      # Intersection area
      inter_area = torch.clamp(inter_rect_x2 - inter_rect_x1, 0) * torch.clamp(inter_rect_y2 - inter_rect_y1, 0)
      # Union Area
      b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1)
      b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1)
  
      # print(box1, box1.shape)
      # print(box2, box2.shape)
      return inter_area / (b1_area + b2_area - inter_area + 1e-16)