from __future__ import annotations

import math

import torch
import torch.nn.functional as F  
from torch import nn


class RepeatLinear(nn.Module):
    

    def __init__(
        self,
        in_dim: int,
        out_dim: int,
    ) -> None:
        super().__init__()
        
        
        self.w = nn.Parameter(torch.randn(in_dim).cuda())
        self.linear = nn.Linear(in_dim, out_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        w = self.w.unsqueeze(0).repeat(x.size(0), 1, 1)

        
        x = torch.relu(w * x)
        x = torch.mean(x, dim=1)
        return self.linear(x)



class GroupLinearLayer(nn.Module):
    def __init__(
        self,
        in_dim,  
        out_dim,  
        a=None,  
    ) -> None:
        super().__init__()  
        if a is None:
            a = 1.0 / math.sqrt(
                out_dim
            )  
        self.linear = nn.Linear(in_dim, out_dim)  
        self.linear.weight.data.uniform_(
            -a, a
        )  
        self.linear.bias.data.uniform_(-a, a)  

    def forward(self, x):
        x = self.linear(x)  
        return x  


class MemoryModule(nn.Module):
    

    def __init__(
        self,
        mem_slots: int,  
        head_size: int,  
        hidden_dim: int,  
        attn_drop: float = 0.9,  
        num_heads: int = 1,  
        num_blocks: int = 1,  
        forget_bias: float = 1.0,  
        input_bias: float = 0.0,  
        attention_mlp_layers: int = 2,  
        use_topk: bool = False,  
        topk: int = 3,  
    ) -> None:
        super().__init__()

        self.mem_slots = mem_slots
        self.head_size = head_size
        self.hidden_dim = hidden_dim
        self.n_heads = num_heads
        self.use_topk = use_topk
        self.topk = topk
        self.attn_drop = nn.Dropout(attn_drop)

        if num_blocks < 1:
            msg = f"num blocks must be >= 1. Got: {num_blocks}"
            raise ValueError(msg)
        self.num_blocks = num_blocks
        self.num_atten_mlp_layers = attention_mlp_layers

        self.query_proj = nn.Linear(self.hidden_dim, self.mem_slots)
        self.key_proj = nn.Linear(self.mem_slots, self.mem_slots)
        self.value_proj = nn.Linear(self.mem_slots, self.mem_slots)

        
        self.attention_mlp = nn.ModuleList(
            [nn.Linear(self.mem_slots, self.mem_slots)] * self.num_atten_mlp_layers
        )
        self.attended_memory_layernorm = nn.LayerNorm(self.mem_slots)
        self.attended_memory_layernorm2 = nn.LayerNorm(self.mem_slots)

        
        self.num_gates = 2 * self.calculate_gate_size()

        
        self.input_gate_projector = RepeatLinear(
            in_dim=self.mem_slots, out_dim=self.num_gates
        )
        
        self.memory_gate_projector = GroupLinearLayer(
            in_dim=self.mem_slots, out_dim=self.num_gates
        )

        
        self.forget_bias = nn.Parameter(torch.tensor(forget_bias, dtype=torch.float32))
        self.input_bias = nn.Parameter(torch.tensor(input_bias, dtype=torch.float32))

    def multi_head_attention(
        self,
        ipts: torch.Tensor,
        memory: torch.Tensor,
        attention_mask: torch.Tensor,
    ) -> torch.Tensor:
        
        b, t, c1 = ipts.size()
        _, m, c2 = memory.size()

        
        if t < m:
            
            ipts = F.interpolate(ipts.transpose(1, 2), size=m, mode="linear").transpose(
                1, 2
            )
        elif t > m:
            
            ipts = F.adaptive_avg_pool1d(ipts.transpose(1, 2), m).transpose(1, 2)

        
        q = self.query_proj(ipts)  
        
        k = self.key_proj(memory)
        
        v = self.value_proj(memory)

        
        q = q.reshape(b, m, self.n_heads, -1).transpose(
            1, 2
        )  
        k = k.reshape(k.size(0), k.size(1), self.n_heads, -1).transpose(
            1, 2
        )  
        v = v.reshape(v.size(0), v.size(1), self.n_heads, -1).transpose(
            1, 2
        )  

        
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))  

        if m != t:
            raise ValueError(f"Memory length M {m} must be equal sequence length T {t} for causal masking.")

        causal_mask = ~torch.tril(torch.ones((t, m), dtype=torch.bool, device=att.device))  
        causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)  

        if attention_mask is not None:
            attention_mask = attention_mask.to(dtype=torch.bool)  
            
            combined_mask = causal_mask | attention_mask  
        else:
            combined_mask = causal_mask

        attn_bias = combined_mask.to(dtype=torch.float32).masked_fill(combined_mask, float("-inf"))

        att = att + attn_bias
        att = F.softmax(att, dim=-1)  
        att = self.attn_drop(att)

        if self.use_topk:
            
            topk = torch.topk(att, dim=-1, k=self.topk)
            mask = torch.zeros_like(att).to(att.device)
            mask.scatter_(3, topk.indices, 1)
            att = att * mask

        output = att @ v  
        return output.transpose(1, 2).contiguous().view(b, t, self.n_heads * v.size(-1))

    def attend_over_memory(
        self,
        inputs: torch.Tensor,
        memory: torch.Tensor,
        attention_mask: torch.Tensor,
    ) -> torch.Tensor:
        
        for _ in range(self.num_blocks):
            attended_memory = self.multi_head_attention(inputs, memory, attention_mask)
            memory = self.attended_memory_layernorm(memory + attended_memory)

            
            attention_mlp = (
                
                memory
            )  
            for i, _ in enumerate(self.attention_mlp):
                attention_mlp = self.attention_mlp[i](
                    attention_mlp
                )  
                attention_mlp = F.relu(attention_mlp)  
            
            memory = self.attended_memory_layernorm2(memory + attention_mlp)
        return memory  

    def forward(
        self,
        inputs: torch.Tensor,
        memory: torch.Tensor,
        attention_mask: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        
        inputs = inputs.view(inputs.shape[0], inputs.shape[1], -1)

        
        next_memory = self.attend_over_memory(inputs, memory, attention_mask)

        
        input_gate, forget_gate = self.create_gates(inputs, memory)
        
        next_memory = input_gate * torch.tanh(next_memory)
        
        next_memory += forget_gate * memory

        
        return inputs + next_memory, next_memory

    def calculate_gate_size(self) -> int:
        
        return self.mem_slots  

    def create_gates(
        self, inputs: torch.Tensor, memory: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
        
        memory = torch.tanh(memory)  
        shape_dim = 3

        if len(inputs.shape) == shape_dim:
            
            gate_inputs = self.input_gate_projector(
                inputs
            )  
            
            gate_inputs = gate_inputs.unsqueeze(1)

            
            gate_memory = self.memory_gate_projector(
                memory
            )  
        else:
            
            msg = f"input shape of create_gate function is {inputs.shape}, expects 3"
            raise ValueError(msg)

        
        gates = gate_memory + gate_inputs  

        
        gates = torch.split(
            gates, split_size_or_sections=int(gates.shape[2] / 2), dim=2
        )  
        input_gate, forget_gate = gates  

        
        if input_gate.shape[2] != forget_gate.shape[2]:
            raise ValueError

        
        input_gate = torch.sigmoid(
            input_gate + self.input_bias
        )  
        forget_gate = torch.sigmoid(
            forget_gate + self.forget_bias
        )  

        return input_gate, forget_gate  
