

class AC(nn.Module):
    def __init__(self, cm, tw=0.1, iw=0.1, tau=3, iau=6, lvs=[6, 12, 18, 24], r=True, **kwargs):
        super().__init__()
        self.cm=cm
        self.enc_i=cm.visual
        self.tau=tau
        self.iau=iau
        self.tw=tw
        self.iw=iw
        self.lvs=lvs
        l_a=nn.ModuleList([SimpleAdapter(1024, 1024) for _ in range(iau)])
        s_p=nn.ModuleList([SimpleProj(1024, 768, r) for _ in range(len(lvs))])
        d_p=SimpleProj(1024, 768, r)
        self.adpt_i=nn.ModuleDict({"la": l_a, "sp": s_p, "dp": d_p, "tpa": TGA(grid_size=4, temperature=1.0)})
        self.adpt_t=nn.ModuleList([SimpleAdapter(768, 768) for _ in range(tau)] + [SimpleProj(768, 768, relu=True)])
        self._iw_()
    def _iw_(self):
        for p in self.adpt_i.parameters():
            if p.dim()>1:
                nn.init.xavier_uniform_(p)
        for p in self.adpt_t.parameters():
            if p.dim()>1:
                nn.init.xavier_uniform_(p)
    def fwd(self, d):
        d=self.enc_i.conv1(d)
        d=d.reshape(d.shape[0], d.shape[1], -1).permute(0, 2, 1)
        d=torch.cat([self.enc_i.class_embedding.to(d.dtype) + torch.zeros(d.shape[0], 1, d.shape[-1], dtype=d.dtype, device=d.device), d], dim=1)
        d=d+self.enc_i.positional_embedding.to(d.dtype)
        d=self.enc_i.patch_dropout(d)
        d=self.enc_i.ln_pre(d).permute(1, 0, 2)
        tks=[]
        for i in range(24):
            d, _=self.enc_i.transformer.resblocks[i](d, attn_mask=None)
            if i<self.iau:
                o=self.adpt_i["la"][i](d)
                o=o*d.norm(dim=-1, keepdim=True)/o.norm(dim=-1, keepdim=True)
                d=self.iw*o+(1-self.iw)*d
            if i+1 in self.lvs:
                tks.append(d[1:, :, :])
        d=d.permute(1, 0, 2)
        tks=[t.permute(1, 0, 2) for t in tks]
        tks=[self.enc_i.ln_post(t) for t in tks]
        st=[self.adpt_i["sp"][i](t) for i, t in enumerate(tks)]
        st=[F.normalize(t, dim=-1) for t in st]
        dt=F.normalize(self.adpt_i["dp"](tks[-1]), dim=-1).mean(1)
        fpf=st[-1]
        arf, aw=self.adpt_i["tpa"](fpf)
        return st, dt, arf, aw
    def enc_t(self, t, adapt=True):
        if not adapt:
            return self.cm.encode_text(t)
        cdt=self.cm.transformer.get_cast_dtype()
        d=self.cm.token_embedding(t).to(cdt)
        d=d+self.cm.positional_embedding.to(cdt)
        d=d.permute(1, 0, 2)
        for i in range(12):
            d, _=self.cm.transformer.resblocks[i](d, attn_mask=self.cm.attn_mask)
            if i<self.tau:
                o=self.adpt_t[i](d)
                o=o*d.norm(dim=-1, keepdim=True)/o.norm(dim=-1, keepdim=True)
                d=self.tw*o+(1-self.iw)*d
        d=self.cm.ln_final(d.permute(1, 0, 2))
        d=self.adpt_t[-1](d[torch.arange(d.shape[0]), t.argmax(dim=-1)])
        return d