import math
import cv2
import torch
import numpy as np
from scipy.optimize import linear_sum_assignment
from sklearn.metrics.cluster import adjusted_rand_score

def get_iou(pred, gt, pred_class, gt_class):
    """
    Calculates IoU per frame for objects with given ID
    If union is 0, frame IoU is set to -1
    
    :arg pred: (#frames, H, W)
    :arg gt: (#frames, H, W)
    :arg pred_class: int
    :arg gt_class: int
    
    :return IoUs: (#frames)
    :return mIoU: (#frames)
    """
    
    # target.shape = (#frames, H, W)
    # source.shape = (#frames, H, W)
    pred = (pred == pred_class)
    gt = (gt == gt_class)
    
    # intersection.shape = (#frames)
    # union.shape = (#frames)
    intersection = torch.logical_and(pred, gt).sum(dim=-1).sum(dim=-1)
    union = torch.logical_or(pred, gt).sum(dim=-1).sum(dim=-1)
    
    available_frames = torch.where(union != 0)[0]
    assert len(available_frames) > 0, "GT mask has non-assigned label"
    
    
    # IoU will be -1 for NA frames
    intersection[union == 0] = -1
    union[union == 0] = 1
    
    IoUs = intersection / union
    mIoU = torch.mean(IoUs[IoUs != -1])
    
    return IoUs, mIoU

def swap_indices(pred, row_ind, col_ind, bg_class):
    """
    Changes row_ind of pred to col_ind
    Sets remaining to bg_class
    
    :arg pred: (#frames, H, W)
    :arg row_ind: (#ids)
    :arg col_ind: (#ids)
    :arg bg_class: int
    
    :return: (#frames, H, W)
    """
    
    new_pred = torch.ones(pred.shape, device=pred.device, dtype=torch.long) * bg_class
    
    matching_num = len(row_ind)
    for i in range(matching_num):
        pred_ID = row_ind[i]
        gt_ID = col_ind[i]
        
        new_pred[pred == pred_ID] = gt_ID
        
    return new_pred


# matches object masks (not bg)
def hungarian_matching(pred, gt, bg_class=0):
    """
    Matches region IDs, leaving bg. 
    Unassigned regions are all set to bg
    Returns prediction mask with matched region IDs
    
    :arg pred: (#frames, H, W)
    :arg gt: (#frames, H, W)
    :arg bg_class: int

    :return: (#frames)
    """
    
    # get all region IDs
    pred_classes = torch.unique(pred)
    gt_classes = torch.unique(gt)
    
    # delete BG from gt_classes:
    gt_classes = gt_classes[gt_classes != bg_class]
    
    gt_class_num = len(gt_classes)
    pred_class_num = len(pred_classes)
    if pred_class_num < gt_class_num:
        pred_class_num = gt_class_num
    
    # discard bg
    miou_res = torch.zeros(pred_class_num, gt_class_num, device=pred.device)
    for i, gt_ID in enumerate(gt_classes):
        for j, pred_ID in enumerate(pred_classes):
            miou_res[j, i] = get_iou(pred, gt, pred_ID, gt_ID)[1]

    all_metrics = miou_res.cpu().numpy()
    
    # row_ind -> col_ind
    # pred ID -> gt ID (without BG)
    row_ind, col_ind = linear_sum_assignment(-all_metrics)
    # col_ind[gt_classes.cpu() > bg_class] += 1

    # TODO: parametrize here:
    col_ind += 1
    
    new_pred = swap_indices(pred, row_ind, col_ind, bg_class)
    
    return new_pred


def get_ari(prediction_masks, gt_masks, bg_class):
    """
    :args prediction_masks: predicted slot mask, (#frames, H, W)
    :args gt_masks: gt instance mask, (#frames, H, W)
    """

    # prediction_masks.shape = (#frames, H * W)
    # gt_masks.shape = (#frames, H * W)
    prediction_masks = torch.flatten(prediction_masks, start_dim=1, end_dim=-1).cpu().numpy().astype(int)
    gt_masks = torch.flatten(gt_masks, start_dim=1, end_dim=-1).cpu().numpy().astype(int)

    assert prediction_masks.shape == gt_masks.shape, f"prediction_masks.shape: {prediction_masks.shape} gt_masks.shape: {gt_masks.shape}"

    frame_num = gt_masks.shape[0]

    # fg_indices.shape = (#frames, H * W)
    fg_indices = np.not_equal(bg_class, gt_masks)

    rand_scores = []
    for frame_idx in range(frame_num):
        fg_indices_frame = fg_indices[frame_idx]

        if fg_indices_frame.sum() == 0:
            continue

        pred = prediction_masks[frame_idx][fg_indices_frame]
        gt = gt_masks[frame_idx][fg_indices_frame]

        rand_scores.append(adjusted_rand_score(gt, pred))
    
    if len(rand_scores) == 0:
        ari = None
    else:
        ari = sum(rand_scores) / len(rand_scores)
    return ari



class Evaluator:
    def __init__(self, bg_class=0):
        self.reset()
        self.bg_class = bg_class

    def reset(self):
        self.object_numbers = []
        self.mious = []
        self.mious_wo_bg = []
        self.fg_aris = []

    def calculate_miou(self, pred, gt):
        # get all region IDs
        pred_classes = torch.unique(pred).sort()[0]
        gt_classes = torch.unique(gt).sort()[0]

        assert len(pred_classes) <= len(gt_classes) + 1, f"pred_classes: {pred_classes}, gt_classes: {gt_classes}"

        mious = []
        mious_wo_bg = []
        class_num = len(gt_classes)

        for cls in range(class_num):
            _, miou = get_iou(pred, gt, gt_classes[cls], gt_classes[cls])
            mious.append(miou)

            if gt_classes[cls] != self.bg_class:
                mious_wo_bg.append(miou)

        self.mious.extend(mious)
        self.mious_wo_bg.extend(mious_wo_bg)


    def calculate_fg_ari(self, pred, gt):
        # get all region IDs
        pred_classes = torch.unique(pred)
        gt_classes = torch.unique(gt)

        assert len(pred_classes) <= len(gt_classes) + 1, f"pred_classes: {pred_classes}, gt_classes: {gt_classes}"

        aris = get_ari(pred, gt, self.bg_class)
        if aris is not None:
            self.fg_aris.append(aris)

    def update(self, prediction_masks, gt_masks):
        """
        :args prediction_masks: predicted slot mask, (#frames, H, W)
        :args gt_masks: gt instance mask, (#frames, #objects, H, W)
        """

        object_num = gt_masks.shape[1]
        self.object_numbers.append(object_num)

        # gt_masks.shape = (#frames, H, W)
        gt_masks = torch.argmax(gt_masks, dim=1)
        assert gt_masks.shape == prediction_masks.shape, f"gt_masks.shape: {gt_masks.shape}, prediction_masks.shape: {prediction_masks.shape}"
        
        prediction_masks = hungarian_matching(prediction_masks, gt_masks, self.bg_class)

        self.calculate_miou(prediction_masks, gt_masks)
        self.calculate_fg_ari(prediction_masks, gt_masks)

        miou, miou_wo_bg, fg_ari, miou_per_video = self.get_results(reset=False)

        return miou, miou_wo_bg, fg_ari, miou_per_video

    def get_results(self, reset=True):

        miou = sum(self.mious) / len(self.mious)
        miou_wo_bg = sum(self.mious_wo_bg) / len(self.mious_wo_bg)
        fg_ari = sum(self.fg_aris) / len(self.fg_aris)

        mious_per_video = []
        acc = 0
        for object_num in self.object_numbers:
            miou_video = sum(self.mious[acc : acc + object_num]) / object_num
            mious_per_video.append(miou_video)
            acc += object_num
        miou_per_video = sum(mious_per_video) / len(self.object_numbers)
        
        if reset:
            self.reset()
        
        return miou, miou_wo_bg, fg_ari, miou_per_video