import torch
import torch.nn as nn
import timm
from timm.models.registry import register_model
from spiking_layer import MixedLIF, LIFt


class BarlowTwins(nn.Module):
    def __init__(self, backbone, act_func_lif=LIFt, in_dim=512, out_dim=1024, hidden_dim=2048, T=4):
        super(BarlowTwins, self).__init__()

        self.T = T
        self.backbone = backbone
        self.projector = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim))
        self.lif = act_func_lif(tau=4.0, v_threshold=0.5, detach_reset=True, backend='cupy')
        self.linear = nn.Linear(hidden_dim, out_dim)

    def forward_one(self, x):
        _, feature = self.backbone(x)
        T, B, C = feature.shape
        z = self.projector(feature.flatten(0, 1)).reshape(T, B, -1).contiguous()
        z = self.lif(z)
        z = self.linear(z.flatten(0, 1)).reshape(T, B, -1).contiguous()
        return feature, z # [T, 2*B, C], [T, 2*B, C]

    def forward(self, x1, x2):
        x = torch.cat((x1, x2), dim=0)
        f, z = self.forward_one(x)

        T, B, C = f.shape
        b = int(B//2)
        # fetch each path
        f1 = f[:, :b, ...]
        f2 = f[:, b:, ...]
        z1 = z[:, :b, ...]
        z2 = z[:, b:, ...]
        return f1, f2, z1, z2

@register_model
def barlow_twins_spikformer(**kwargs):
    if kwargs['act_func'] == 'MixedLIF':
        act_func = MixedLIF
    else:
        act_func = LIFt
    backbone = timm.create_model(
        'spikformer',
        act_func_lif=act_func,
        **{k: v for k, v in kwargs.items() if k not in ('pretrained', 'pretrained_cfg', 'act_func')}
    )
    backbone.head = nn.Identity()
    model = BarlowTwins(
        backbone=backbone,
        act_func_lif=act_func,
        in_dim=kwargs['embed_dims']
    )
    return model
