
import torch
import torch.nn.functional as F
from torch import nn
from torch.autograd import Variable
from .backbone import build_backbone, FrozenBatchNorm2d
from .backbone_re import build_backbone as build_backbone_re
from nets.ops import NestedTensor, nested_tensor_from_tensor_list, unused
from torchvision.transforms import Resize
from .Prompt_encoder import Prompt_Encoder
import matplotlib.pyplot as plt
from .transformer_recurrent import build_transformer


def autopad(k, p=None, d=1):  # kernel, padding, dilation
    # Pad to 'same' shape outputs
    if d > 1:
        k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k]  # actual kernel-size
    if p is None:
        p = k // 2 if isinstance(k, int) else [x // 2 for x in k]  # auto-pad
    return p


class Conv(nn.Module):
    # Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)
    default_act = nn.SiLU()  # default activation

    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
        super().__init__()
        self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
        self.bn = nn.BatchNorm2d(c2)
        self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()

    def forward(self, x):
        return self.act(self.bn(self.conv(x)))

    def forward_fuse(self, x):
        return self.act(self.conv(x))


class SpatialGate(nn.Module):
    def __init__(self):
        super().__init__()
        kernel_size = 7
        self.compress = ChannelPool()
        self.spatial = Conv(2, 1, kernel_size, 1, (kernel_size - 1) // 2, act=nn.Sigmoid())

    def forward(self, x):
        x_compress = self.compress(x)
        scale = self.spatial(x_compress)
        return x * scale


class ChannelPool(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)


class TargetAware(nn.Module):
    def __init__(self):
        super(TargetAware, self).__init__()
        self.spatialAttLayer = SpatialGate()

    def forward(self, x):
        x = self.spatialAttLayer(x)
        return x


class spatial_attention(nn.Module):
    def __init__(self, c):
        super(spatial_attention, self).__init__()

        self.object_encoder = nn.Sequential(
            Conv(c, c, 3, p=1),
            Conv(c, c, 1)
        )

        self.object_aware = TargetAware()

    def forward(self, x):
        x_object = self.object_aware(self.object_encoder(x))

        return x_object


class MHSA_layer(nn.Module):
    def __init__(self,
                 d_model,
                 nhead,
                 dim_feedforward=2048,
                 dropout=0.1,
                 activation="relu",
                 require_linear=True):
        super().__init__()
        # Self-Attention
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        # FFN
        # Implementation of Feedforward model
        self.require_linear = require_linear
        if self.require_linear:
            self.linear1 = nn.Linear(d_model, dim_feedforward)
            self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        self.activation = nn.ReLU() if activation == "relu" else nn.GELU()

    def forward(self, src):
        src2 = self.self_attn(src, src, src)[0]
        src = self.norm1(src + src2)
        if self.require_linear:
            src2 = self.linear2(self.activation(self.linear1(src)))
            src = self.norm2(src + src2)
        return src


class MHCA_layer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.activation = nn.ReLU() if activation == "relu" else nn.GELU()

    def forward(self, resrc, qsrc):
        tgt2 = self.self_attn(resrc, resrc, resrc)[0]
        tgt = self.norm1(resrc + tgt2)
        tgt2 = self.multihead_attn(query=tgt2, key=qsrc, value=qsrc)[0]
        tgt = self.norm2(tgt + tgt2)
        tgt2 = self.linear2(self.activation(self.linear1(tgt)))
        tgt = self.norm3(tgt + tgt2)
        return tgt


class RFEM_part(nn.Module):
    def __init__(self, in_c, num_heads, num_token, n_times=1):
        super().__init__()
        self.n_times = n_times
        self.num_token = num_token
        self.MHSA_q = nn.ModuleList([MHSA_layer(in_c, num_heads, require_linear=False) for _ in range(self.n_times)])

        self.MHCA_qp = nn.ModuleList([MHCA_layer(in_c, num_heads) for _ in range(self.n_times)])
        self.MHCA_req = nn.ModuleList([MHCA_layer(in_c, num_heads) for _ in range(self.n_times)])
        self.MHCA_pq = nn.ModuleList([MHCA_layer(in_c, num_heads) for _ in range(1)])
        self.MHSA_p = nn.ModuleList([MHSA_layer(in_c, num_heads) for _ in range(1)])

        self.conv = nn.Conv2d(in_c, in_c, kernel_size=3, stride=2, padding=1)
        self.q_head = nn.Sequential(nn.Conv2d(in_c, in_c // 2, kernel_size=3, stride=1, padding=1),
                                    nn.ReLU(),
                                    nn.Conv2d(in_c // 2, 1, kernel_size=1, stride=1, padding=0))

        self.focus = spatial_attention(in_c)
        self.encode_token = nn.Sequential(nn.Linear(in_c, in_c * 4),
                                          nn.ReLU(),
                                          nn.Linear(in_c * 4, in_c),
                                          nn.ReLU())

    def forward(self, x_q, prompts, x_re16, x_re32, token):
        bs, c, qh, qw = x_q.size()
        bs, c, reh, rew = x_re16.size()

        x_q = x_q.permute(2, 3, 0, 1).contiguous().view(qh * qw, bs, c)

        prompts = prompts.permute(1, 0, 2).contiguous()
        x_re16 = x_re16.permute(2, 3, 0, 1).contiguous().view(reh * rew, bs, c)
        x_re32 = x_re32.permute(2, 3, 0, 1).contiguous().view((reh // 2) * (rew // 2), bs, c)
        x_q = torch.cat((token, x_q), dim=0)

        for i in range(self.n_times):
            x_q = self.MHSA_q[i](x_q)

            token = x_q[:self.num_token]
            T = self.encode_token(token)

            Retrieval_Map = torch.einsum('mbc,bc->bm', F.normalize(x_re32, dim=-1),
                                         F.normalize((T * T.softmax(dim=0)).sum(dim=0), dim=-1))
            Retrieval_Map = F.interpolate(Retrieval_Map.contiguous().view(bs, 1, reh // 2, rew // 2), scale_factor=2,
                                          mode="bilinear")
            Retrieval_Map = Retrieval_Map.permute(2, 3, 0, 1).contiguous().view(reh * rew, bs, 1)

            x_re16 = x_re16 * Retrieval_Map.sigmoid()
            x_q = torch.cat((T, x_q[self.num_token:]), dim=0)
            x_q = self.MHCA_qp[i](x_q, prompts)
            x_re16 = self.MHCA_req[i](x_re16, x_q)

        token = x_q[:self.num_token]
        x_s = self.conv(x_re16.permute(1, 2, 0).contiguous().view(bs, c, reh, rew))

        x = self.focus(x_s)
        x_KD = self.q_head(x_q[self.num_token:].permute(1, 2, 0).contiguous().view(bs, c, qh, qw))
        return x, x_KD, token


class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
        super().__init__()
        self.num_layers = num_layers
        h = [hidden_dim] * (num_layers - 1)
        self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))

    def forward(self, x):
        for i, layer in enumerate(self.layers):
            x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
        return x


def draw_image(x, name):
    vis = x.detach().cpu().numpy()[0]
    vis = (vis - vis.min()) / (vis.max() - vis.min())
    vis = vis.mean(0)
    vis = 1 - vis
    plt.imshow(vis)
    plt.axis('off')
    plt.imsave(name, vis)


class ReCOT(nn.Module):
    def __init__(self, mode, backbone, position_embedding, hidden_dim, num_classes, num_queries, query_img_shape,
                 recurrent_steps, aux_loss=False,
                 pretrained=False):
        super().__init__()
        self.mode = mode
        self.backbone = build_backbone_re(backbone, position_embedding, hidden_dim, pretrained=pretrained)
        self.prompt_encoder = Prompt_Encoder(hidden_dim, (query_img_shape[0], query_img_shape[1]))
        self.input_proj = nn.Conv2d(self.backbone.num_channels, hidden_dim, kernel_size=1)
        self.input_proj2 = nn.Conv2d(self.backbone.num_channels // 2, hidden_dim, kernel_size=1)
        self.fusion = RFEM_part(hidden_dim, 8, num_queries, n_times=1)

        self.input_coding = nn.Sequential(
            nn.Conv2d(1, hidden_dim * 4, 3, 1, 1),
            nn.BatchNorm2d(hidden_dim * 4),
            nn.ReLU(),
            nn.Conv2d(hidden_dim * 4, hidden_dim, 1, 1, 0),
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU())
        self.transformer = build_transformer(hidden_dim=hidden_dim, recurrent_steps=recurrent_steps, pre_norm=False)
        hidden_dim = self.transformer.d_model

        self.class_embed = nn.Linear(hidden_dim, num_classes + 1)
        self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
        self.query_embed = nn.Embedding(num_queries, hidden_dim)

        self.num_queries = num_queries
        self.aux_loss = aux_loss

    def forward(self, qimgs: NestedTensor, reimgs: NestedTensor, qmasks, q_xys):

        if isinstance(qimgs, (list, torch.Tensor)):
            qimgs = nested_tensor_from_tensor_list(qimgs)
        if isinstance(reimgs, (list, torch.Tensor)):
            reimgs = nested_tensor_from_tensor_list(reimgs)

        qfeatures, qpos = self.backbone(qimgs)
        refeatures, repos = self.backbone(reimgs)

        qsrc, qmask = qfeatures['out32'].decompose()

        resrc16, remask16 = refeatures['out16'].decompose()
        resrc32, remask32 = refeatures['out32'].decompose()

        assert remask16 is not None
        assert qmask is not None

        bs, _, h, w = qsrc.size()
        torch_resize = Resize([h, w], interpolation=0)

        token = self.query_embed.weight.unsqueeze(1).repeat(1, resrc16.size(0), 1)

        resrc, x_KD, token = self.fusion(
            self.input_proj(qsrc) + qpos['out32'],
            self.prompt_encoder(q_xys),
            self.input_proj2(resrc16) + repos['out16'],
            self.input_proj(resrc32) + repos['out32'],
            token)
        aux = [None]
        aux.append(qmasks[:, 1, :, :].unsqueeze(1))
        aux.append(x_KD)
        auxes = []
        auxes.append(aux)

        hs, aux2 = self.transformer(resrc, remask32, token, repos['out32'])
        auxes += aux2
        outputs_class = self.class_embed(hs)

        outputs_coord = self.bbox_embed(hs).sigmoid()

        out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1], 'aux_out': auxes,
               'full_pred_logits': outputs_class, 'full_pred_boxes': outputs_coord}

        if self.aux_loss:
            out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord)
        return out

    @unused
    def _set_aux_loss(self, outputs_class, outputs_coord):
        return [{'pred_logits': a, 'pred_boxes': b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]

    def freeze_bn(self):
        for m in self.modules():
            if isinstance(m, nn.BatchNorm2d) or isinstance(m, FrozenBatchNorm2d):
                m.eval()
