"""
Entropy-Adaptive Dual-Stream Language Model
Anonymous ICML 2026 Submission

This module implements the IdeaGatedModel architecture with:
- System 1: Quantized base LLM with LoRA adapters (syntactic stream)
- System 2: Semantic projection head (idea stream)
- Dynamic gating mechanism with entropy-based weighting
"""

import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from peft import get_peft_model, LoraConfig, TaskType, prepare_model_for_kbit_training


class IdeaGatedModel(nn.Module):
    """
    Dual-stream language model with entropy-adaptive gating.
    
    Args:
        model_name (str): HuggingFace model identifier
        device (str): Device to load model on ('cuda' or 'cpu')
        alpha_max (float): Maximum gating intensity (default: 0.5)
    """
    
    def __init__(self, model_name, device, alpha_max=0.5):
        super().__init__()
        self.device = device
        self.alpha_max = alpha_max
 

        print(f"Loading Base Model: {model_name}...")
        
        # 1. Quantization Configuration (4-bit NF4 for memory efficiency)
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16, 
            bnb_4bit_use_double_quant=True,
        )

        # 2. Load Base Model (System 1)
        self.base_model = AutoModelForCausalLM.from_pretrained(
            model_name,
            quantization_config=bnb_config,
            device_map={"": self.device}, 
            trust_remote_code=True
        )
        
        # Prepare for LoRA training
        self.base_model = prepare_model_for_kbit_training(self.base_model)

        # 3. Add LoRA Adapters
        print("Injecting LoRA Adapters...")
        peft_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            inference_mode=False,
            r=8,            # Rank
            lora_alpha=32,  # Scaling factor
            lora_dropout=0.05,
            target_modules=["q_proj", "v_proj"]  # Attention projections
        )
        self.base_model = get_peft_model(self.base_model, peft_config)
        self.base_model.print_trainable_parameters()

        # 4. Initialize Idea Head (System 2)
        hidden_size = self.base_model.config.hidden_size
        vocab_size = self.base_model.config.vocab_size
        
        print(f"Initializing Idea Head ({hidden_size} -> {vocab_size})...")
        
        self.idea_head = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, vocab_size)
        ).to(self.device)
        
        # Match LoRA dtype for numerical stability
        self.idea_head = self.idea_head.to(dtype=torch.bfloat16)

    def forward(self, input_ids, alpha=0.0, boost=1.0, return_s1=False):
        """
        Forward pass with dynamic gating.
        
        Args:
            input_ids (torch.Tensor): Input token IDs [batch, seq_len]
            alpha (float): Gating intensity (0.0 = System 1 only, 0.5 = balanced)
            boost (float): Scale alignment factor (default 1.0)
            return_s1 (bool): If True, returns System 1 logits for entropy calculation
            
        Returns:
            final_logits (torch.Tensor): Gated output logits
            idea_logits (torch.Tensor): System 2 logits
            token_logits (torch.Tensor, optional): System 1 logits (if return_s1=True)
        """
        # 1. Forward pass through Base Model (System 1)
        outputs = self.base_model(
            input_ids=input_ids,
            output_hidden_states=True,
            return_dict=True
        )
        
        # Syntactic Stream
        token_logits = outputs.logits
        last_hidden_state = outputs.hidden_states[-1]
        
        # 2. Semantic Stream (System 2)
        idea_logits = self.idea_head(last_hidden_state)
        
        # 3. Dynamic Gating Logic
        if alpha > 0:
            # Mean-center idea logits for stable gating
            idea_centered = idea_logits - idea_logits.mean(dim=-1, keepdim=True)
            
            # Variance alignment (disabled during training with boost=1.0, 
            # enabled during inference with pre-computed global γ ≈ 3.42)
            if boost != 1.0:
                idea_centered = idea_centered * boost

            # Soft gating via sigmoid
            p_idea = torch.sigmoid(idea_centered)
            gate = alpha * torch.log(p_idea + 1e-8)
            
            final_logits = token_logits + gate
        else:
            final_logits = token_logits
            
        if return_s1:
            return final_logits, idea_logits, token_logits
        
        return final_logits, idea_logits
