import os
import copy
from abc import ABC, abstractmethod
from typing import Optional, TYPE_CHECKING

import torch
import torch.nn as nn

if TYPE_CHECKING:
    from hr2r.model.recurrent_transformer import HR2RForCausalLM

from hr2r.model.registry import register_output_updater, get_output_updater_class, capture_init_args


class OutputUpdater(nn.Module, ABC):
    """
    Base class for updating output logits between iterations.
    
    This class is designed to efficiently handle tensors of arbitrary shape (..., vocab_size),
    where the leading dimensions can be any combination of batch, sequence, or other
    dimensions. All operations preserve the leading dimensions and only operate on
    the last dimension for vocabulary operations.
    """

    @abstractmethod
    def forward(
        self,
        logits: torch.Tensor,
        prev_logits: Optional[torch.Tensor] = None,
        iter_depth: int = 0,
        **kwargs
    ) -> torch.Tensor:
        """
        Return updated logits for accumulation.
        
        This method efficiently handles tensors of arbitrary shape, preserving all
        leading dimensions while operating only on the vocabulary dimension.
        
        Args:
            logits: The current iteration logits, shape (..., vocab_size)
            prev_logits: The previous accumulated logits, shape (..., vocab_size) or None for first iteration
            iter_depth: Current iteration depth (0-indexed)
            **kwargs: Additional arguments
        
        Returns:
            The updated accumulated logits, shape (..., vocab_size)
            
        Note:
            All leading dimensions (...) are preserved exactly. The implementation
            is optimized for efficient processing regardless of the number or size
            of leading dimensions (e.g., batch size, sequence length, etc.).
        """


@register_output_updater
@capture_init_args
class NoneUpdater(OutputUpdater):
    """
    No-op output updater that returns current logits without accumulation.
    This is the default behavior to maintain backward compatibility.
    """
    
    def __init__(self):
        super().__init__()
    
    def forward(
        self,
        logits: torch.Tensor,
        prev_logits: Optional[torch.Tensor] = None,
        iter_depth: int = 0,
        **kwargs
    ) -> torch.Tensor:
        """Simply return current logits without any accumulation."""
        return logits


@register_output_updater
@capture_init_args
class AdditiveLogitsUpdater(OutputUpdater):
    """
    Additive output updater that accumulates logits across iterations.
    
    On the first iteration (prev_logits is None), returns current logits.
    On subsequent iterations, returns prev_logits + current logits.
    This allows the model to learn residual corrections to the output.
    """
    
    def __init__(self):
        super().__init__()
    
    def forward(
        self,
        logits: torch.Tensor,
        prev_logits: Optional[torch.Tensor] = None,
        iter_depth: int = 0,
        **kwargs
    ) -> torch.Tensor:
        """
        Accumulate logits additively.
        
        Args:
            logits: Current iteration logits (..., vocab_size)
            prev_logits: Previous accumulated logits (..., vocab_size) or None
            iter_depth: Current iteration depth (0-indexed)
            
        Returns:
            Accumulated logits (..., vocab_size)
        """
        if prev_logits is None:
            # First iteration: return current logits as-is
            return logits
        else:
            # Subsequent iterations: add to accumulated logits
            return prev_logits + logits


@register_output_updater
@capture_init_args
class WeightedAdditiveLogitsUpdater(OutputUpdater):
    """
    Weighted additive output updater with learnable or fixed weights.
    
    Computes weighted sum: alpha * prev_logits + beta * current_logits
    where alpha and beta can be learnable parameters or fixed values.
    """
    
    def __init__(
        self, 
        alpha: float = 1.0, 
        beta: float = 1.0, 
        learnable: bool = False,
        dtype: torch.dtype = torch.bfloat16
    ):
        super().__init__()
        self.learnable = learnable
        
        if learnable:
            # Learnable weights
            self.alpha = nn.Parameter(torch.tensor(alpha, dtype=dtype))
            self.beta = nn.Parameter(torch.tensor(beta, dtype=dtype))
        else:
            # Fixed weights
            self.register_buffer('alpha', torch.tensor(alpha, dtype=dtype))
            self.register_buffer('beta', torch.tensor(beta, dtype=dtype))
    
    def forward(
        self,
        logits: torch.Tensor,
        prev_logits: Optional[torch.Tensor] = None,
        iter_depth: int = 0,
        **kwargs
    ) -> torch.Tensor:
        """
        Compute weighted sum of previous and current logits.
        
        Args:
            logits: Current iteration logits (..., vocab_size)
            prev_logits: Previous accumulated logits (..., vocab_size) or None
            iter_depth: Current iteration depth (0-indexed)
            
        Returns:
            Weighted accumulated logits (..., vocab_size)
        """
        if prev_logits is None:
            # First iteration: return current logits scaled by beta
            return self.beta * logits
        else:
            # Subsequent iterations: weighted combination
            return self.alpha * prev_logits + self.beta * logits


def save_output_updater(updater: OutputUpdater, save_directory: str):
    """Save output updater state dict and configuration."""
    # Use captured initialization arguments from the decorator
    init_args = getattr(updater, '_init_args', {})
    
    # Use natural state_dict - no overrides needed
    state_dict = updater.state_dict()
    state_dict = {k: v.cpu() for k, v in state_dict.items()}
    data = {
        "class": updater.__class__.__name__,
        "state_dict": state_dict,
        "init_args": init_args,
    }
    
    save_path = os.path.join(save_directory, "output_updater.bin")
    print(f"Saving output updater with {len(state_dict)} parameters to {save_path}")
    torch.save(data, save_path)


def load_output_updater(load_directory: str, class_name: Optional[str] = None, init_args: Optional[dict] = None) -> OutputUpdater:
    """Load output updater from directory."""
    path = os.path.join(load_directory, "output_updater.bin")
    
    if not os.path.isfile(path):
        raise FileNotFoundError(f"No output updater found at {path}")
    
    data = torch.load(path, map_location="cpu")
    if class_name is None:
        class_name = data.get("class")
    
    if not class_name:
        raise ValueError("No output updater class specified in saved data")
    
    # Get constructor arguments if available
    if init_args is None:
        init_args = data.get("init_args", {})
    
    # Create updater instance using registry with proper arguments
    updater_class = get_output_updater_class(class_name)
    updater = updater_class(**init_args)
    
    # Load state dict if available - natural loading
    state_dict = data.get("state_dict", {})
    if state_dict:
        # Filter out state_dict keys that conflict with init_args
        filtered_state_dict = {}
        for key, value in state_dict.items():
            if key not in init_args:
                filtered_state_dict[key] = value
            else:
                print(f"Skipping state_dict key '{key}' as it conflicts with init_args")
        
        print(f"Loading output updater state dict with {len(filtered_state_dict)} parameters (filtered from {len(state_dict)})")
        if filtered_state_dict:
            updater.load_state_dict(filtered_state_dict, strict=False)
    
    return updater
