import torch
import timm
import torch.nn as nn
import sys

class InterpretableConvNextV2(nn.Module):
    def __init__(self, add_bias = False):
        
        super(InterpretableConvNextV2, self).__init__()

        # Load the pretrained ConvNextV2 model
        self.add_bias = add_bias
        self.model = timm.create_model('convnext_tiny.in12k_ft_in1k_384', pretrained=True)        
        data_config = timm.data.resolve_model_data_config(self.model)
        self.transforms = timm.data.create_transform(**data_config, is_training=False)
        self.model.eval()

    def masked_stage_forward(self, input, stage, mask, running_norm_variant = None):
        
        output = input
        if not isinstance(stage.downsample, nn.Identity):
            running_norm_variant = stage.downsample[0](running_norm_variant)
            output = stage.downsample[0](output)
            
            weights = stage.downsample[1].weight
            bias = stage.downsample[1].bias
            stage.downsample[1].weight = nn.Parameter(torch.ones_like(weights)/torch.numel(weights[0]))
            stage.downsample[1].bias = nn.Parameter(torch.zeros_like(bias))
            
            mask = stage.downsample[1](mask)
            mask = torch.where(mask != 0, 1.0, 0.0)

            stage.downsample[1].weight = weights
            stage.downsample[1].bias = bias

            running_norm_variant = stage.downsample[1](running_norm_variant)
            output = stage.downsample[1](output)

            output = torch.where(mask != 0, output, torch.zeros_like(output))


        for block in stage.blocks:
            shortcut = output
            shortcut_mask = mask
            running_norm_shortcut = running_norm_variant
            
            weights = block.conv_dw.weight
            bias = block.conv_dw.bias
            block.conv_dw.weight = nn.Parameter(torch.ones_like(weights)/torch.numel(weights[0]))
            block.conv_dw.bias = nn.Parameter(torch.zeros_like(bias))

            mask = block.conv_dw(mask)
            mask = torch.where(mask != 0, 1.0, 0.0)

            if self.add_bias:
                bias_vals = torch.where(mask == 0, bias[:,None,None,None], 0)
    
            block.conv_dw.weight = weights
            block.conv_dw.bias = bias

            running_norm_variant = block.conv_dw(running_norm_variant)
            output = running_norm_variant.permute(0, 2, 3, 1)
            output = block.norm(output)
            running_norm_variant = output.clone()

            mask = mask.permute(0, 2, 3, 1)

            output = torch.where(mask != 0.0, output, torch.zeros_like(output))
            
            if self.add_bias:
                output = output + bias_vals

            output = block.mlp(output)
            running_norm_variant = block.mlp(running_norm_variant)
            
            running_norm_variant = running_norm_variant.permute(0, 3, 1, 2)
            output = output.permute(0, 3, 1, 2)
            mask = mask.permute(0, 3, 1, 2)

            if not isinstance(block.shortcut, nn.Identity):
                shortcut = block.shortcut(shortcut)
                running_norm_shortcut = shortcut.clone()
                shortcut = torch.where(shortcut_mask != 0.0, shortcut, torch.zeros_like(shortcut))

            if not isinstance(block.drop_path, nn.Identity):
                output = block.drop_path(output)
                running_norm_variant = block.drop_path(running_norm_variant)
                output = torch.where(mask != 0.0, output, torch.zeros_like(output))
            
            if block.gamma is not None:
                output = output.mul(block.gamma.reshape(1, -1, 1, 1))
                running_norm_variant = running_norm_variant.mul(block.gamma.reshape(1, -1, 1, 1))
                output = torch.where(mask != 0.0, output, torch.zeros_like(output))

            output += shortcut
            running_norm_variant += running_norm_shortcut
        
        return output, mask, running_norm_variant
        


    def forward(self, x,
                 explanation_mode = False,
                 masking_value = None,
                 explanation_mask = None):
        
        if explanation_mode:
            # assert explanation_mask is not None or masking_value is not None, "Explanation_mask or masking_value must be provided in explanation mode"
            with torch.no_grad():
                if masking_value is not None:
                    explanation_mask = torch.where(x[:, 1:2, :, :] == masking_value, 0, 1.0)

                if masking_value is None and explanation_mask is None:
                    explanation_mask = torch.ones_like(x)
                
                if len(explanation_mask.shape) == 3 or explanation_mask.shape[1] == 1:
                    explanation_mask = torch.stack([explanation_mask] * 3, dim=1)
                
                #Since bias term is not zero, we need to make decision whether to mask it or propagate it
                running_norm_variant = x.clone() #Since mu and sigma are calculated on the fly, we need to keep original variant so that we can use it for norm

                # Applying the Conv Operator to the Mutant Mask, then removing the effect of the LayerNorm
                weights = self.model.stem[0].weight
                bias = self.model.stem[0].bias
                self.model.stem[0].weight = nn.Parameter(torch.ones_like(weights)/torch.numel(weights[0]))
                self.model.stem[0].bias = nn.Parameter(torch.zeros_like(bias))  

                explanation_mask = self.model.stem[0](explanation_mask) 
                explanation_mask = torch.where(explanation_mask != 0, 1.0, 0.0)
                
                self.model.stem[0].weight = weights
                self.model.stem[0].bias = bias
                
                if self.add_bias:
                    bias_vals = torch.where(explanation_mask == 0, bias[:,None,None,None], 0)
                
                
                #For original norm values
                running_norm_variant = self.model.stem[0](running_norm_variant)
                layer_norm_variant = self.model.stem[1](running_norm_variant)
                layer_norm_values = running_norm_variant - layer_norm_variant
                
                # output = self.model.stem[0](input_variant)
                # output = self.model.stem[1](output)        
                output = torch.where(explanation_mask == 0, 0, running_norm_variant)
                # output = torch.where(output != 0, running_norm_variant, output)

                if self.add_bias:
                    output = output + bias_vals
                        
                output, explanation_mask, running_norm_variant = self.masked_stage_forward(input = output, stage = self.model.stages[0], mask = explanation_mask, running_norm_variant = running_norm_variant)
                output, explanation_mask, running_norm_variant = self.masked_stage_forward(input = output, stage = self.model.stages[1], mask = explanation_mask, running_norm_variant = running_norm_variant)
                output, explanation_mask, running_norm_variant = self.masked_stage_forward(input = output, stage = self.model.stages[2], mask = explanation_mask, running_norm_variant = running_norm_variant)
                output, explanation_mask, running_norm_variant = self.masked_stage_forward(input = output, stage = self.model.stages[3], mask = explanation_mask, running_norm_variant = running_norm_variant)
                output = self.model.norm_pre(output)
                output = self.model.head(output)

                return output
        else:
            with torch.no_grad():
                return self.model(x)
