import torch
import torch.nn as nn
from einops import rearrange



class BasicConv(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False, max=False):
        super(BasicConv, self).__init__()
        self.out_channels = out_planes
        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size,
                              stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
        self.bn = nn.BatchNorm2d(out_planes, eps=1e-5,
                                 momentum=0.01, affine=True) if bn else None
        self.relu = nn.ReLU() if relu else None
        self.max = nn.MaxPool2d(2) if max else None
    def forward(self, x):
        x = self.conv(x)
        if self.bn is not None:
            x = self.bn(x)
        if self.relu is not None:
            x = self.relu(x)
        if self.max is not None:
            x = self.max(x)
        return x

class Feature_compression(nn.Module):
    def __init__(self, resnet, in_c):
        super().__init__()
        self.max = nn.AdaptiveMaxPool2d((1, 1))
        self.alpha = 0.3
        if resnet:
            self.feature_size = 640
            self.conv_block = nn.Sequential(
                BasicConv(in_c, self.feature_size, kernel_size=1, stride=1, padding=0, relu=True)
            )
            self.mlp = nn.Sequential(
                nn.BatchNorm1d(self.feature_size),
                nn.Linear(self.feature_size, self.feature_size),
                nn.ELU(inplace=True)
            )
        else:
            self.feature_size = 64
            self.conv_block = nn.Sequential(
                BasicConv(in_c, self.feature_size, kernel_size=3, stride=1, padding=1, relu=True, bias=True)
            )
            self.mlp = nn.Sequential(
                nn.Linear(self.feature_size, self.feature_size),
                nn.ELU(inplace=True)
            )

    def forward(self, x):
        x = self.conv_block(x)
        x = self.max(x)
        x = x.view(x.size(0), -1)
        x = self.alpha * x + (1 - self.alpha) * self.mlp(x)
        return x

class Attention(nn.Module):
    def __init__(self, dim, num_heads=2, head_dim_ratio=1., qkv_bias=False, qk_scale=None,
                 attn_drop=0., proj_drop=0.):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        head_dim = round(dim // num_heads * head_dim_ratio)
        self.head_dim = head_dim
        qk_scale_factor = qk_scale if qk_scale is not None else -0.25
        self.scale = head_dim ** qk_scale_factor

        self.qkv = nn.Conv2d(dim, head_dim * num_heads * 3, 3, stride=1, padding=1, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Conv2d(self.head_dim * self.num_heads, dim, 1, stride=1, padding=0, bias=False)
        self.proj_drop = nn.Dropout(proj_drop)

        self.TG_prompt = nn.Parameter(torch.randn((1, dim, 5, 5)))

    def forward(self, q, s, TG_prompt):
        B, C, H, W = q.shape

        t = s + TG_prompt  # [B, C, H, W]

        x = self.qkv(q)  # [B, 3C, H, W]
        qkv = rearrange(x, 'b (x y z) h w -> x b y (h w) z', x=3, y=self.num_heads, z=self.head_dim)
        q, k, v = qkv[0], qkv[1], qkv[2]  # [B, head, HW, d]

        x1 = self.qkv(t)  # [B, 3C, H, W]
        qkv1 = rearrange(x1, 'b (x y z) h w -> x b y (h w) z', x=3, y=self.num_heads, z=self.head_dim)
        qs, ks, vs = qkv1[0], qkv1[1], qkv1[2]  # [B, head, HW, c]

        q = torch.cat([q, qs], dim=2)  # [B, head, 2HW, c]
        k = torch.cat([k, ks], dim=2)  # [B, head, 2HW, c]
        v = torch.cat([v, vs], dim=2)  # [B, head, 2HW, c]

        attn = ((q * self.scale) @ (k.transpose(-2, -1) * self.scale))  # [B, head, 2HW, 2HW]
        attn = attn.softmax(dim=-1)  # [B, head, 2HW, 2HW]
        attn = self.attn_drop(attn)
        x = attn @ v  # [B, head, 2HW, c]

        x = x[:, :, :H * W]  # [B, head, HW ,c]
        x = rearrange(x, 'b y (h w) z -> b (y z) h w', h=H, w=W)  # [B, C, H, W]

        x = self.proj(x)
        x = self.proj_drop(x)

        return x

class GKPM(nn.Module):
    def __init__(self, resnet):
        super().__init__()
        self.lamda = nn.Parameter(torch.FloatTensor([0.5]), requires_grad=True)
        self.max = nn.AdaptiveMaxPool2d((1, 1))
        if resnet:
            self.num_channel = 640
            self.TG_prompt = nn.Parameter(torch.randn((1, self.num_channel, 5, 5)))
            self.attn = Attention(self.num_channel, num_heads=1)
            self.fc_l = Feature_compression(resnet, in_c=640)
        else:
            self.num_channel = 64
            self.TG_prompt = nn.Parameter(torch.randn((1, self.num_channel, 5, 5)))
            self.attn = Attention(self.num_channel, num_heads=1)
            self.fc_l = Feature_compression(resnet, in_c=64)


    def forward(self, F_l, way, shot):
        b, c, h, w = F_l.shape
        m = h * w
        support_l = F_l[:way * shot].view(way, shot, c, m)
        centroid_l = support_l.mean(dim=1).unsqueeze(dim=1).view(-1, 1, c, m)
        query_l = F_l[way * shot:].view(1, -1, c, m)

        query_num = query_l.shape[1]
        zero_c = torch.zeros([1, query_num, c, m]).cuda()
        zero_q = torch.zeros([way, 1, c, m]).cuda()
        centroid_l = (centroid_l + zero_c).view(-1, c, h, w)
        query_l = (query_l + zero_q).view(-1, c, h, w)
        query_l = self.attn(query_l, centroid_l, self.TG_prompt).view(-1, query_num, c, h, w).mean(0)
        query_l = self.max(query_l).view(query_l.size(0), -1)

        f_l = self.fc_l(F_l)

        support_fl = f_l[:way * shot].view(way, shot, -1).mean(1)
        query_fl = f_l[way * shot:] + self.lamda * query_l


        return support_fl, query_fl
