import os
import argparse

import cv2
from PIL import Image
import numpy as np
import torch
import torch.nn


def parse_args():

  parser = argparse.ArgumentParser(
    description='Benchmark segmentation predictions'
  )
  parser.add_argument('--pred_dir', type=str, default='',
                      help='/path/to/prediction.')
  parser.add_argument('--gt_dir', type=str, default='',
                      help='/path/to/ground-truths')
  parser.add_argument('--save_dir', type=str, default='',
                      help='/path/to/saved-results')
  parser.add_argument('--mass_threshold', type=float, default=0.6,
                      help='mass threshold')
  parser.add_argument('--num_classes', type=int, default=2,
                      help='number of segmentation classes')

  return parser.parse_args()


def iou_stats(pred, target, num_classes=21, background=0):
  """Computes statistics of true positive (TP), false negative (FN) and
  false positive (FP).

  Args:
    pred: A numpy array.
    target: A numpy array which should be in the same size as pred.
    num_classes: A number indicating the number of valid classes.
    background: A number indicating the class index of the back ground.

  Returns:
    Three num_classes-D vector indicating the statistics of (TP+FN), (TP+FP)
    and TP across each class.
  """
  # Set redundant classes to background.
  locs = np.logical_and(target > -1, target < num_classes)

  # true positive + false negative
  tp_fn, _ = np.histogram(target[locs],
                          bins=np.arange(num_classes+1))
  # true positive + false positive
  tp_fp, _ = np.histogram(pred[locs],
                          bins=np.arange(num_classes+1))
  # true positive
  tp_locs = np.logical_and(locs, pred == target)
  tp, _ = np.histogram(target[tp_locs],
                       bins=np.arange(num_classes+1))

  return tp_fn, tp_fp, tp


def threshold_by_mass(attentions, threshold):
    """Follow DINO's implementation.

    https://github.com/facebookresearch/dino/blob/main/visualize_attention.py
    """
    # we keep only the output patch attention
    nh, h_featmap, w_featmap = attentions.shape
    attentions = attentions.reshape(nh, -1)

    # we keep only a certain percentage of the mass
    val, idx = torch.sort(attentions)
    val /= torch.sum(val, dim=1, keepdim=True)
    cumval = torch.cumsum(val, dim=1)
    th_attn = cumval > (1 - threshold)
    idx2 = torch.argsort(idx)
    for head in range(nh):
        th_attn[head] = th_attn[head][idx2[head]]
    th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float()
    return th_attn


def main():

  args = parse_args()

  assert(os.path.isdir(args.pred_dir))
  assert(os.path.isdir(args.gt_dir))
  print(args.pred_dir)
  tp_fn = np.zeros(args.num_classes, dtype=np.float64)
  tp_fp = np.zeros(args.num_classes, dtype=np.float64)
  tp = np.zeros(args.num_classes, dtype=np.float64)
  for dirpath, dirnames, filenames in os.walk(args.pred_dir):
    for filename in filenames:
      predname = os.path.join(dirpath, filename)
      gtname = predname.replace(args.pred_dir, args.gt_dir).replace('.npy', '.png')

      pred = np.load(predname, allow_pickle=True).item()
      attn = torch.from_numpy(pred['level3'][0])
      th_attn = threshold_by_mass(attn, 0.6)
      th_attn = th_attn.data.cpu().numpy().astype(np.int32)

      gt = np.asarray(
          Image.open(gtname).convert(mode='L'),
          dtype=np.uint8)
      gt = (gt > 0).astype(np.int32)

      # Select the best head.
      _max_iou, _max_tp_fn, _max_tp_fp, _max_tp = 0, 0, 0, 0
      _max_pred = np.zeros_like(gt)
      for head_ind in range(attn.shape[0]):
        _attn = th_attn[head_ind]
        _attn = cv2.resize(_attn, (gt.shape[1], gt.shape[0]),
                           interpolation=cv2.INTER_NEAREST)
        _tp_fn, _tp_fp, _tp = iou_stats(
            _attn,
            gt,
            num_classes=args.num_classes,
            background=0)
        _iou = _tp / (_tp_fn + _tp_fp - _tp + 1e-12) * 100.0
        _mean_iou = _iou.sum() / args.num_classes
        if _mean_iou > _max_iou:
          _max_iou = _mean_iou
          _max_tp_fn, _max_tp_fp, _max_tp = _tp_fn, _tp_fp, _tp
          _max_pred = _attn

      # Save best prediction.
      savename = gtname.replace(args.gt_dir, args.save_dir)
      os.makedirs(os.path.dirname(savename), exist_ok=True)
      Image.fromarray((_max_pred * 255).astype(np.uint8), mode='L').save(savename)

      tp_fn += _max_tp_fn
      tp_fp += _max_tp_fp
      tp += _max_tp

  iou = tp / (tp_fn + tp_fp - tp + 1e-12) * 100.0

  if args.num_classes == 2:
    # Potsdam
    class_names = ['Background', 'Foreground']
  else:
    raise NotImplementedError()


  for i in range(args.num_classes):
    if i >= len(class_names):
      break
    print('class {:10s}: {:02d}, acc: {:4.4f}%'.format(
        class_names[i], i, iou[i]))
  mean_iou = iou.sum() / args.num_classes
  print('mean IOU: {:4.4f}%'.format(mean_iou))

  mean_pixel_acc = tp.sum() / (tp_fp.sum() + 1e-12)
  print('mean Pixel Acc: {:4.4f}%'.format(mean_pixel_acc))

if __name__ == '__main__':
  main()
