"""
Modular IntervalNet Architecture

A complete implementation of the IntervalNet continual learning architecture
that can be used in a modular fashion. This file contains all the essential
components from the original codebase combined into a single, self-contained module.

The architecture includes built-in loss computation, so you can simply call 
model.compute_loss(x, y) to get the interval-specific loss.

Usage:
    from modular_intervalnet import IntervalNet
    
    # Create model
    model = IntervalNet.create_mlp(
        input_size=784, hidden_dim=512, output_classes=10,
        heads=1, radius_multiplier=1.0, max_radius=1.0
    )
    
    # Training loop
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    for epoch in range(num_epochs):
        for batch_idx, (data, targets) in enumerate(dataloader):
            loss = model.compute_loss(data, targets)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            model.after_step()  # Clamp radii
"""

import math
from abc import ABC
from enum import Enum
from typing import cast, Optional, Any, Dict, List, Tuple, Sequence
from collections import deque
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.nn.parameter import Parameter
from torch.optim import Optimizer


# Constants
RADIUS_MIN = 0.0


class Mode(Enum):
    """Training modes for IntervalNet"""
    VANILLA = 0
    EXPANSION = 1
    CONTRACTION_SHIFT = 2
    CONTRACTION_SCALE = 3


class IntervalModuleWithWeights(nn.Module, ABC):
    """Base class for interval modules with weight parameters"""
    
    def __init__(self):
        super().__init__()
    
    def switch_mode(self, mode: Mode) -> None:
        """Switch training mode"""
        raise NotImplementedError
    
    def freeze_task(self) -> None:
        """Freeze current task parameters"""
        raise NotImplementedError
    
    def clamp_radii(self) -> None:
        """Clamp radii to valid ranges"""
        raise NotImplementedError


class PointLinear(nn.Module):
    """Standard linear layer for point estimates (non-interval)"""
    
    def __init__(self, in_features: int, out_features: int):
        super().__init__()
        self.weight = Parameter(torch.empty((out_features, in_features)))
        self.bias = Parameter(torch.empty(out_features))
        self.reset_parameters()
        self.mode = Mode.VANILLA

    def reset_parameters(self) -> None:
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        with torch.no_grad():
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
            nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, x):
        x = x.refine_names("N", "bounds", "features")
        assert (x.rename(None) >= 0.0).all(), "All input features must be non-negative."

        x_lower, x_middle, x_upper = map(lambda x_: cast(Tensor, x_.rename(None)), x.unbind("bounds"))
        assert (x_lower <= x_middle).all(), "Lower bound must be less than or equal to middle bound."
        assert (x_middle <= x_upper).all(), "Middle bound must be less than or equal to upper bound."

        w_middle_pos = self.weight.clamp(min=0)
        w_middle_neg = self.weight.clamp(max=0)

        lower = x_lower @ w_middle_pos.t() + x_upper @ w_middle_neg.t() + self.bias
        upper = x_upper @ w_middle_pos.t() + x_lower @ w_middle_neg.t() + self.bias
        middle = x_middle @ w_middle_pos.t() + x_middle @ w_middle_neg.t() + self.bias

        assert (lower <= middle).all(), "Lower bound must be less than or equal to middle bound."
        assert (middle <= upper).all(), "Middle bound must be less than or equal to upper bound."

        return torch.stack([lower, middle, upper], dim=1).refine_names("N", "bounds", "features")

    def switch_mode(self, mode: Mode) -> None:
        self.mode = mode
        # Enable/disable gradients based on mode
        if mode == Mode.VANILLA:
            self.weight.requires_grad_(True)
            self.bias.requires_grad_(True)
        else:
            self.weight.requires_grad_(False)
            self.bias.requires_grad_(False)


class IntervalLinear(IntervalModuleWithWeights):
    """Interval linear layer with learnable bounds"""
    
    def __init__(
        self, in_features: int, out_features: int,
        radius_multiplier: float, max_radius: float, bias: bool = True,
        normalize_shift: bool = True, normalize_scale: bool = False, 
        scale_init: float = 5.0
    ) -> None:
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.radius_multiplier = radius_multiplier
        self.max_radius = max_radius
        self.normalize_shift = normalize_shift
        self.normalize_scale = normalize_scale
        self.scale_init = scale_init

        assert self.radius_multiplier > 0
        assert self.max_radius > 0

        # Main parameters
        self.weight = Parameter(torch.empty((out_features, in_features)))
        
        # Interval parameters
        self._radius = Parameter(torch.empty((out_features, in_features)), requires_grad=False)
        self._shift = Parameter(torch.empty((out_features, in_features)), requires_grad=False)
        self._scale = Parameter(torch.empty((out_features, in_features)), requires_grad=False)

        # Bias parameters
        if bias:
            self.bias = Parameter(torch.empty(out_features), requires_grad=True)
            self._bias_radius = Parameter(torch.empty_like(self.bias), requires_grad=False)
            self._bias_shift = Parameter(torch.empty_like(self.bias), requires_grad=False)
            self._bias_scale = Parameter(torch.empty_like(self.bias), requires_grad=False)
        else:
            self.bias = None
            
        self.mode: Mode = Mode.VANILLA
        self.reset_parameters()

    def radius_transform(self, params: Tensor):
        return (params * torch.tensor(self.radius_multiplier)).clamp(min=RADIUS_MIN, max=self.max_radius + 0.1)

    @property
    def radius(self) -> Tensor:
        return self.radius_transform(self._radius)

    @property
    def bias_radius(self) -> Tensor:
        return self.radius_transform(self._bias_radius)

    @property
    def shift(self) -> Tensor:
        """Contracted interval middle shift (-1, 1)."""
        if self.normalize_shift:
            eps = torch.tensor(1e-8).to(self._shift.device)
            return (self._shift / torch.max(self.radius, eps)).tanh()
        else:
            return self._shift.tanh()

    @property
    def bias_shift(self) -> Tensor:
        """Contracted interval middle shift (-1, 1)."""
        if self.normalize_shift:
            eps = torch.tensor(1e-8).to(self._bias_shift.device)
            return (self._bias_shift / torch.max(self.bias_radius, eps)).tanh()
        else:
            return self._bias_shift.tanh()

    @property
    def scale(self) -> Tensor:
        """Contracted interval scale (0, 1)."""
        if self.normalize_scale:
            eps = torch.tensor(1e-8).to(self._scale.device)
            scale = (self._scale / torch.max(self.radius, eps)).sigmoid()
        else:
            scale = self._scale.sigmoid()
        return scale * (1.0 - torch.abs(self.shift))

    @property
    def bias_scale(self) -> Tensor:
        """Contracted interval scale (0, 1)."""
        if self.normalize_scale:
            eps = torch.tensor(1e-8).to(self._bias_scale.device)
            scale = (self._bias_scale / torch.max(self.radius, eps)).sigmoid()
        else:
            scale = self._bias_scale.sigmoid()
        return scale * (1.0 - torch.abs(self.bias_shift))

    def clamp_radii(self) -> None:
        with torch.no_grad():
            max_val = self.max_radius / self.radius_multiplier
            self._radius.clamp_(min=RADIUS_MIN, max=max_val)
            if self.bias is not None:
                self._bias_radius.clamp_(min=RADIUS_MIN, max=max_val)

    def reset_parameters(self) -> None:
        with torch.no_grad():
            nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
            self._radius.fill_(self.max_radius)
            self._shift.zero_()
            self._scale.fill_(self.scale_init)
            if self.bias is not None:
                fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
                bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
                nn.init.uniform_(self.bias, -bound, bound)
                self.bias.zero_()
                self._bias_radius.fill_(self.max_radius)
                self._bias_shift.zero_()
                self._bias_scale.fill_(self.scale_init)

    def switch_mode(self, mode: Mode) -> None:
        self.mode = mode

        def enable(params: List[Parameter]):
            for p in params:
                p.requires_grad_(True)

        def disable(params: List[Parameter]):
            for p in params:
                p.requires_grad_(False)
                p.grad = None

        # Disable all parameters first
        all_params = [self.weight, self._radius, self._shift, self._scale]
        if self.bias is not None:
            all_params.extend([self.bias, self._bias_radius, self._bias_shift, self._bias_scale])
        disable(all_params)

        # Enable specific parameters based on mode
        if mode == Mode.VANILLA:
            enable([self.weight])
            if self.bias is not None:
                enable([self.bias])
        elif mode == Mode.EXPANSION:
            with torch.no_grad():
                self._radius.fill_(self.max_radius)
                if self.bias is not None:
                    self._bias_radius.fill_(self.max_radius)
            enable([self._radius])
            if self.bias is not None:
                enable([self._bias_radius])
        elif mode == Mode.CONTRACTION_SHIFT:
            enable([self._shift])
            if self.bias is not None:
                enable([self._bias_shift])
        elif mode == Mode.CONTRACTION_SCALE:
            enable([self._scale])
            if self.bias is not None:
                enable([self._bias_scale])
        

    def freeze_task(self) -> None:
        with torch.no_grad():
            self.weight.copy_(self.weight + self.shift * self.radius)
            self._radius.copy_(self.scale * self._radius)
            self._shift.zero_()
            self._scale.fill_(self.scale_init)
            if self.bias is not None:
                self.bias.copy_(self.bias + self.bias_shift * self.bias_radius)
                self._bias_radius.copy_(self.bias_scale * self._bias_radius)
                self._bias_shift.zero_()
                self._bias_scale.fill_(self.scale_init)

    def forward(self, x: Tensor) -> Tensor:
        x = x.refine_names("N", "bounds", "features")
        # print(x.size())
        assert (x.rename(None) >= 0.0).all(), "All input features must be non-negative."

        x_lower, x_middle, x_upper = map(lambda x_: cast(Tensor, x_.rename(None)), x.unbind("bounds"))
        assert (x_lower <= x_middle).all(), "Lower bound must be less than or equal to middle bound."
        assert (x_middle <= x_upper).all(), "Middle bound must be less than or equal to upper bound."

        if self.mode in [Mode.VANILLA, Mode.EXPANSION]:
            w_middle: Tensor = self.weight
            w_lower = self.weight - self.radius
            w_upper = self.weight + self.radius
        else:
            assert self.mode in [Mode.CONTRACTION_SHIFT, Mode.CONTRACTION_SCALE]
            assert (0.0 <= self.scale).all() and (self.scale <= 1.0).all(), "Scale must be in [0, 1] range."
            assert (-1.0 <= self.shift).all() and (self.shift <= 1.0).all(), "Shift must be in [-1, 1] range."

            w_middle = self.weight + self.shift * self.radius
            w_lower = w_middle - self.scale * self.radius
            w_upper = w_middle + self.scale * self.radius

        w_lower_pos = w_lower.clamp(min=0)
        w_lower_neg = w_lower.clamp(max=0)
        w_upper_pos = w_upper.clamp(min=0)
        w_upper_neg = w_upper.clamp(max=0)
        w_middle_pos = w_middle.clamp(min=0)
        w_middle_neg = w_middle.clamp(max=0)

        lower = x_lower @ w_lower_pos.t() + x_upper @ w_lower_neg.t()
        upper = x_upper @ w_upper_pos.t() + x_lower @ w_upper_neg.t()
        middle = x_middle @ w_middle_pos.t() + x_middle @ w_middle_neg.t()

        if self.bias is not None:
            b_middle = self.bias + self.bias_shift * self.bias_radius
            b_lower = b_middle - self.bias_scale * self.bias_radius
            b_upper = b_middle + self.bias_scale * self.bias_radius
            lower = lower + b_lower
            upper = upper + b_upper
            middle = middle + b_middle

        assert (lower <= middle).all(), "Lower bound must be less than or equal to middle bound."
        assert (middle <= upper).all(), "Middle bound must be less than or equal to upper bound."

        return torch.stack([lower, middle, upper], dim=1).refine_names("N", "bounds", "features")


class IntervalConv2d(nn.Conv2d, IntervalModuleWithWeights):
    """Interval convolutional layer"""
    
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        radius_multiplier: float,
        max_radius: float,
        stride: int = 1,
        padding: int = 0,
        dilation: int = 1,
        groups: int = 1,
        bias: bool = True,
        normalize_shift: bool = True,
        normalize_scale: bool = False,
        scale_init: float = -5.0
    ) -> None:
        IntervalModuleWithWeights.__init__(self)
        nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
        self.radius_multiplier = radius_multiplier
        self.max_radius = max_radius
        self.normalize_shift = normalize_shift
        self.normalize_scale = normalize_scale
        self.scale_init = scale_init

        assert self.radius_multiplier > 0
        assert self.max_radius > 0

        self._radius = Parameter(torch.empty_like(self.weight), requires_grad=False)
        self._shift = Parameter(torch.empty_like(self.weight), requires_grad=False)
        self._scale = Parameter(torch.empty_like(self.weight), requires_grad=False)
        
        if bias:
            self._bias_radius = Parameter(torch.empty_like(self.bias), requires_grad=False)
            self._bias_shift = Parameter(torch.empty_like(self.bias), requires_grad=False)
            self._bias_scale = Parameter(torch.empty_like(self.bias), requires_grad=False)
            
        self.mode: Mode = Mode.VANILLA
        self.init_parameters()

    def radius_transform(self, params: Tensor):
        return (params * torch.tensor(self.radius_multiplier)).clamp(min=RADIUS_MIN, max=self.max_radius)

    @property
    def radius(self) -> Tensor:
        return self.radius_transform(self._radius)

    @property
    def bias_radius(self) -> Tensor:
        return self.radius_transform(self._bias_radius)

    @property
    def shift(self) -> Tensor:
        if self.normalize_shift:
            eps = torch.tensor(1e-8).to(self._shift.device)
            return (self._shift / torch.max(self.radius, eps)).tanh()
        else:
            return self._shift.tanh()

    @property
    def bias_shift(self) -> Tensor:
        if self.normalize_shift:
            eps = torch.tensor(1e-8).to(self._bias_shift.device)
            return (self._bias_shift / torch.max(self.bias_radius, eps)).tanh()
        else:
            return self._bias_shift.tanh()

    @property
    def scale(self) -> Tensor:
        if self.normalize_scale:
            eps = torch.tensor(1e-8).to(self._scale.device)
            scale = (self._scale / torch.max(self.radius, eps)).sigmoid()
        else:
            scale = self._scale.sigmoid()
        return scale * (1.0 - torch.abs(self.shift))

    @property
    def bias_scale(self) -> Tensor:
        if self.normalize_scale:
            eps = torch.tensor(1e-8).to(self._bias_scale.device)
            scale = (self._bias_scale / torch.max(self.radius, eps)).sigmoid()
        else:
            scale = self._bias_scale.sigmoid()
        return scale * (1.0 - torch.abs(self.bias_shift))

    def clamp_radii(self) -> None:
        with torch.no_grad():
            max_val = self.max_radius / self.radius_multiplier
            self._radius.clamp_(min=RADIUS_MIN, max=max_val)
            if self.bias is not None:
                self._bias_radius.clamp_(min=RADIUS_MIN, max=max_val)

    def init_parameters(self) -> None:
        with torch.no_grad():
            self._radius.fill_(RADIUS_MIN)
            self._shift.zero_()
            self._scale.fill_(5)
            if self.bias is not None:
                self._bias_radius.fill_(RADIUS_MIN)
                self._bias_shift.zero_()
                self._bias_scale.fill_(5)

    def switch_mode(self, mode: Mode) -> None:
        if self.mode == Mode.VANILLA and mode == Mode.CONTRACTION_SCALE:
            self._radius.fill_(self.max_radius)
            if self.bias is not None:
                self._bias_radius.fill_(self.max_radius)
        
        self.mode = mode

        def enable(params: List[Parameter]):
            for p in params:
                p.requires_grad_(True)

        def disable(params: List[Parameter]):
            for p in params:
                p.requires_grad_(False)
                p.grad = None

        all_params = [self.weight, self._radius, self._shift, self._scale]
        if self.bias is not None:
            all_params.extend([self.bias, self._bias_radius, self._bias_shift, self._bias_scale])
        disable(all_params)

        if mode == Mode.VANILLA:
            enable([self.weight])
            if self.bias is not None:
                enable([self.bias])
        elif mode == Mode.CONTRACTION_SHIFT:
            enable([self._shift])
            if self.bias is not None:
                enable([self._bias_shift])
        elif mode == Mode.CONTRACTION_SCALE:
            enable([self._scale])
            if self.bias is not None:
                enable([self._bias_scale])

    def freeze_task(self) -> None:
        with torch.no_grad():
            self.weight.copy_(self.weight + self.shift * self.radius)
            self._radius.copy_(self.scale * self._radius)
            self._shift.zero_()
            self._scale.fill_(5)
            if self.bias is not None:
                self.bias.copy_(self.bias + self.bias_shift * self.bias_radius)
                self._bias_radius.copy_(self.bias_scale * self._bias_radius)
                self._bias_shift.zero_()
                self._bias_scale.fill_(5)

    def forward(self, x: Tensor) -> Tensor:
        x = x.refine_names("N", "bounds", ...)
        x_lower, x_middle, x_upper = map(lambda x_: cast(Tensor, x_.rename(None)), x.unbind("bounds"))

        if self.mode in [Mode.VANILLA, Mode.EXPANSION]:
            w_middle = self.weight
            w_lower = self.weight - self.radius
            w_upper = self.weight + self.radius
        else:
            w_middle = self.weight + self.shift * self.radius
            w_lower = w_middle - self.scale * self.radius
            w_upper = w_middle + self.scale * self.radius

        # Compute interval convolutions
        lower = F.conv2d(x_lower, w_lower, self.bias, self.stride, self.padding, self.dilation, self.groups)
        middle = F.conv2d(x_middle, w_middle, self.bias, self.stride, self.padding, self.dilation, self.groups)
        upper = F.conv2d(x_upper, w_upper, self.bias, self.stride, self.padding, self.dilation, self.groups)

        return torch.stack([lower, middle, upper], dim=1).refine_names("N", "bounds", ...)


class IntervalNet(nn.Module):
    """Main IntervalNet model class"""
    
    def __init__(self, radius_multiplier: float = 1.0, max_radius: float = 1.0):
        super().__init__()
        self.mode: Mode = Mode.VANILLA
        self._radius_multiplier = radius_multiplier
        self._max_radius = max_radius
        self.output_classes = None
        self.criterion = nn.CrossEntropyLoss()
        
        # Training metrics
        self._accuracy = deque(maxlen=10)
        self._robust_accuracy = deque(maxlen=10)
        
        # Training configuration
        self.robust_accuracy_threshold = 0.8

    def interval_children(self) -> List[IntervalModuleWithWeights]:
        """Get all interval modules in the model"""
        return [m for m in self.modules() if isinstance(m, IntervalModuleWithWeights)]

    def named_interval_children(self) -> List[Tuple[str, IntervalModuleWithWeights]]:
        """Get named interval modules"""
        return [(n, m) for n, m in self.named_modules() if isinstance(m, IntervalModuleWithWeights)]

    def switch_mode(self, mode: Mode) -> None:
        """Switch training mode for all interval modules"""
        self.mode = mode
        for m in self.interval_children():
            m.switch_mode(mode)

    def freeze_task(self) -> None:
        """Freeze current task parameters"""
        for m in self.interval_children():
            m.freeze_task()

    def clamp_radii(self) -> None:
        """Clamp all radii to valid ranges"""
        for m in self.interval_children():
            m.clamp_radii()

    @property
    def radius_multiplier(self):
        return self._radius_multiplier

    @radius_multiplier.setter
    def radius_multiplier(self, value: float):
        self._radius_multiplier = value
        for m in self.interval_children():
            m.radius_multiplier = value

    @property
    def max_radius(self):
        return self._max_radius

    @max_radius.setter
    def max_radius(self, value: float) -> None:
        self._max_radius = value
        for m in self.interval_children():
            m.max_radius = value

    def compute_loss(self, x: torch.Tensor, y: torch.Tensor, task_labels: Optional[torch.Tensor] = None) -> torch.Tensor:
        """Compute interval-specific loss given input and targets"""
        # Forward pass
        output_dict = self.forward(x, task_labels)
        
        # Extract bounds
        mb_output_all = output_dict
        
        # Vanilla loss (middle bound)
        mb_output = mb_output_all["last"][:, 1, :].rename(None)  # middle bound
        vanilla_loss = self.criterion(mb_output, y)
        
        # Robust loss (worst-case bound)
        robust_output = self._get_robust_output(mb_output_all["last"], y)
        robust_loss = self.criterion(robust_output, y)
        
        # Compute accuracies
        self._accuracy.appendleft(self._compute_accuracy(mb_output, y))
        self._robust_accuracy.appendleft(self._compute_accuracy(robust_output, y))
        
        # Total loss depends on training mode
        if self.mode == Mode.VANILLA:
            total_loss = vanilla_loss
        elif self.mode == Mode.CONTRACTION_SHIFT:
            total_loss = vanilla_loss
        elif self.mode == Mode.CONTRACTION_SCALE:
            # Apply robust penalty if robust accuracy is low
            if self.robust_accuracy() < (self.robust_accuracy_threshold * self.accuracy()):
                total_loss = robust_loss
            else:
                total_loss = robust_loss * 0.0  # Effectively zero loss
        else:
            total_loss = vanilla_loss
        
        return total_loss

    def _get_robust_output(self, output_bounds: Tensor, targets: Tensor) -> Tensor:
        """Get robust version of output (lower bound for correct class, upper for incorrect)"""
        output_lower, _, output_upper = output_bounds.unbind("bounds")
        y_oh = F.one_hot(targets, num_classes=self.output_classes)
        return torch.where(y_oh.bool(), output_lower.rename(None), output_upper.rename(None))

    def _compute_accuracy(self, predictions: Tensor, targets: Tensor) -> Tensor:
        """Compute classification accuracy"""
        pred_classes = predictions.argmax(dim=-1)
        return (pred_classes == targets).float().mean()

    def accuracy(self, n_last: int = 1) -> float:
        """Moving average of batch accuracy"""
        if not self._accuracy:
            return 0.0
        return torch.stack(list(self._accuracy)[:n_last]).mean().item()

    def robust_accuracy(self, n_last: int = 1) -> float:
        """Moving average of robust batch accuracy"""
        if not self._robust_accuracy:
            return 0.0
        return torch.stack(list(self._robust_accuracy)[:n_last]).mean().item()

    def before_task(self, task_id: int):
        """Called before training a new task"""
        if task_id >= 1:
            self.switch_mode(Mode.CONTRACTION_SHIFT)
            self.freeze_task()
        
        self._accuracy.clear()
        self._robust_accuracy.clear()

    def before_epoch(self, epoch: int, total_epochs: int, contraction_epochs: int = 5):
        """Called before each training epoch"""
        contraction_start = total_epochs - contraction_epochs
        
        if self.mode in [Mode.VANILLA, Mode.CONTRACTION_SHIFT] and epoch == contraction_start:
            self.switch_mode(Mode.CONTRACTION_SCALE)
        

    def after_step(self):
        """Called after each optimization step"""
        self.clamp_radii()

    def predict(self, x: torch.Tensor, task_labels: Optional[torch.Tensor] = None, return_bounds: bool = False) -> torch.Tensor:
        """Get model predictions
        
        Args:
            x: Input tensor
            task_labels: Task labels for multi-task scenarios
            return_bounds: If True, returns all bounds (lower, middle, upper), 
                         if False, returns only middle bound (standard predictions)
        
        Returns:
            Predictions tensor. Shape depends on return_bounds:
            - If return_bounds=False: (batch_size, num_classes) - middle bound predictions
            - If return_bounds=True: (batch_size, 3, num_classes) - all bounds [lower, middle, upper]
        """
        self.eval()  # Set to evaluation mode
        with torch.no_grad():
            output_dict = self.forward(x, task_labels)
            predictions = output_dict["last"]  # Shape: (batch_size, 3, num_classes)
            
            if return_bounds:
                return predictions.rename(None)  # Return all bounds
            else:
                # Return only middle bound (standard predictions)
                return predictions[:, 1, :].rename(None)

    def predict_proba(self, x: torch.Tensor, task_labels: Optional[torch.Tensor] = None, return_bounds: bool = False) -> torch.Tensor:
        """Get model prediction probabilities using softmax
        
        Args:
            x: Input tensor
            task_labels: Task labels for multi-task scenarios  
            return_bounds: If True, returns probabilities for all bounds,
                         if False, returns only middle bound probabilities
        
        Returns:
            Probability predictions with softmax applied
        """
        logits = self.predict(x, task_labels, return_bounds)
        
        if return_bounds:
            # Apply softmax to each bound separately
            probs = torch.softmax(logits, dim=-1)
            return probs
        else:
            # Apply softmax to middle bound
            probs = torch.softmax(logits, dim=-1)
            return probs

    def predict_classes(self, x: torch.Tensor, task_labels: Optional[torch.Tensor] = None, use_robust: bool = False) -> torch.Tensor:
        """Get predicted class labels
        
        Args:
            x: Input tensor
            task_labels: Task labels for multi-task scenarios
            use_robust: If True, uses robust predictions (worst-case bounds),
                       if False, uses middle bound predictions
        
        Returns:
            Predicted class indices (batch_size,)
        """
        if use_robust:
            # Use robust predictions (pessimistic)
            with torch.no_grad():
                output_dict = self.forward(x, task_labels)
                bounds = output_dict["last"]
                
                # Create dummy targets for robust output computation
                # We'll use the middle bound predictions to get the predicted classes first
                middle_preds = bounds[:, 1, :].rename(None)
                predicted_classes = middle_preds.argmax(dim=-1)
                
                # Now compute robust output using predicted classes
                robust_output = self._get_robust_output(bounds, predicted_classes)
                return robust_output.argmax(dim=-1)
        else:
            # Use middle bound predictions
            logits = self.predict(x, task_labels, return_bounds=False)
            return logits.argmax(dim=-1)

    def get_prediction_bounds(self, x: torch.Tensor, task_labels: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
        """Get detailed prediction bounds information
        
        Args:
            x: Input tensor
            task_labels: Task labels for multi-task scenarios
            
        Returns:
            Dictionary containing:
            - 'lower': Lower bound predictions
            - 'middle': Middle bound predictions (standard predictions) 
            - 'upper': Upper bound predictions
            - 'width': Interval width (upper - lower)
            - 'uncertainty': Normalized uncertainty measure
        """
        bounds = self.predict(x, task_labels, return_bounds=True)
        
        lower = bounds[:, 0, :]
        middle = bounds[:, 1, :]
        upper = bounds[:, 2, :]
        width = upper - lower
        
        # Compute uncertainty as the average interval width
        uncertainty = width.mean(dim=-1)
        
        return {
            'lower': lower,
            'middle': middle, 
            'upper': upper,
            'width': width,
            'uncertainty': uncertainty
        }


    @classmethod
    def create_mlp(
        cls,
        input_size: int,
        hidden_dim: int,
        output_classes: int,
        heads: int = 1,
        radius_multiplier: float = 1.0,
        max_radius: float = 1.0,
        bias: bool = True,
        normalize_shift: bool = True,
        normalize_scale: bool = False,
        scale_init: float = 5.0,
    ) -> "IntervalMLP":
        """Create an IntervalMLP model"""
        return IntervalMLP(
            input_size=input_size,
            hidden_dim=hidden_dim,
            output_classes=output_classes,
            heads=heads,
            radius_multiplier=radius_multiplier,
            max_radius=max_radius,
            bias=bias,
            normalize_shift=normalize_shift,
            normalize_scale=normalize_scale,
            scale_init=scale_init,
        )

    @classmethod 
    def create_cnn(
        cls,
        in_channels: int,
        output_classes: int,
        heads: int = 1,
        radius_multiplier: float = 1.0,
        max_radius: float = 1.0,
        normalize_shift: bool = True,
        normalize_scale: bool = False,
        scale_init: float = -5.0,
    ) -> "IntervalAlexNet":
        """Create an IntervalAlexNet model"""
        return IntervalAlexNet(
            in_channels=in_channels,
            output_classes=output_classes,
            heads=heads,
            radius_multiplier=radius_multiplier,
            max_radius=max_radius,
            normalize_shift=normalize_shift,
            normalize_scale=normalize_scale,
            scale_init=scale_init,
            act_fn="relu"
        )


class IntervalMLP(IntervalNet):
    """Multi-layer perceptron with interval weights"""
    
    def __init__(
        self,
        input_size: int,
        hidden_dim: int,
        output_classes: int,
        heads: int = 1,
        radius_multiplier: float = 1.0,
        max_radius: float = 1.0,
        bias: bool = True,
        normalize_shift: bool = True,
        normalize_scale: bool = False,
        scale_init: float = 5.0,
    ):
        super().__init__(radius_multiplier=radius_multiplier, max_radius=max_radius)

        self.input_size = input_size
        self.hidden_dim = hidden_dim
        self.output_classes = output_classes
        self.normalize_shift = normalize_shift
        self.normalize_scale = normalize_scale
        self.num_classes = 0

        # Hidden layers
        self.fc1 = IntervalLinear(
            self.input_size, self.hidden_dim,
            radius_multiplier=radius_multiplier, max_radius=max_radius,
            bias=bias, normalize_shift=normalize_shift, normalize_scale=normalize_scale,
            scale_init=scale_init
        )
        self.fc2 = IntervalLinear(
            self.hidden_dim, self.hidden_dim,
            radius_multiplier=radius_multiplier, max_radius=max_radius,
            bias=bias, normalize_shift=normalize_shift, normalize_scale=normalize_scale,
            scale_init=scale_init,
        )

        # Output heads (multi-task support)
        if heads > 1:
            self.last = nn.ModuleList([
                PointLinear(self.hidden_dim, self.output_classes) for _ in range(heads)
            ])
        else:
            self.last = nn.ModuleList([
                IntervalLinear(
                    self.hidden_dim,
                    self.output_classes,
                    radius_multiplier=radius_multiplier,
                    max_radius=max_radius,
                    bias=bias,
                    normalize_shift=normalize_shift,
                    normalize_scale=normalize_scale,
                    scale_init=scale_init,
                )])

    def forward(self, x: torch.Tensor, task_labels: Optional[torch.Tensor] = None) -> torch.Tensor:
        """Forward pass with optional task labels for multi-task scenarios"""
        if task_labels is None:
            task_labels = torch.zeros(x.size(0), dtype=torch.long, device=x.device)
        
        if isinstance(task_labels, int):
            return self.forward_single_task(x, task_labels)
        
        # Multi-task forward pass
        unique_tasks = torch.unique(task_labels)
        if len(unique_tasks) == 1:
            return self.forward_single_task(x, task_labels[0].item())

        full_out = {}
        for task in unique_tasks:
            task_mask = task_labels == task
            x_task = x[task_mask]
            out_task = self.forward_single_task(x_task, task.item())

            if not full_out:
                for key, val in out_task.items():
                    full_out[key] = torch.empty(x.shape[0], *val.shape[1:],
                                                device=val.device).rename(None)
            for key, val in out_task.items():
                full_out[key][task_mask] = val.rename(None)

        for key, val in full_out.items():
            full_out[key] = val.refine_names("N", "bounds", "features")
        return full_out

    def forward_base(self, x: Tensor) -> Dict[str, Tensor]:
        """Forward pass through base layers"""
        x = x.refine_names("N", ...)
        x = x.rename(None)
        x = x.flatten(1)  # (N, features)
        x = x.unflatten(1, (1, -1))  # (N, bounds, features)
        x = x.tile((1, 3, 1))  # Create interval bounds

        x = x.refine_names("N", "bounds", "features")

        fc1 = F.relu(self.fc1(x))
        fc2 = F.relu(self.fc2(fc1))

        return {
            "fc1": fc1,
            "fc2": fc2,
        }

    def forward_single_task(self, x: Tensor, task_id: int) -> Dict[str, Tensor]:
        """Forward pass for a single task"""
        activation_dict = self.forward_base(x)
        activation_dict["last"] = self.last[task_id](activation_dict["fc2"])
        return activation_dict

    @property
    def device(self):
        return self.fc1.weight.device


class IntervalAlexNet(IntervalNet):
    """AlexNet-style CNN with interval weights"""
    
    def __init__(
        self, 
        in_channels: int, 
        output_classes: int, 
        heads: int = 1,
        radius_multiplier: float = 1.0, 
        max_radius: float = 1.0,
        normalize_shift: bool = True, 
        normalize_scale: bool = False, 
        scale_init: float = -5.0,
        act_fn: str = "relu"
    ):
        super().__init__(radius_multiplier=radius_multiplier, max_radius=max_radius)

        self.output_classes = output_classes
        self.normalize_shift = normalize_shift
        self.normalize_scale = normalize_scale
        self.scale_init = scale_init
        
        if act_fn == "relu":
            self.act_fn = nn.ReLU(inplace=True)
        elif act_fn == "relu6":
            self.act_fn = nn.ReLU6()
        else:
            raise NotImplementedError(f"Activation function {act_fn} not supported")

        # Feature extractor
        self.features = nn.Sequential(
            IntervalConv2d(in_channels, 64, kernel_size=3, stride=2, padding=1,
                           radius_multiplier=radius_multiplier, max_radius=max_radius,
                           normalize_shift=normalize_shift, normalize_scale=normalize_scale,
                           scale_init=scale_init),
            self.act_fn,
            IntervalMaxPool2d(kernel_size=2),
            IntervalConv2d(64, 192, kernel_size=3, padding=1,
                           radius_multiplier=radius_multiplier, max_radius=max_radius,
                           normalize_shift=normalize_shift, normalize_scale=normalize_scale,
                           scale_init=scale_init),
            self.act_fn,
            IntervalMaxPool2d(kernel_size=2),
            IntervalConv2d(192, 384, kernel_size=3, padding=1,
                           radius_multiplier=radius_multiplier, max_radius=max_radius,
                           normalize_shift=normalize_shift, normalize_scale=normalize_scale,
                           scale_init=scale_init),
            self.act_fn,
            IntervalConv2d(384, 256, kernel_size=3, padding=1,
                           radius_multiplier=radius_multiplier, max_radius=max_radius,
                           normalize_shift=normalize_shift, normalize_scale=normalize_scale,
                           scale_init=scale_init),
            self.act_fn,
            IntervalConv2d(256, 256, kernel_size=3, padding=1,
                           radius_multiplier=radius_multiplier, max_radius=max_radius,
                           normalize_shift=normalize_shift, normalize_scale=normalize_scale,
                           scale_init=scale_init),
            self.act_fn,
            IntervalMaxPool2d(kernel_size=2),
        )
        
        # Classifier
        self.classifier = nn.Sequential(
            IntervalDropout(),
            IntervalLinear(256 * 2 * 2, 4096, bias=True,
                          radius_multiplier=radius_multiplier, max_radius=max_radius,
                          normalize_shift=normalize_shift, normalize_scale=normalize_scale,
                          scale_init=scale_init),
            self.act_fn,
            IntervalDropout(),
            IntervalLinear(4096, 4096, bias=True,
                          radius_multiplier=radius_multiplier, max_radius=max_radius,
                          normalize_shift=normalize_shift, normalize_scale=normalize_scale,
                          scale_init=scale_init),
            self.act_fn,
        )
        
        # Output heads
        if heads > 1:
            self.last = nn.ModuleList([
                PointLinear(4096, self.output_classes) for _ in range(heads)
            ])
        else:
            self.last = nn.ModuleList([
                IntervalLinear(
                    4096,
                    output_classes,
                    radius_multiplier=radius_multiplier,
                    max_radius=max_radius,
                    bias=True,
                    normalize_shift=normalize_shift,
                    normalize_scale=normalize_scale,
                    scale_init=scale_init,
                )])

    def forward_single_task(self, x: Tensor, task_id: int):
        x = x.unsqueeze(1).tile(1, 3, 1, 1, 1)
        x = x.refine_names("N", "bounds", ...)
        x = self.features(x)
        x = x.rename(None).flatten(2).refine_names("N", "bounds", "features")
        x = self.classifier(x)
        activation_dict = {}
        x = self.last[task_id](x)
        activation_dict["last"] = x
        return activation_dict

    def forward(self, x: Tensor, task_labels: torch.Tensor) -> Tensor:
        if isinstance(task_labels, int):
            return self.forward_single_task(x, task_labels)
        else:
            unique_tasks = torch.unique(task_labels)
            if len(unique_tasks) == 1:
                return self.forward_single_task(x, task_labels[0])

        full_out = {}
        for task in unique_tasks:
            task_mask = task_labels == task
            x_task = x[task_mask]
            out_task = self.forward_single_task(x_task, task.item())

            if not full_out:
                for key, val in out_task.items():
                    full_out[key] = torch.empty(x.shape[0], *val.shape[1:],
                                                device=val.device).rename(None)
            for key, val in out_task.items():
                full_out[key][task_mask] = val.rename(None)

        for key, val in full_out.items():
            full_out[key] = val.refine_names("N", "bounds", ...)
        return full_out


# Utility classes for interval operations
class IntervalMaxPool2d(nn.MaxPool2d):
    """Max pooling for interval tensors"""
    
    def __init__(self, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False):
        super().__init__(kernel_size, stride, padding, dilation, return_indices, ceil_mode)

    def forward(self, x):
        x = x.refine_names("N", "bounds", ...)
        x_lower, x_middle, x_upper = map(lambda x_: cast(Tensor, x_.rename(None)), x.unbind("bounds"))
        x_lower = super().forward(x_lower)
        x_middle = super().forward(x_middle)
        x_upper = super().forward(x_upper)
        return torch.stack([x_lower, x_middle, x_upper], dim=1).refine_names("N", "bounds", ...)


class IntervalDropout(nn.Module):
    """Dropout for interval tensors"""
    
    def __init__(self, p=0.5):
        super().__init__()
        self.p = p
        self.scale = 1. / (1 - self.p)

    def forward(self, x):
        if self.training:
            x = x.refine_names("N", "bounds", ...)
            x_lower, x_middle, x_upper = map(lambda x_: cast(Tensor, x_.rename(None)), x.unbind("bounds"))
            mask = torch.bernoulli(self.p * torch.ones_like(x_middle)).long()
            x_lower = x_lower.where(mask != 1, torch.zeros_like(x_lower)) * self.scale
            x_middle = x_middle.where(mask != 1, torch.zeros_like(x_middle)) * self.scale
            x_upper = x_upper.where(mask != 1, torch.zeros_like(x_upper)) * self.scale
            return torch.stack([x_lower, x_middle, x_upper], dim=1)
        else:
            return x


# Example usage
if __name__ == "__main__":
    # Create model
    model = IntervalNet.create_mlp(
        input_size=784,
        hidden_dim=512,
        output_classes=10,
        heads=1,
        radius_multiplier=1.0,
        max_radius=1.0
    )
    
    print("Modular IntervalNet created successfully!")
    print(f"Model has {sum(p.numel() for p in model.parameters())} parameters")
    print(f"Current mode: {model.mode}")
    
    """
    # load mnist dataset and create a DataLoader using continuum
    # Example training loop
    
    import torch.optim as optim
    
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    # Training multiple tasks
    for task_id in range(num_tasks):
        model.before_task(task_id)
        
        for epoch in range(num_epochs):
            model.before_epoch(epoch, num_epochs)
            
            for batch_idx, (data, targets) in enumerate(dataloader):
                # Compute loss directly
                loss = model.compute_loss(data, targets, task_labels=torch.full((data.size(0),), task_id))
                
                # Standard PyTorch training
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                # Post-step operations
                model.after_step()
    print("Training completed!")
    """
