from collections import OrderedDict
from utils import utils
import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision.models.detection import faster_rcnn

import numpy as np
import logging


def load_backbone(backbone_type, weight_path=None):
    logger = logging.getLogger("load_backbone")
    logger.info(f"Loading {backbone_type} backbone from {weight_path}")

    backbone_type = backbone_type.split("_")
    model_name, weight_type = backbone_type[0], "_".join(backbone_type[1:])

    assert model_name in [
        "resnet18", "resnet50", "resnet101", "resnet152",
        "faster50",
    ], model_name

    assert weight_type in ["", "pt", "pt_frz"], weight_type
    # ""        =from scratch
    # "pt"      =pretrained+finetune
    # "pt_frz"  =pretrained+freeze

    use_torch_weight = (weight_type in ["pt", "pt_frz"]) and (weight_path is None)

    if "resnet" in model_name:
        # e.g.  torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)
        backbone = torch.hub.load('pytorch/vision:v0.10.0', model_name, pretrained=use_torch_weight)
        feat_dim = backbone.fc.in_features
        backbone.fc = nn.Sequential()

    elif model_name == "faster50":
        backbone = FasterRCNN50Backbone(pretrained=use_torch_weight)
        feat_dim = backbone.representation_size

    else:
        raise NotImplementedError()

    # load weight with provided checkpoint
    if weight_type in ["pt", "pt_frz"] and weight_path:
        checkpoint = torch.load(weight_path)
        if "state_dict" in checkpoint:
            backbone.backbone.load_state_dict(checkpoint["state_dict"])
        else:
            backbone.backbone.load_state_dict(checkpoint)

    # freeze weight
    if weight_type == "pt_frz":
        for param in backbone.parameters():
            param.requires_grad = False

    return backbone, feat_dim


def onehot(x, depth, device):
    return torch.eye(depth)[x].to(device).float()


def build_counterfactual(causal, num_attr, num_aff):
    '''
    :param causal: [ N, 3 ] (inst_id, attr_id, aff_id)
    :param num_attr:
    :param num_aff:
    :return:
         counterfactual_inst_id : tensor [ M ]  index of instance in batch
         counterfactual_attr_mask: tensor [ M, num_attr ]  which attr to be skipped
         counterfactual_aff_mask: tensor [ M, num_aff ]  which aff will be affected after counterfactual
    '''
    orig_size = causal.shape[0]
    unique_inst_att_pair = torch.unique(causal[:, :2], dim=0)
    reduce_size = unique_inst_att_pair.shape[0]
    counterfactual_inst_id = unique_inst_att_pair[:, 0]
    counterfactual_attr_mask = onehot(unique_inst_att_pair[:, 1], num_attr, causal.device)
    space_mapping = torch.all(
        causal[:, :2].unsqueeze(0).expand(reduce_size, orig_size, 2) == \
        unique_inst_att_pair[:, :2].unsqueeze(1).expand(reduce_size, orig_size, 2),
        dim=2
    ).float()
    counterfactual_aff_mask = torch.matmul(space_mapping, onehot(causal[:, 2], num_aff, causal.device))

    return counterfactual_inst_id, counterfactual_attr_mask, counterfactual_aff_mask


class Aggregator(nn.Module):
    def __init__(self, method, args=None, num_para=None):
        super().__init__()
        self.support = ['sum', 'mean', 'max', 'concat']
        self.method = method

        if method not in self.support:
            raise NotImplementedError(
                'Not supported aggregation method [%s].\nWe only support: %s' % (method, self.support))

        if method == "concat":
            self.compression = nn.Linear(args.parallel_attr_rep_dim*num_para, args.aggr_rep_dim, bias=False)
            self.relu = nn.ReLU(inplace=True)


    def forward(self, tensor, mask=None, mask_method="zero"):
        """
        :param tensor:  bz * n * dim
        :param mask:    bz * n
        :return:        bz * dim
        """
        

        if mask is not None:
            if len(mask.size())==2:
                mask = mask.unsqueeze(-1)
            else:
                mask = mask.unsqueeze(-1).unsqueeze(0)

            if mask_method == "zero":
                tensor = tensor * mask
            elif mask_method == "random":
                rdm = torch.randn_like(tensor).to(tensor.device)
                tensor = torch.where(mask.expand_as(tensor), tensor, rdm)
            else:
                raise NotImplementedError(mask_method)

        if self.method == 'sum':
            return tensor.sum(1)
        elif self.method == 'mean':
            return tensor.mean(1)
        elif self.method == 'max':
            return tensor.max(1).values
        elif self.method == 'concat':
            out = tensor.reshape(tensor.shape[0], -1)
            out = self.compression(out)
            out = self.relu(out)
            return out



class FasterRCNN50Backbone(nn.Module):
    def __init__(self, pretrained=False):
        super(FasterRCNN50Backbone, self).__init__()

        self.backbone = faster_rcnn.resnet_fpn_backbone('resnet50', pretrained)
        self.box_roi_pool = faster_rcnn.MultiScaleRoIAlign(
            featmap_names=['0', '1', '2', '3'],
            output_size=7,
            sampling_ratio=2)

        resolution = self.box_roi_pool.output_size[0]
        out_channels = self.backbone.out_channels
        self.representation_size = 1024
        self.box_head = faster_rcnn.TwoMLPHead(
            out_channels * resolution ** 2,
            self.representation_size)

        if pretrained:
            state_dict = faster_rcnn.load_state_dict_from_url(
                faster_rcnn.model_urls['fasterrcnn_resnet50_fpn_coco'],
                progress=True)

            backbone_pref = "backbone."
            self.backbone.load_state_dict({
                k[len(backbone_pref):]: v
                for k, v in state_dict.items()
                if k.startswith(backbone_pref)
            })

            box_head_pref = "roi_heads.box_head."
            self.box_head.load_state_dict({
                k[len(box_head_pref):]: v
                for k, v in state_dict.items()
                if k.startswith(box_head_pref)
            })

    def forward(self, images, bboxes):
        image_sizes = [img.shape[-2:] for img in images]
        images = torch.stack(images, 0)

        features = self.backbone(images)
        if isinstance(features, torch.Tensor):
            features = OrderedDict([('0', features)])

        box_features = self.box_roi_pool(features, bboxes, image_sizes)
        # input (Tensor[N, C, H, W])
        # boxes (Tensor[K, 5] or List[Tensor[L, 4]])
        # output_size (int or Tuple[int, int])
        box_features = self.box_head(box_features)

        return box_features


class ParallelLinear(nn.Module):
    def __init__(self, in_dim, out_dim, num_para, bias=True):
        super(ParallelLinear, self).__init__()
        self.bias = bias
        self.__weight = nn.Parameter(torch.randn(
            num_para, in_dim, out_dim
        ))
        if self.bias:
            self.__bias = nn.Parameter(torch.zeros(
                num_para, out_dim
            ))

    def forward(self, x):
        x = torch.einsum('...ij, ijk -> ...ik', x, self.__weight)
        if self.bias:
            x = x + self.__bias
        return x



class ParallelMLP(nn.Module):
    def __init__(self, inp_dim, out_dim, num_para, hidden_layers=[], layernorm=True, bias=True, share_last_fc=False, out_relu=False):
        super().__init__()
        inner_bias = bias

        mod = []
        if hidden_layers is not None:
            last_dim = inp_dim
            for hid_dim in hidden_layers:
                mod.append(ParallelLinear(last_dim, hid_dim, num_para, bias=inner_bias))

                if layernorm:
                    mod.append(nn.LayerNorm(hid_dim))
                mod.append(nn.ReLU(inplace=True))
                last_dim = hid_dim

            if share_last_fc:
                mod.append(nn.Linear(last_dim, out_dim, bias=inner_bias))
            else:
                mod.append(ParallelLinear(last_dim, out_dim, num_para, bias=inner_bias))
            
            if out_relu:
                mod.append(nn.ReLU(inplace=True))

        self.mod = nn.Sequential(*mod)

    def forward(self, x):
        output = self.mod(x)
        return output


class MLP(nn.Module):
    """Multi-layer perceptron, 1 layers as default. No activation after last fc"""

    def __init__(self, inp_dim, out_dim, hidden_layers=[], batchnorm=True, bias=True, out_relu=False, out_bn=False):
        super(MLP, self).__init__()

        inner_bias = bias and (not batchnorm)

        mod = []
        if hidden_layers is not None:
            last_dim = inp_dim
            for hid_dim in hidden_layers:
                mod.append(nn.Linear(last_dim, hid_dim, bias=inner_bias))
                if batchnorm:
                    mod.append(nn.BatchNorm1d(hid_dim))
                mod.append(nn.ReLU(inplace=True))
                last_dim = hid_dim

            mod.append(nn.Linear(last_dim, out_dim, bias=bias))
            if out_bn:
                mod.append(nn.BatchNorm1d(out_dim))
            if out_relu:
                mod.append(nn.ReLU(inplace=True))

        self.mod = nn.Sequential(*mod)

    def forward(self, x):
        output = self.mod(x)
        return output


class Distance(nn.Module):
    def __init__(self, metric):
        super(Distance, self).__init__()

        if metric == "L2":
            self.metric_func = lambda x, y: torch.norm(x - y, 2, dim=-1)
        elif metric == "L1":
            self.metric_func = lambda x, y: torch.norm(x - y, 1, dim=-1)
        elif metric == "cos":
            self.metric_func = lambda x, y: 1 - F.cosine_similarity(x, y, dim=-1)
        else:
            raise NotImplementedError("Unsupported distance metric: %s" % metric)

    def forward(self, x, y):
        output = self.metric_func(x, y)
        return output


class DistanceLoss(Distance):
    def forward(self, x, y):
        output = self.metric_func(x, y)
        output = torch.mean(output)
        return output

class CrossEntropyLossWithProb(nn.Module):
    def __init__(self, weight=None, clip_thres=1e-8):
        super(CrossEntropyLossWithProb, self).__init__()
        self.nll = nn.NLLLoss(weight)
        self.clip_thres = clip_thres

    def forward(self, probs, labels):
        probs = probs.clamp_min(self.clip_thres)
        ll = torch.log(probs)
        return self.nll(ll, labels)


class CounterfactualHingeLoss(nn.Module):
    def __init__(self, margin=0.1):
        super().__init__()
        self.margin = margin

    def forward(self, cf_prob, orig_prob, gt_label, cf_label_mask):
        loss = torch.where(
            gt_label == 1,
            cf_prob - (orig_prob - self.margin),
            (orig_prob + self.margin) - cf_prob
        )
        # loss[loss < 0] = 0
        loss = nn.functional.relu(loss, inplace=True)

        loss = loss * cf_label_mask
        loss = loss.mean(0).sum()
        return loss
