
import torch.nn as nn
import torch.nn.functional as F


class Block(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride):
        super(Block, self).__init__()
        self.n1 = nn.BatchNorm2d(in_planes)
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.n2 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False)

    def forward(self, x):
        out = F.relu(self.n1(x))
        shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
        out = self.conv1(out)
        out = self.conv2(F.relu(self.n2(out)))
        out += shortcut
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride):
        super(Bottleneck, self).__init__()
        self.n1 = nn.BatchNorm2d(in_planes)
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.n2 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.n3 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)

        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False)

    def forward(self, x):
        out = F.relu(self.n1(x))
        shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
        out = self.conv1(out)
        out = self.conv2(F.relu(self.n2(out)))
        out = self.conv3(F.relu(self.n3(out)))
        out += shortcut
        return out


class ResNet(nn.Module):
    def __init__(self, data_shape, hidden_size, block, num_blocks, target_size):
        super().__init__()
        self.in_planes = hidden_size[0]
        self.conv1 = nn.Conv2d(data_shape[0], hidden_size[0], kernel_size=3, stride=1, padding=1, bias=False)
        self.layer1 = self._make_layer(block, hidden_size[0], num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, hidden_size[1], num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, hidden_size[2], num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, hidden_size[3], num_blocks[3], stride=2)
        self.n4 = nn.BatchNorm2d(hidden_size[3] * block.expansion)
        self.fc = nn.Linear(hidden_size[3] * block.expansion, target_size)
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.bias.data.zero_()

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def f(self, x):
        x = self.get_feature(x)
        x = self.fc(x)
        return x
    
    def get_feature(self, x):
        x = self.get_raw_feature(x)         # N * C * (H, W)
        x = F.adaptive_avg_pool2d(x, 1)     # N * C * (1, 1)
        x = x.view(x.size(0), -1)           # N * C
        return x
    
    def get_raw_feature(self, x):
        # shape: B, C, H, W
        x = self.conv1(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = F.relu(self.n4(x))
        # x = F.adaptive_avg_pool2d(x, 1)
        # x = x.view(x.size(0), -1)
        return x

    def forward(self, input):
        return self.f(input)

def resnet9(data_shape=[3,32,32], target_size=10, ):
    model = ResNet(data_shape, [64, 128, 256, 512], Block, [1, 1, 1, 1], target_size)
    return model


def resnet18(data_shape=[3,32,32], target_size=10, ):
    model = ResNet(data_shape, [64, 128, 256, 512], Block, [2, 2, 2, 2], target_size)
    return model




# class Attention(nn.Module):
#     def __init__(self, dim, num_heads=4, qkv_bias=True, attn_drop=0., proj_drop=0.):
#         super().__init__()
#         # print(f'Attention: dim={dim}, num_heads={num_heads}, qkv_bias={qkv_bias}, attn_drop={attn_drop}, proj_drop={proj_drop}')
#         assert dim % num_heads == 0, 'dim should be divisible by num_heads'
#         self.num_heads = num_heads
#         head_dim = dim // num_heads
#         self.scale = head_dim ** -0.5

#         self.linear_q = nn.Linear(dim, dim, bias=qkv_bias)
#         self.linear_k = nn.Linear(dim, dim, bias=qkv_bias)
#         self.attn_drop = nn.Dropout(attn_drop)
#         self.proj_drop = nn.Dropout(proj_drop)

#     def forward(self, x, query_embed):
#         B, N, C = x.shape
#         K = query_embed.size(1)
        
#         q = self.linear_q(query_embed).expand(B, -1, -1).reshape(B, K, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
#         k = self.linear_k(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
#         v = x.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)

#         attn = (q @ k.transpose(-2, -1)) * self.scale
#         attn = attn.softmax(dim=-1)
#         attn = self.attn_drop(attn)

#         x = (attn @ v).transpose(1, 2).reshape(B, K, C)
#         x = self.proj_drop(x)
#         return x        
    

# class LayerScale(nn.Module):
#     def __init__(self, dim, init_values=1e-5, inplace=False):
#         super().__init__()
#         self.inplace = inplace
#         self.gamma = nn.Parameter(init_values * torch.ones(dim)) # type: ignore

#     def forward(self, x):
#         return x.mul_(self.gamma) if self.inplace else x * self.gamma
    
# def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
#     """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

#     This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
#     the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
#     See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
#     changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
#     'survival rate' as the argument.

#     """
#     if drop_prob == 0. or not training:
#         return x
#     keep_prob = 1 - drop_prob
#     shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
#     random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
#     if keep_prob > 0.0 and scale_by_keep:
#         random_tensor.div_(keep_prob)
#     return x * random_tensor


# class DropPath(nn.Module):
#     """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
#     """
#     def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
#         super(DropPath, self).__init__()
#         self.drop_prob = drop_prob
#         self.scale_by_keep = scale_by_keep

#     def forward(self, x):
#         return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)

#     def extra_repr(self):
#         return f'drop_prob={round(self.drop_prob,3):0.3f}'
    
# class BDMatch_Net(nn.Module):
#     def __init__(self, base, num_classes, num_heads=4, qkv_bias=True, attn_drop=0., drop=0.,
#                  init_values=None, drop_path=0.2, use_rot=False):
#         super(BDMatch_Net, self).__init__()
#         self.backbone = base
#         self.num_features = base.fc.in_features 
        
#         # Multi-head dot-product attention module to extract label-specific features
#         self.attn = Attention(self.num_features, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
#         self.ls1 = LayerScale(self.num_features, init_values=init_values) if init_values else nn.Identity()
#         self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
#         self.norm = nn.LayerNorm(self.num_features, eps=1e-6)
        
#         # initialize the label embedding
#         self.query_embed = nn.Parameter(torch.zeros(1, num_classes, self.num_features)) # type: ignore
#         nn.init.normal_(self.query_embed)
        
#         self.fc2 = nn.Linear(self.num_features, num_classes)
#         nn.init.xavier_normal_(self.fc2.weight.data)
#         self.fc2.bias.data.zero_()
        
#         # Standard classifiers used in the dual-branch architecture
#         self.st_fc1 = nn.Linear(self.num_features, num_classes)
#         nn.init.xavier_normal_(self.st_fc1.weight.data)
#         self.st_fc1.bias.data.zero_()
#         self.st_fc2 = nn.Linear(self.num_features, num_classes)
#         nn.init.xavier_normal_(self.st_fc2.weight.data)
#         self.st_fc2.bias.data.zero_()
    
#     def forward(self, x, st_pred=False):
#         feat = self.backbone.get_raw_feature(x)
#         out = F.adaptive_avg_pool2d(feat, 1)
#         out = out.view(-1, 1, self.num_features)
#         feat = feat.reshape((feat.size(0), feat.size(1), -1)).permute(0, 2, 1)
#         feat = out + self.drop_path1(self.ls1(self.attn(feat, self.query_embed)))
#         feat = self.norm(feat)

#         print(f'feat.shape: {feat.shape}')
        
#         logits = self.head_forward(feat)
#         if not st_pred:
#             return {'logits':logits}
#         st_logits = self.head_forward_st(feat)
#         return {'logits':logits, 'st_logits':st_logits}
    
#     def head_forward(self, x):
#         print(f'x.shape: {x.shape}, self.backbone.fc.weight.shape: {self.backbone.fc.weight.shape}, self.fc2.weight.shape: {self.fc2.weight.shape}')
#         logits = (x * self.backbone.fc.weight).sum(dim=-1) + self.backbone.fc.bias
#         logits2 = (x * self.fc2.weight).sum(dim=-1) + self.fc2.bias
#         return logits - logits2
    
#     def head_forward_st(self, x):
#         logits = (x * self.st_fc1.weight).sum(dim=-1) + self.st_fc1.bias
#         logits2 = (x * self.st_fc2.weight).sum(dim=-1) + self.st_fc2.bias
#         return logits - logits2

#     def group_matcher(self, coarse=False):
#         matcher = self.backbone.group_matcher(coarse, prefix='backbone.')
#         return matcher
    
#     def no_weight_decay(self):
#         nwd = []
#         for n, _ in self.named_parameters():
#             if 'bn' in n or 'bias' in n:
#                 nwd.append(n)
#         return nwd
    

# class FedNet(nn.Module):
#     def __init__(self, base, num_classes, num_heads=4, qkv_bias=True, attn_drop=0., drop=0.,
#                  init_values=None, drop_path=0.2, use_rot=False):
#         super(FedNet, self).__init__()
#         self.backbone = base
#         self.num_features = base.fc.in_features 
        
#         # Multi-head dot-product attention module to extract label-specific features
#         self.attn = Attention(self.num_features, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
#         self.ls1 = LayerScale(self.num_features, init_values=init_values) if init_values else nn.Identity()
#         self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
#         self.norm = nn.LayerNorm(self.num_features, eps=1e-6)
        
#         # initialize the label embedding
#         self.query_embed = nn.Parameter(torch.zeros(1, num_classes, self.num_features)) # type: ignore
#         nn.init.normal_(self.query_embed)

#         self.fc_feat = nn.Sequential(nn.Linear(self.num_features, self.num_features),
#                                     nn.ReLU(),
#                                     nn.Linear(self.num_features, self.num_features),
#                                     nn.ReLU())
#         self.ova_pos = nn.Linear(self.num_features, num_classes, bias=False)
#         self.ova_neg = nn.Linear(self.num_features, num_classes, bias=False)
#         self.fc = nn.Linear(self.num_features, num_classes, bias=False)
        
#         nn.init.xavier_normal_(self.fc.weight.data)
#         nn.init.xavier_normal_(self.ova_pos.weight.data)
#         nn.init.xavier_normal_(self.ova_neg.weight.data)

    
#     def forward(self, x):
#         feat = self.backbone.get_raw_feature(x)     # N * C * (H, W)
#         out = F.adaptive_avg_pool2d(feat, 1)        # N * C * (1, 1)
#         out = out.view(-1, 1, self.num_features)    # N * 1 * C
#         feat = feat.reshape((feat.size(0), feat.size(1), -1)).permute(0, 2, 1) # N * HW * C
#         feat = out + self.drop_path1(self.ls1(self.attn(feat, self.query_embed)))
#         ova_feat = self.norm(feat)
#         fc_feat = out.view(x.size(0), -1)
#         logits = self.fc(self.fc_feat(fc_feat)) 

#         logits_ova_pos = (ova_feat * self.ova_pos.weight).sum(-1)
#         logits_ova_neg = (ova_feat * self.ova_neg.weight).sum(-1)

#         logits_ova = torch.cat([logits_ova_neg, logits_ova_pos], dim=1)

#         return_dict = {
#             'logits': logits, 
#             'logits_ova': logits_ova
#         }
#         return return_dict


# if __name__ == '__main__':
#     import torch
#     model = FedNet(resnet18(), num_classes=10, use_rot=False)
#     x = torch.randn(128, 3, 32, 32)
#     y = model(x)
#     print(y['logits'].shape)