# Patch-based memory module (PMM)
import torch
import torch.nn as nn
import torch.nn.functional as F

class ReconstructionHead(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
    
    def forward(self, x):
        return self.model(x)
    
class MemoryModule(nn.Module):
    def __init__(self, d_model=1024, patch_len=8, num_patch=64, top_k=3, tau=0.3):
        super().__init__()
        self.input_dim  = 2 * d_model
        self.hidden_dim = d_model // 2
        self.dec = ReconstructionHead(
            input_dim=self.input_dim, 
            hidden_dim=self.hidden_dim, 
            output_dim=patch_len
        )
        self.patch_len = patch_len
        self.num_patch = num_patch
        self.k = top_k
        self.tau = tau

        # Projection weights for memory update
        self.Wm = nn.Parameter(torch.empty(d_model, d_model))
        self.Wq = nn.Parameter(torch.empty(d_model, d_model))
        nn.init.xavier_uniform_(self.Wm)
        nn.init.xavier_uniform_(self.Wq)

    def forward(self, Q, M, patch_masks):
        device, dtype = Q.device, Q.dtype
        B, G = Q.shape[0], M.shape[0]
        tau = torch.tensor(self.tau, device=device, dtype=dtype)
        tau_dom = torch.tensor(getattr(self, "tau_dom", 1), device=device, dtype=dtype)

        def l2norm(x, dim=-1, eps=1e-8):
            return F.normalize(x, p=2, dim=dim, eps=eps)


        outputs = []
        updated_M = M  

        for i in range(B):
            pmask = patch_masks[i].squeeze(0).bool()
            q_sel = Q[i][pmask, :].contiguous()     
            if q_sel.numel() == 0:
                zeros_out = torch.zeros(self.num_patch * self.patch_len, device=device, dtype=dtype)
                outputs.append(zeros_out)
                continue

            # (0) Normalize query and memory alignment
            qn   = l2norm(q_sel, dim=-1)                 
            M_sel = updated_M[:, pmask, :].contiguous()  

            # (1) Domain selection
            qf = qn.reshape(1, -1)                      
            Mf = M_sel.reshape(G, -1)                    
            logits = (qf @ Mf.T) / tau_dom               
            logits = logits - logits.max(dim=1, keepdim=True).values
            Dprob  = F.softmax(logits, dim=1)


            topk_val, topk_idx = torch.topk(Dprob, self.k, dim=1) 
            topk_val = topk_val.squeeze(0)                         
            topk_idx = topk_idx.squeeze(0)                         

            Mk  = M_sel.index_select(0, topk_idx).contiguous()     # [k, P, D]

            # (2) Memory update
            Vm  = torch.matmul(Mk, qn.T.unsqueeze(0)) / tau       
            Vm  = Vm - Vm.amax(dim=1, keepdim=True)
            Vm  = F.softmax(Vm, dim=1)
            VmQ = torch.matmul(Vm, qn.unsqueeze(0))                

            Mw  = torch.matmul(Mk, self.Wm)                        
            Vw  = torch.matmul(VmQ, self.Wq)                       
            Wm_gate    = torch.sigmoid(Mw + Vw)                    
            updated_Mk = (1.0 - Wm_gate) * Mk + Wm_gate * VmQ      

            # (3) Query refinement
            updated_Mk_n = l2norm(updated_Mk, dim=-1)              
            A2  = torch.matmul(qn.unsqueeze(0), updated_Mk_n.transpose(-1, -2)) / tau  
            A2  = A2 - A2.amax(dim=-1, keepdim=True)
            Wq  = F.softmax(A2, dim=-1)
            WqQ = torch.matmul(Wq, qn.unsqueeze(0))                
            q_  = torch.einsum("k,kpd->pd", topk_val, WqQ).contiguous()  

            # (4) Decoder: concatenate original and refined queries
            q2 = torch.cat([q_sel, q_], dim=-1).contiguous()       
            recon = self.dec(q2)                                   
            pad_rows = self.num_patch - recon.shape[0]
            if pad_rows > 0:
                zeros = torch.zeros(pad_rows, self.patch_len, device=device, dtype=recon.dtype)
                recon = torch.cat([zeros, recon], dim=0)
            outputs.append(recon.reshape(-1))

            # (5) Update memory state
            with torch.no_grad():
                M_sel_out = M_sel.clone()
                upd = l2norm(updated_Mk.detach(), dim=-1)
                M_sel_out[topk_idx] = upd
                up_M = updated_M.clone()
                up_M[:, pmask, :] = M_sel_out
                updated_M = up_M

        outputs = torch.stack(outputs)
        return outputs, updated_M

    def pred(self,
            Q, M,
            test_data, test_mask, patch_mask,
            tau=None,
            p_chunk: int = 128,    
            use_amp: bool = True):
        """
        Inference without updating memory (used at test time).

        Returns:
            recon: reconstructed sequences 
            mse: element-wise reconstruction errors
        """
        
        device, dtype = Q.device, Q.dtype
        B, P, D = Q.shape
        tau = float(self.tau if tau is None else tau)
        tau_dom = float(getattr(self, "tau_dom", 1))            
        k = self.k
        assert P == self.num_patch and self.patch_len > 0

        pm = patch_mask.squeeze(1).to(torch.bool)          
        pm_f = pm.to(dtype)                                 
        Qm = Q * pm_f.unsqueeze(-1)                         
        Qn   = F.normalize(Qm, dim=-1)                      

        total = self.num_patch * self.patch_len
        tgt = test_data.reshape(B, -1)
        msk = test_mask
        if msk.dtype != torch.bool:
            msk = (msk != 0)
        msk = msk.reshape(B, -1)
        cut = min(tgt.size(1), total)
        if tgt.size(1) != total or msk.size(1) != total:
            tgt = F.pad(tgt[:, :cut], (0, total - cut)) if cut < total else tgt[:, :total]
            msk = F.pad(msk[:, :cut].to(torch.float32), (0, total - cut)).to(torch.bool) if cut < total else msk[:, :total]
        tgt = tgt.view(B, self.num_patch, self.patch_len).to(dtype)     
        msk = msk.view(B, self.num_patch, self.patch_len)               

        # Domain selection
        amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
        with torch.amp.autocast(device_type='cuda', dtype=amp_dtype, enabled=use_amp):
            sim_bgP = torch.einsum('bpd,gpd->bgp', Qn, M)
        sim_bgP = sim_bgP.to(dtype)

        mask_bgP = pm.unsqueeze(1)                                   
        sim_sum  = (sim_bgP * mask_bgP).sum(dim=2)                  
        logits   = sim_sum / max(tau_dom, 1e-6)                      
        logits   = logits - logits.max(dim=1, keepdim=True).values
        Dprob    = torch.softmax(logits, dim=1)                      
        topk_val, topk_idx = torch.topk(Dprob, k, dim=1)             

        # Select memory items
        Mk = M.index_select(0, topk_idx.reshape(-1))     
        Mk = Mk.view(B, k, P, D).contiguous()           

        # Chunked attention over patches
        q_ = torch.zeros(B, P, D, device=device, dtype=dtype)
        key_mask = pm.view(B, 1, 1, P)
        inv_tau = 1.0 / max(tau, 1e-6)

        with torch.amp.autocast(device_type='cuda', dtype=amp_dtype, enabled=use_amp):
            for s in range(0, P, p_chunk):
                e = min(P, s + p_chunk)
                Qc_n = Qn[:, s:e, :]                                   
                logits = torch.einsum('bpd,bkqd->bkpq', Qc_n, Mk) * inv_tau
                logits = logits.masked_fill(~key_mask, float('-inf'))
                attn = torch.softmax(logits, dim=-1)                 
                ctx = torch.einsum('bkpq,bqd->bkpd', attn, Qn)
                q_chunk = (topk_val.unsqueeze(-1).unsqueeze(-1) * ctx).sum(dim=1) 
                q_[:, s:e, :] = q_chunk.to(dtype)

        # Reconstruction and anomaly score
        q2    = torch.cat([Qm, q_], dim=-1)                            
        recon = self.dec(q2.reshape(B * P, -1)).reshape(B, P, self.patch_len)
        recon = recon * pm_f.unsqueeze(-1)                              
        mse   = ((recon - tgt) ** 2) * msk.to(recon.dtype)              

        return recon.reshape(B, -1), mse.reshape(B, -1)
