from typing import List, Dict
from omegaconf import DictConfig
from collections import defaultdict
import torch
import torch.nn.functional as F

from feature_extractor.cutie.cutie.utils.point_features import calculate_uncertainty, point_sample, get_uncertain_point_coords_with_randomness
from feature_extractor.cutie.cutie.utils.tensor_utils import cls_to_one_hot


@torch.jit.script
def ce_loss(logits: torch.Tensor, soft_gt: torch.Tensor) -> torch.Tensor:
    # logits: T*C*num_points
    loss = F.cross_entropy(logits, soft_gt, reduction='none')
    # sum over temporal dimension
    return loss.sum(0).mean()


@torch.jit.script
def dice_loss(mask: torch.Tensor, soft_gt: torch.Tensor) -> torch.Tensor:
    # mask: T*C*num_points
    # soft_gt: T*C*num_points
    # ignores the background
    mask = mask[:, 1:].flatten(start_dim=2)
    gt = soft_gt[:, 1:].float().flatten(start_dim=2)
    numerator = 2 * (mask * gt).sum(-1)
    denominator = mask.sum(-1) + gt.sum(-1)
    loss = 1 - (numerator + 1) / (denominator + 1)
    return loss.sum(0).mean()


class LossComputer:
    def __init__(self, cfg: DictConfig, stage_cfg: DictConfig):
        super().__init__()
        self.point_supervision = stage_cfg.point_supervision
        self.num_points = stage_cfg.train_num_points
        self.oversample_ratio = stage_cfg.oversample_ratio
        self.importance_sample_ratio = stage_cfg.importance_sample_ratio

        self.sensory_weight = cfg.model.aux_loss.sensory.weight
        self.query_weight = cfg.model.aux_loss.query.weight

    def mask_loss(self, logits: torch.Tensor,
                  soft_gt: torch.Tensor) -> (torch.Tensor, torch.Tensor):
        assert self.point_supervision

        with torch.no_grad():
            # sample point_coords
            point_coords = get_uncertain_point_coords_with_randomness(
                logits, lambda x: calculate_uncertainty(x), self.num_points, self.oversample_ratio,
                self.importance_sample_ratio)
            # get gt labels
            point_labels = point_sample(soft_gt, point_coords, align_corners=False)
        point_logits = point_sample(logits, point_coords, align_corners=False)
        # point_labels and point_logits: B*C*num_points

        loss_ce = ce_loss(point_logits, point_labels)
        loss_dice = dice_loss(point_logits.softmax(dim=1), point_labels)

        return loss_ce, loss_dice

    def compute(self, data: Dict[str, torch.Tensor],
                num_objects: List[int]) -> Dict[str, torch.Tensor]:
        batch_size, num_frames = data['rgb'].shape[:2]
        losses = defaultdict(float)
        t_range = range(1, num_frames)

        for bi in range(batch_size):
            logits = torch.stack([data[f'logits_{ti}'][bi, :num_objects[bi] + 1] for ti in t_range],
                                 dim=0)
            cls_gt = data['cls_gt'][bi, 1:]  # remove gt for the first frame
            soft_gt = cls_to_one_hot(cls_gt, num_objects[bi])

            loss_ce, loss_dice = self.mask_loss(logits, soft_gt)
            losses['loss_ce'] += loss_ce / batch_size
            losses['loss_dice'] += loss_dice / batch_size

            aux = [data[f'aux_{ti}'] for ti in t_range]
            if 'sensory_logits' in aux[0]:
                sensory_log = torch.stack(
                    [a['sensory_logits'][bi, :num_objects[bi] + 1] for a in aux], dim=0)
                loss_ce, loss_dice = self.mask_loss(sensory_log, soft_gt)
                losses['aux_sensory_ce'] += loss_ce / batch_size * self.sensory_weight
                losses['aux_sensory_dice'] += loss_dice / batch_size * self.sensory_weight
            if 'q_logits' in aux[0]:
                num_levels = aux[0]['q_logits'].shape[2]

                for l in range(num_levels):
                    query_log = torch.stack(
                        [a['q_logits'][bi, :num_objects[bi] + 1, l] for a in aux], dim=0)
                    loss_ce, loss_dice = self.mask_loss(query_log, soft_gt)
                    losses[f'aux_query_ce_l{l}'] += loss_ce / batch_size * self.query_weight
                    losses[f'aux_query_dice_l{l}'] += loss_dice / batch_size * self.query_weight

        losses['total_loss'] = sum(losses.values())

        return losses
