import torch
from merlin_arthur_framework.stochastic_frank_wolfe import SFW, PositiveKSparsePolytope
from typing import Optional


class SFWFeatureSelector(torch.nn.Module):
    def __init__(
        self,
        mask_size: int,
        mode: str = "merlin",
        lr: float = 0.1,
        binary_classification: bool = False,
        l1_penalty_coefficient: float = 0.1,
        idk_class: int = 3,
        num_blocks: int = 16,
        sfw_max_iterations: int = 350,
        sfw_patience: int = 10,
        enc_type: str = "one_hot_padded"
    ) -> None:
        """
        Simple feature selector using Stochastic Frank-Wolfe optimization.
        
        Args:
            mask_size: Number of features to select (sparsity level)
            mode: "merlin" (minimize loss) or "morgana" (maximize loss)
            lr: Learning rate (or step size) for Merlin or Morgana (SFW)
            binary_classification: Whether this is a binary classification task
            l1_penalty_coefficient: Coefficient for L1 regularization
            idk_class: Index of the IDK class for Morgana
            num_blocks: Number of blocks in the one-hot encoding
            sfw_max_iterations: Maximum number of iterations for SFW
            sfw_patience: Patience for SFW
            enc_type: Type of encoding (one_hot_padded)
        """
        super().__init__()
        assert mode in ["merlin", "morgana"], "Mode must be 'merlin' or 'morgana'"
        assert mask_size > 0, "Mask size must be greater than 0"
        
        self.mask_size = mask_size
        self.mode = mode
        self.lr = lr
        self.binary_classification = binary_classification
        self.l1_penalty_coefficient = l1_penalty_coefficient
        self.idk_class = idk_class
        self.num_blocks = num_blocks
        self.sfw_max_iterations = sfw_max_iterations
        self.sfw_patience = sfw_patience
        self.enc_type = enc_type
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Set criterion based on classification type
        if binary_classification:
            self.criterion = torch.nn.BCEWithLogitsLoss()
        else:
            self.criterion = torch.nn.CrossEntropyLoss() if mode == "merlin" else MorganaCriterion(self.idk_class)
            
    def forward(self, x, y, classifier, init_mask=None):
        """
        Optimize a mask using SFW for embedding vectors.
        
        Args:
            x: Input embedding tensor [batch_size, embedding_dim]
            y: Target labels [batch_size]
            classifier: Model to use for predictions
            init_mask: Initial mask (random if None)
            
        Returns:
            Optimized mask
        """
        batch_size, num_slots, slot_dim = x.shape
        
        # Initialize mask if needed
        if init_mask is None:
            init_mask = torch.rand(batch_size, num_slots, self.num_blocks).to(self.device)
            
        # Setup SFW optimizer
        constraint = PositiveKSparsePolytope(
            n=num_slots*self.num_blocks,  # Number of features
            bs=batch_size, 
            k=self.mask_size
        )
        init_mask = constraint.shift_inside(init_mask)
        
        # Create parameter for optimization
        mask = torch.nn.Parameter(init_mask.to(self.device), requires_grad=True)
        # Use configurable learning rate from class
        optimizer = SFW([mask], learning_rate=self.lr, momentum=0.9)
        
        # Freeze classifier parameters
        for param in classifier.parameters():
            param.requires_grad = False
        classifier.eval()
        
        # Track best loss for early stopping
        best_loss = float('inf')
        patience = self.sfw_patience
        patience_counter = 0
        
        # Optimization loop
        for _ in range(self.sfw_max_iterations):  # Max iterations
            optimizer.zero_grad()
            
            # Apply mask and get predictions
            x_masked = self.apply_mask(x, mask)
            logits = classifier(x_masked)
            
            # Handle binary classification case
            if self.binary_classification:
                logits = logits.squeeze(1)
                y_tensor = y.float()
            else:
                y_tensor = y
            
            # Calculate loss based on mode
            if self.mode == "merlin":
                distortion = self.criterion(logits, y_tensor)
            elif self.mode == "morgana": 
                distortion = -self.criterion(logits, y_tensor)
                
            if self.l1_penalty_coefficient is not None:
                # Add regularization
                l1_penalty = self.l1_penalty_coefficient * torch.mean(torch.abs(mask))
            else:
                l1_penalty = 0.0
            
            # Total loss
            loss = distortion + l1_penalty
            loss.backward()
            
            # Update mask
            optimizer.step(constraints=[constraint])
            
            # Early stopping check
            if loss.item() < best_loss - 1e-5:  # Improved by at least delta
                best_loss = loss.item()
                patience_counter = 0
            else:
                patience_counter += 1
                if patience_counter >= patience:
                    break
        
        # Re-enable classifier gradients
        for param in classifier.parameters():
            param.requires_grad = True
        
        return mask.detach()
    
    def apply_mask(self, x, mask):
        """Apply mask to input"""
        if self.enc_type == "one_hot_padded":
            # use as mask shape (4,16) and if the mask is zero set the 1 in the corres. block to zero
            mask = mask.unsqueeze(-1) # shape (4, 16, 1) for broadcasting
            input = x.reshape(*x.shape[:-1], self.num_blocks, -1) # unflatten the one-hot encoding into shape (128, 4, 16, 264)
            input_masked = input * mask # broadcasting, shape (128, 4, 16, 264)
            return input_masked.reshape(*input_masked.shape[:-2], -1)  #flatten again
        else:
            raise ValueError(f"Unsupported input type")

    def get_binary_mask(self, continuous_mask):
        """Convert continuous mask to binary mask"""
        v = torch.zeros_like(continuous_mask).flatten(start_dim=1)
        max_indices = torch.topk(torch.abs(continuous_mask.flatten(start_dim=1)), k=self.mask_size).indices.to(continuous_mask.device)
        v.scatter_(1, max_indices, 1.0)
        return v.reshape(continuous_mask.shape) 


class ModelFeatureSelector(torch.nn.Module):
    def __init__(
        self,
        mask_size: int,
        mode: str = "merlin",
        binary_classification: bool = False,
        idk_class: int = 100,
        num_blocks: int = 16,
        enc_type: str = "one_hot_padded",
        model: torch.nn.Module = None
    ) -> None:
        super().__init__()

        assert mode in ["merlin", "morgana"], "Mode must be 'merlin' or 'morgana'"
        assert mask_size > 0, "Mask size must be greater than 0"

        self.mask_size = mask_size
        self.mode = mode
        self.idk_class = idk_class
        self.num_blocks = num_blocks
        self.enc_type = enc_type
        self.model = model

        # Set criterion based on classification type
        if binary_classification:
            self.criterion = torch.nn.BCEWithLogitsLoss()
        else:
            self.criterion = torch.nn.CrossEntropyLoss() if mode == "merlin" else MorganaCriterion(self.idk_class)

    def forward(self, x):
        return self.model(x)
    
    def apply_mask(self, x, mask):
        """Apply mask to input"""
        if self.enc_type == "one_hot_padded":
            # use as mask shape (4,16) and if the mask is zero set the 1 in the corres. block to zero
            mask = mask.unsqueeze(-1) # shape (4, 16, 1) for broadcasting
            input = x.reshape(*x.shape[:-1], self.num_blocks, -1) # unflatten the one-hot encoding into shape (128, 4, 16, 264)
            input_masked = input * mask # broadcasting, shape (128, 4, 16, 264)
            return input_masked.reshape(*input_masked.shape[:-2], -1)  #flatten again
        else:
            raise ValueError(f"Unsupported input type")

    def get_binary_mask(self, continuous_mask):
        """Convert continuous mask to binary mask"""
        v = torch.zeros_like(continuous_mask).flatten(start_dim=1)
        max_indices = torch.topk(torch.abs(continuous_mask.flatten(start_dim=1)), k=self.mask_size).indices.to(continuous_mask.device)
        v.scatter_(1, max_indices, 1.0)
        return v.reshape(continuous_mask.shape) 


class MorganaCriterion(torch.nn.Module):
    def __init__(self, idk_class, weight: Optional[torch.Tensor] = None, reduction: str = "mean") -> None:
        super().__init__()
        self.reduction = reduction
        self.weight = weight
        self.idk_class = idk_class # index of the idk class

    def forward(self, logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """Returns the loss that is minimized by Arthur and maximized by Morgana.

        Args:
            logits (torch.Tensor): Arthurs output.
            target (torch.Tensor): True targets.

        Raises:
            ValueError: Reduction assertion, possible values are `mean`, `sum` and `none`.

        Returns:
            torch.Tensor: Outputs loss minimized by Arthur and maximized by Morgana.
        """
        logits_wrt_true_class = torch.gather(logits, dim=1, index=target.unsqueeze(1))
        logits_idk = logits[:, self.idk_class].unsqueeze(1)  # last column corresponds to idk logits
        logits_concatenated = torch.cat((logits_wrt_true_class, logits_idk), 1)

        diff = -torch.abs(logits_wrt_true_class - logits_idk)

        target_cloned = torch.clone(target)
        target_cloned[torch.argmax(logits_concatenated, dim=1) == 1] = self.idk_class
        criterion = torch.nn.CrossEntropyLoss(weight=self.weight, reduction=self.reduction)

        if self.reduction == "mean":
            correction_term = -torch.log(1 + torch.exp(diff)).mean()
        elif self.reduction == "sum":
            correction_term = -torch.log(1 + torch.exp(diff)).sum()
        elif self.reduction == "none":
            correction_term = -torch.log(1 + torch.exp(diff)).squeeze()
        else:
            raise ValueError(f"unexpected value for reduction, got `{self.reduction}`")

        loss = criterion(logits, target_cloned) + correction_term

        return loss