import torch
import torch.nn as nn
import torch.nn.functional as F
from merlin_arthur_framework.stochastic_frank_wolfe import SFW, PositiveKSparsePolytope
from merlin_arthur_framework.feature_selectors import MorganaCriterion

class SFWPixelFeatureSelector(torch.nn.Module):
    def __init__(
        self,
        mask_size: int,
        mode: str = "merlin",
        lr: float = 0.1,
        binary_classification: bool = False,
        l1_penalty_coefficient: float = 0.1,
        idk_class: int = 3,
        sfw_max_iterations: int = 350,
        sfw_patience: int = 10
    ) -> None:
        """
        Simple feature selector using Stochastic Frank-Wolfe optimization.
        
        Args:
            mask_size: Number of features to select (sparsity level)
            mode: "merlin" (minimize loss) or "morgana" (maximize loss)
            lr: Learning rate (or step size) for Merlin or Morgana (SFW)
            binary_classification: Whether this is a binary classification task
            l1_penalty_coefficient: Coefficient for L1 regularization
            idk_class: Index of the IDK class for Morgana
            num_blocks: Number of blocks in the one-hot encoding
            sfw_max_iterations: Maximum number of iterations for SFW
            sfw_patience: Patience for SFW
            enc_type: Type of encoding (one_hot_padded)
        """
        super().__init__()
        assert mode in ["merlin", "morgana"], "Mode must be 'merlin' or 'morgana'"
        assert mask_size > 0, "Mask size must be greater than 0"
        
        self.mask_size = mask_size
        self.mode = mode
        self.lr = lr
        self.binary_classification = binary_classification
        self.l1_penalty_coefficient = l1_penalty_coefficient
        self.idk_class = idk_class
        self.sfw_max_iterations = sfw_max_iterations
        self.sfw_patience = sfw_patience
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Set criterion based on classification type
        if binary_classification:
            self.criterion = torch.nn.BCEWithLogitsLoss()
        else:
            self.criterion = torch.nn.CrossEntropyLoss() if mode == "merlin" else MorganaCriterion(self.idk_class)
            
    def forward(self, x, y, classifier, init_mask=None):
        """
        Optimize a mask using SFW for embedding vectors.
        
        Args:
            x: Input images [batch_size, channels, height, width]
            y: Target labels [batch_size]
            classifier: Model to use for predictions
            init_mask: Initial mask (random if None)
            
        Returns:
            Optimized mask
        """
        batch_size = y.shape[0]
        
        # Initialize mask if needed
        if init_mask is None:
            # shape (bs, 1, 128, 128)
            init_mask = torch.rand(x.shape[0], 1, x.shape[2], x.shape[3]).to(self.device)
            
        # Setup SFW optimizer
        constraint = PositiveKSparsePolytope(
            n=x.shape[2]*x.shape[3],  # Total number of pixels in image
            bs=batch_size,
            k=self.mask_size
        )
        init_mask = constraint.shift_inside(init_mask)
        
        # Create parameter for optimization
        mask = torch.nn.Parameter(init_mask.to(self.device), requires_grad=True)
        # Use configurable learning rate from class
        optimizer = SFW([mask], learning_rate=self.lr, momentum=0.9, rescale=None)
        
        # Freeze classifier parameters
        for param in classifier.parameters():
            param.requires_grad = False
        classifier.eval()
        
        # Track best loss for early stopping
        best_loss = float('inf')
        patience = self.sfw_patience
        patience_counter = 0
        
        # Optimization loop
        for _ in range(self.sfw_max_iterations):  # Max iterations
            optimizer.zero_grad()
            
            # Apply mask and get predictions
            x_masked = self.apply_mask(x, mask)
            logits = classifier(x_masked)
            
            # Handle binary classification case
            if self.binary_classification:
                logits = logits.squeeze(1)
                y_tensor = y.float()
            else:
                y_tensor = y
            
            # Calculate loss based on mode
            if self.mode == "merlin":
                distortion = self.criterion(logits, y_tensor)
            elif self.mode == "morgana": 
                distortion = -self.criterion(logits, y_tensor)
                
            if self.l1_penalty_coefficient is not None:
                # Add regularization
                l1_penalty = self.l1_penalty_coefficient * torch.mean(torch.abs(mask))
            else:
                l1_penalty = 0.0
            
            # Total loss
            loss = distortion + l1_penalty
            loss.backward()
            
            # Update mask
            optimizer.step(constraints=[constraint])
            
            # Early stopping check
            if loss.item() < best_loss - 1e-5:  # Improved by at least delta
                best_loss = loss.item()
                patience_counter = 0
            else:
                patience_counter += 1
                if patience_counter >= patience:
                    break
        
        # Re-enable classifier gradients
        for param in classifier.parameters():
            param.requires_grad = True
        
        return mask.detach()
    
    def apply_mask(self, x, mask):
        """Apply mask to input"""
        # Add gaussian noise to masked input
        x_masked = mask * x + (1 - mask) * torch.rand_like(x)
        return x_masked
        
    def get_binary_mask(self, continuous_mask):
        """Convert continuous mask to binary mask"""
        v = torch.zeros_like(continuous_mask).flatten(start_dim=1)
        max_indices = torch.topk(torch.abs(continuous_mask.flatten(start_dim=1)), k=self.mask_size).indices.to(continuous_mask.device)
        v.scatter_(1, max_indices, 1.0)
        return v.reshape(continuous_mask.shape) 

# U-Net
class ModelPixelFeatureSelector(torch.nn.Module):
    def __init__(
        self,
        mask_size: int,
        mode: str = "merlin",
        idk_class: int = 3,
        model: torch.nn.Module = None
    ) -> None:
        super().__init__()

        assert mode in ["merlin", "morgana"], "Mode must be 'merlin' or 'morgana'"
        assert mask_size > 0, "Mask size must be greater than 0"

        self.mask_size = mask_size
        self.mode = mode
        self.idk_class = idk_class
        self.model = model
        self.criterion = torch.nn.CrossEntropyLoss() if mode == "merlin" else MorganaCriterion(self.idk_class)

    def forward(self, x):
        return self.model(x)
    
    def apply_mask(self, x, mask):
        """Apply mask to input"""
        # Add gaussian noise to masked input
        x_masked = mask * x + (1 - mask) * torch.rand_like(x)
        return x_masked
        
    def get_binary_mask(self, continuous_mask):
        """Convert continuous mask to binary mask"""
        v = torch.zeros_like(continuous_mask).flatten(start_dim=1)
        max_indices = torch.topk(torch.abs(continuous_mask.flatten(start_dim=1)), k=self.mask_size).indices.to(continuous_mask.device)
        v.scatter_(1, max_indices, 1.0)
        return v.reshape(continuous_mask.shape) 
    
    def normalize_l1(self, input: torch.Tensor, mask_size: int):
        factor = torch.clamp(mask_size / (1e-7 + torch.norm(input, p=1, dim=(2, 3), keepdim=True)), max=1)  # type: ignore
        return factor * input

# Used Model for U-Net
class SimpleNet(nn.Module):
    def __init__(self, n_channels: int, bilinear: bool = True, apply_sigmoid: bool = False):
        super(SimpleNet, self).__init__()
        self.n_channels = n_channels
        self.bilinear = bilinear
        self.apply_sigmoid = apply_sigmoid

        self.inc = DoubleConv(self.n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if self.bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512 // factor, self.bilinear)
        self.up2 = Up(512, 256 // factor, self.bilinear)
        self.up3 = Up(256, 128 // factor, self.bilinear)
        self.up4 = Up(128, 64, self.bilinear)
        self.out_conv = OutConv(64, 1)
        self.lin = nn.Linear(32 * 32, 784)

    def forward(self, x):
        x1 = self.inc(x)  # Output shape: Channel=64, Width=28, Height=28
        x2 = self.down1(x1)  # Channel=128, Width=14, Height=14
        x3 = self.down2(x2)  # Channel=256, Width=7, Height=7
        x4 = self.down3(x3)  # Channel=512, Width=3, Height=3
        x5 = self.down4(x4)  # Channel=512, Width=1, Height=1
        x = self.up1(x5, x4)  # Channel=256, Width=3, Height=3
        x = self.up2(x, x3)  # Channel=128, Width=7, Height=7
        x = self.up3(x, x2)  # Channel=64, Width=14, Height=14
        x = self.up4(x, x1)  # Channel=128, Width=28, Height=28
        logits = self.out_conv(x)  # N, C, W, H
        # a = torch.abs(logits[:, 0, :, :])
        # b = torch.abs(logits[:, 1, :, :])
        # logits = torch.unsqueeze(a / (a + b), dim=1)
        # return shape: N, C, W, H
        # return torch.sigmoid(logits)
        return torch.sigmoid(logits) if self.apply_sigmoid else logits


class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(nn.MaxPool2d(2), DoubleConv(in_channels, out_channels))

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)