from __future__ import annotations

import math
import os
import random
import re
from typing import Any, Optional, Tuple

import torch
import torch.nn.functional as F
from torch import nn
from torch import Tensor
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    LlamaConfig,
    LlamaForCausalLM,
    LlamaModel,
)

from transformers.cache_utils import DynamicCache
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMSNorm

from src.utils import CfgNode, print0  

class StackMemory(nn.Module):
    
    
    def __init__(self, config: LlamaConfig) -> None:
        super().__init__()
        self.config = config
        self.num_mem_heads = config.num_attention_heads
        self.stack_slots = getattr(config, 'stack_slots', 4)
        self.head_dim = config.hidden_size // self.num_mem_heads
        self.temperature = nn.Parameter(torch.ones(1))

        
        self.k_proj = nn.Linear(config.hidden_size, self.num_mem_heads * self.head_dim)
        self.action_head = nn.Linear(config.hidden_size, 3 * self.num_mem_heads)
        self.gate_proj = nn.Linear(self.head_dim, 1)
        
        
        self.res_weight = nn.Parameter(torch.ones(1))

    def _update_stack(self, stack, mask, actions, k_t):
        
        batch_size, num_heads, slots, _ = stack.size()
        
        
        a_push, a_pop, a_noop = actions.unbind(-1)
        a_push = a_push.unsqueeze(-1).unsqueeze(-1)
        a_pop = a_pop.unsqueeze(-1).unsqueeze(-1)
        a_noop = a_noop.unsqueeze(-1).unsqueeze(-1)

        
        
        push_stack = torch.cat([k_t.unsqueeze(2), stack[:, :, :-1]], dim=2)
        
        pop_stack = torch.cat([stack[:, :, 1:], torch.zeros_like(stack[:, :, :1])], dim=2)
        
        new_stack = push_stack * a_push + pop_stack * a_pop + stack * a_noop
        
        
        
        push_mask = torch.cat([torch.ones_like(mask[:, :, :1]), mask[:, :, :-1]], dim=2)
        
        pop_mask = torch.cat([mask[:, :, 1:], torch.zeros_like(mask[:, :, :1])], dim=2)
        new_mask = (push_mask * a_push.squeeze(-1) + 
                   pop_mask * a_pop.squeeze(-1) + 
                   mask * a_noop.squeeze(-1)).clamp(0, 1)
        
        return new_stack, new_mask

    def forward(
        self,
        hidden_states: Tensor,
        stack_memory: Tensor,
        stack_mask: Tensor,
    ) -> tuple[Tensor, Tensor, Tensor]:
        batch_size, seq_len, _ = hidden_states.size()
        
        
        all_outputs = []
        current_stack = stack_memory
        current_mask = stack_mask
        
        for t in range(seq_len):
            
            k_t = self.k_proj(hidden_states[:, t]).view(
                batch_size, self.num_mem_heads, self.head_dim
            )
            
            
            action_logits = self.action_head(hidden_states[:, t]) / math.sqrt(self.head_dim)
            action_probs = torch.softmax(action_logits.view(batch_size, self.num_mem_heads, 3), dim=-1)
            
            
            current_stack, current_mask = self._update_stack(
                current_stack, current_mask, action_probs, k_t.unsqueeze(2)
            )
            
            
            combined = current_stack  
            mask = current_mask       
            
            scores = self.gate_proj(combined * mask.unsqueeze(-1)).squeeze(-1)  
            gate = torch.softmax(scores, dim=-1)  
            
            weighted_sum = (combined * gate.unsqueeze(-1)).sum(dim=2)  
            sum_mask = mask.sum(dim=-1, keepdim=True).clamp(min=1e-6)  
            output_t = weighted_sum / sum_mask
            
            
            output_t = output_t.view(batch_size, -1) * self.res_weight + hidden_states[:, t]
            all_outputs.append(output_t)

        new_hidden = torch.stack(all_outputs, dim=1)
        return new_hidden, current_stack, current_mask

class CustomLlamaDecoderLayer(LlamaDecoderLayer):
    
    
    def __init__(self, config: LlamaConfig, layer_idx: int) -> None:
        super().__init__(config, layer_idx)
        self.stack_memory = StackMemory(config)
        self.memory_norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        memory: Optional[torch.Tensor] = None,
        memory_mask: Optional[torch.Tensor] = None,
        position_embeddings: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
        
        mem_output, new_memory, new_mask = self.stack_memory(
            self.memory_norm(hidden_states),
            memory,
            memory_mask
        )
        
        
        hidden_states = hidden_states + mem_output
        
        
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        
        attn_output, _, _ = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            position_embeddings=position_embeddings,
        )
        
        
        hidden_states = residual + attn_output
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states, new_memory, new_mask

class CustomLlamaModel(LlamaModel):
    def __init__(self, config: LlamaConfig):
        super().__init__(config)
        
        self.layers = nn.ModuleList([
            CustomLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)
        ])
        
        self.memory = None
        self.memory_mask = None

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        memory: Optional[torch.Tensor] = None,
        memory_mask: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
        if input_ids is not None:
            batch_size, seq_length = input_ids.shape
        else:
            batch_size, seq_length = inputs_embeds.shape[:2]

        
        if memory is None:
            memory = torch.zeros(
                batch_size,
                self.config.num_attention_heads,
                self.config.stack_slots,
                self.config.hidden_size // self.config.num_attention_heads,
                device=input_ids.device if input_ids is not None else inputs_embeds.device
            )
        if memory_mask is None:
            memory_mask = torch.zeros(
                batch_size,
                self.config.num_attention_heads,
                self.config.stack_slots,
                device=input_ids.device if input_ids is not None else inputs_embeds.device
            )

        
        inputs_embeds = self.embed_tokens(input_ids)
        position_embeddings = self.embed_positions(inputs_embeds, position_ids)

        
        for layer in self.layers:
            inputs_embeds, memory, memory_mask = layer(
                hidden_states=inputs_embeds,
                attention_mask=attention_mask,
                position_ids=position_ids,
                memory=memory,
                memory_mask=memory_mask,
                position_embeddings=position_embeddings,
            )

        
        hidden_states = self.norm(inputs_embeds)

        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=None,
            hidden_states=None,
            attentions=None,
        ), memory, memory_mask

class LlamaMem(LlamaForCausalLM):
    
    
    def __init__(self, config: LlamaConfig, tokenizer: AutoTokenizer) -> None:
        super().__init__(config)
        self.tokenizer = tokenizer
        self.model = CustomLlamaModel(config)  
        self.post_init()  

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        memory: Optional[torch.Tensor] = None,
        memory_mask: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
        
        outputs, memory, memory_mask = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            memory=memory,
            memory_mask=memory_mask,
            **kwargs,
        )
        
        
        logits = self.lm_head(outputs.last_hidden_state)
        
        return CausalLMOutputWithPast(
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        ), memory, memory_mask

    def configure_optimizers(self, train_config):
        
        return torch.optim.AdamW(self.parameters(), lr=train_config.learning_rate)

class CustomLlamaConfig(LlamaConfig):
    
    
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.stack_slots = kwargs.get("stack_slots", 4)
        self.memory_layers = kwargs.get("memory_layers", 4)
        self.memory_init = kwargs.get("memory_init", "zeros")

if __name__ == "__main__":
    
    tokenizer = AutoTokenizer.from_pretrained("/mnt/bd/arxivdata/TinyLlama-1.1B-Chat-v1.0")
    config = CustomLlamaConfig(
        num_hidden_layers=12,
        num_attention_heads=8,
        hidden_size=512,
        stack_slots=8,
        memory_layers=4,
    )
    model = LlamaMem(config, tokenizer)
    
    input_ids = torch.randint(0, 32000, (2, 16))  
    outputs, memory, memory_mask = model(input_ids)
    print(outputs.logits.shape)  