"""
Aggregate projector outputs for the unified memory
"""

import torch
import torch.nn as nn
from torch import Tensor
from transformers import Cache, DynamicCache
import copy
from typing import Optional, Tuple, List
import math


from rosetta.utils.registry import (
    create_registry,
    capture_init_args,
    save_object,
    load_object,
)

# Model Registry System (case-insensitive for backward compatibility)
AGGREGATOR_REGISTRY, register_model, get_aggregator_class = create_registry(
    "aggregator", case_insensitive=True
)


class Aggregator(nn.Module):
    """Base aggregator class for unified memory, aggregating multiple projector outputs into a single key-value pair"""
    
    def forward(self, source_kv_list: List[Tuple[Tensor, Tensor]], target_kv: Tuple[Tensor, Tensor], projected_kv_list: List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tensor]:
        """
        Project and combine the source key-value tensors to the target key-value tensors
        Args:
            source_kv_list: List of (key, value) tensors, each (..., D_s) where ... are arbitrary leading dimensions
            target_kv: Tuple of (key, value) tensors, each (..., D_t) where ... are arbitrary leading dimensions
            projected_kv_list: List of (key, value) tensors, each (..., D_t) where ... are arbitrary leading dimensions
        Returns:
            Tuple of (key, value) tensors, each (..., D_t) with same leading dimensions as input
        """
        raise NotImplementedError("Subclasses must implement forward method")


@register_model
@capture_init_args
class FirstAggregator(Aggregator):
    """
    Select the first projected KV pair. Useful as a default when only one projector is configured.
    """
    def __init__(self):
        super().__init__()

    def forward(self, source_kv_list: List[Tuple[Tensor, Tensor]], target_kv: Tuple[Tensor, Tensor], projected_kv_list: List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tensor]:
        if len(projected_kv_list) == 0:
            # fall back to target_kv if nothing provided
            return target_kv
        return projected_kv_list[0]


@register_model
@capture_init_args
class MeanAggregator(Aggregator):
    """
    Mean aggregator that averages all projected KV pairs element-wise. If list is empty, returns target_kv.
    """
    def __init__(self):
        super().__init__()

    def forward(self, source_kv_list: List[Tuple[Tensor, Tensor]], target_kv: Tuple[Tensor, Tensor], projected_kv_list: List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tensor]:
        if len(projected_kv_list) == 0:
            return target_kv
        # Sum and average keys and values
        sum_key = None
        sum_value = None
        for key_tensor, value_tensor in projected_kv_list:
            if sum_key is None:
                sum_key = key_tensor
                sum_value = value_tensor
            else:
                sum_key = sum_key + key_tensor
                sum_value = sum_value + value_tensor
        count = len(projected_kv_list)
        return (sum_key / count, sum_value / count)


@register_model
@capture_init_args
class WeightedAggregator(Aggregator):
    """
    Learnable softmax-weighted aggregator over multiple projected KV pairs.
    Always expects a fixed number of inputs (num_options), and learns gate logits
    to soft-select among them. If fewer inputs are provided at runtime, the first
    k gate logits are used.
    """
    def __init__(self, 
                 num_options: int, 
                 final_temperature: float = 1.0, 
                 anneal_steps: int = 440,
                 initial_temperature: Optional[float] = None,
                 dtype=torch.float32):
        super().__init__()
        assert num_options >= 1, "num_options must be >= 1"
        self.num_options = num_options

        # Annealing configuration
        self.anneal_steps = anneal_steps
        self.initial_temperature = float(initial_temperature if initial_temperature is not None else final_temperature)
        self.final_temperature = float(final_temperature)

        # Track the current effective temperature as a buffer for easy device moves and checkpointing
        self.register_buffer("current_temperature", torch.tensor(self.initial_temperature, dtype=torch.float32))
        self.gate_logits = nn.Parameter(torch.zeros(num_options, dtype=dtype))

    def update_temperature(self, step: int):
        """
        Update the softmax temperature using an exponential annealing schedule.
        Starts from `initial_temperature` and anneals towards `final_temperature` over `anneal_steps`.
        """
        ratio = min(max(float(step), 0.0) / max(float(self.anneal_steps), 1.0), 1.0)
        # Exponential interpolation for temperature
        temp = self.initial_temperature * ((self.final_temperature / self.initial_temperature) ** ratio)
        self.current_temperature.fill_(float(temp))

    def forward(self, source_kv_list: List[Tuple[Tensor, Tensor]], target_kv: Tuple[Tensor, Tensor], projected_kv_list: List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tensor]:
        if len(projected_kv_list) == 0:
            return target_kv
        # Compute softmax weights over available options
        k = len(projected_kv_list)
        # Use fixed temperature defined by final_temperature
        effective_temp = float(self.final_temperature)
        logits = self.gate_logits[:k] / max(effective_temp, 1e-6)
        weights = torch.softmax(logits, dim=0)

        # Weighted sum of keys and values
        weighted_key = None
        weighted_value = None
        for idx, (key_tensor, value_tensor) in enumerate(projected_kv_list):
            w = weights[idx]
            # Expand weight to broadcast over tensor dims
            while w.dim() < key_tensor.dim():
                w = w.unsqueeze(0)
            if weighted_key is None:
                weighted_key = w * key_tensor
                weighted_value = w * value_tensor
            else:
                weighted_key = weighted_key + w * key_tensor
                weighted_value = weighted_value + w * value_tensor

        return (weighted_key, weighted_value)


# Convenience helpers for persistence
def save_aggregator(obj: Aggregator, file_path: str) -> None:
    save_object(obj, file_path)


def load_aggregator(file_path: str, override_args: Optional[dict] = None) -> Aggregator:
    return load_object(file_path, get_aggregator_class, override_args)