import torch
import torch.nn as nn


class AstroNormWithLayerNorm(nn.Module):
    def __init__(self, d_model, total_segments):
        super(AstroNormWithLayerNorm, self).__init__()
        self.layer_norm = nn.LayerNorm(d_model)
        self.total_segments = total_segments
        self.memory_retention_sum = None
        self.memory_retention = None

    def calculate_area(self, device):
        integer_x_values = torch.arange(0, self.total_segments + 1, device=device)
        current_area = torch.sum(self.memory_retention_function(integer_x_values, self.total_segments))
        return current_area

    def scaling_factor(self, device):
        return 1 / self.calculate_area(device)

    def memory_retention_function(self, x, x_max):
        return 0.049787 + torch.exp(-x / (x_max / 2)) - torch.exp(-(x / (x_max / 2) + 1))

    def scaled_memory_retention(self, x, x_max, device):
        k = self.scaling_factor(device)
        return k * self.memory_retention_function(x, x_max)

    def forward(self, mem, segment):
        device = mem.device

        if not isinstance(segment, torch.Tensor):
            segment = torch.tensor(segment, dtype=mem.dtype, device=device)

        if not mem.requires_grad:
            mem.requires_grad_(True)
        if not segment.requires_grad:
            segment.requires_grad_(True)

        memory_retention_factor = self.scaled_memory_retention(segment, self.total_segments, device)
        self.memory_retention = memory_retention_factor * mem

        self.layer_norm = self.layer_norm.to(device)

        if segment.item() == 0:
            self.memory_retention_sum = torch.zeros_like(mem, device=device)

        self.memory_retention_sum += self.memory_retention
        self.memory_retention = self.layer_norm(self.memory_retention)
        self.memory_retention_sum = self.layer_norm(self.memory_retention_sum)

        return None


# Example usage and testing
if __name__ == "__main__":
    d_model = 512
    total_segments = 10

    # Create an instance of the AstroNormWithLayerNorm class
    astro_norm_layer = AstroNormWithLayerNorm(d_model, total_segments)

    # Create dummy memory tensor
    mem = torch.randn(5, d_model).cuda()  # Example memory tensor

    # Simulate processing multiple segments sequentially in a loop
    for segment in [0, 5, 10]:
        astro_norm_layer(mem, segment)
        print(
            f"Output for segment {segment}:\nMemory Retention:\n{astro_norm_layer.memory_retention}\nMemory Retention Sum:\n{astro_norm_layer.memory_retention_sum}\n")
