import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from gym import spaces
from habitat import logger
from habitat_baselines.rl.ddppo.policy import resnet
from habitat_baselines.rl.ddppo.policy.resnet_policy import ResNetEncoder
import torchvision
import clip
from typing import Tuple


class ConvAdapter(nn.Module):
    """
    Convolutional adapter for depth features with bottleneck architecture.
    Uses 1x1 convolutions with residual connection and starts as identity.
    
    Args:
        in_ch: Number of input channels
        mid_ch: Number of middle/bottleneck channels  
        p_drop: Dropout probability
    """
    def __init__(self, in_ch: int, mid_ch: int = 64, p_drop: float = 0.1):
        super().__init__()
        self.norm = nn.GroupNorm(1, in_ch)  # Layer normalization via GroupNorm
        self.down = nn.Conv2d(in_ch, mid_ch, 1, bias=False)  # Bottleneck down
        self.act = nn.GELU()
        self.up = nn.Conv2d(mid_ch, in_ch, 1, bias=False)    # Bottleneck up
        self.drop = nn.Dropout2d(p_drop)
        
        # Initialize final conv weights to zero for identity startup
        nn.init.zeros_(self.up.weight)
        
    def forward(self, x):
        # Residual connection: x + drop(up(gelu(down(norm(x)))))
        residual = self.up(self.act(self.down(self.norm(x))))
        return x + self.drop(residual)


class VlnResnetDepthEncoder(nn.Module):
    def __init__(
        self,
        observation_space,
        output_size=128,
        checkpoint="NONE",
        backbone="resnet50",
        resnet_baseplanes=32,
        normalize_visual_inputs=False,
        trainable=False,
        spatial_output: bool = False,
        # Adapter parameters
        use_depth_adapters: bool = False,
        depth_adapter_width: int = 64,
        depth_adapter_dropout: float = 0.1,
        depth_adapter_stages: Tuple[int, ...] = (2, 3),
        train_ln_with_adapters: bool = True,
    ):
        super().__init__()
        self.use_depth_adapters = use_depth_adapters
        self.train_ln_with_adapters = train_ln_with_adapters
        
        self.visual_encoder = ResNetEncoder(
            spaces.Dict({"depth": observation_space.spaces["depth"]}),
            baseplanes=resnet_baseplanes,
            ngroups=resnet_baseplanes // 2,
            make_backbone=getattr(resnet, backbone),
            normalize_visual_inputs=normalize_visual_inputs,
        )

        # Store original trainable setting for compatibility
        self.original_trainable = trainable
        for param in self.visual_encoder.parameters():
            param.requires_grad_(trainable)
            
        # Add adapters if enabled
        if use_depth_adapters:
            self._add_adapters(depth_adapter_stages, depth_adapter_width, depth_adapter_dropout)
            self._register_adapter_hooks()
            logger.info(f"Added depth adapters to stages {depth_adapter_stages} with width {depth_adapter_width}")

        if checkpoint != "NONE":
            ddppo_weights = torch.load(checkpoint)

            weights_dict = {}
            for k, v in ddppo_weights["state_dict"].items():
                split_layer_name = k.split(".")[2:]
                if split_layer_name[0] != "visual_encoder":
                    continue

                layer_name = ".".join(split_layer_name[1:])
                weights_dict[layer_name] = v

            del ddppo_weights
            # Use strict=False when adapters are enabled, since checkpoint won't have adapter parameters
            strict_loading = not self.use_depth_adapters
            self.visual_encoder.load_state_dict(weights_dict, strict=strict_loading)

        self.spatial_output = spatial_output

        if not self.spatial_output:
            self.output_shape = (output_size,)
            # self.visual_fc = nn.Sequential(
            #     nn.Flatten(),
            #     nn.Linear(
            #         np.prod(self.visual_encoder.output_shape), output_size
            #     ),
            #     nn.ReLU(True),
            # )
            None
        else:
            # Use 4x4 spatial embeddings to match actual output size
            self.spatial_embeddings = nn.Embedding(4 * 4, 64)
            
            # We'll create the projection layer dynamically in forward() since the visual encoder output size may vary
            self.projection_layer = None
            self.output_shape = (output_size, 4, 4)

    def _add_adapters(self, stages: Tuple[int, ...], width: int, dropout: float):
        """Add ConvAdapter modules to specified ResNet stages."""
        # Get the ResNet backbone from the visual encoder
        backbone = self.visual_encoder.backbone
        
        # Map stage numbers to layer names
        stage_layers = {
            1: backbone.layer1,
            2: backbone.layer2, 
            3: backbone.layer3,
            4: backbone.layer4
        }
        
        for stage_idx in stages:
            if stage_idx not in stage_layers:
                logger.warning(f"Invalid adapter stage {stage_idx}, skipping")
                continue
                
            layer = stage_layers[stage_idx]
            
            # Add adapter to each bottleneck block in this stage
            for i, block in enumerate(layer):
                # Try different ways to get output channels
                out_channels = None
                
                # Method 1: Check for conv3 (standard PyTorch Bottleneck blocks)
                if hasattr(block, 'conv3') and hasattr(block.conv3, 'out_channels'):
                    out_channels = block.conv3.out_channels
                # Method 2: Check for conv2 (BasicBlock)
                elif hasattr(block, 'conv2') and hasattr(block.conv2, 'out_channels'):
                    out_channels = block.conv2.out_channels
                # Method 3: Check habitat_baselines custom Bottleneck with 'convs' attribute
                elif hasattr(block, 'convs') and isinstance(block.convs, nn.Sequential):
                    # Get the last conv layer from the sequential
                    last_conv = None
                    for layer_module in reversed(block.convs):
                        if isinstance(layer_module, nn.Conv2d):
                            last_conv = layer_module
                            break
                    if last_conv is not None:
                        out_channels = last_conv.out_channels
                # Method 4: Use known ResNet50 channel dimensions (based on debug output)
                elif stage_idx == 1:
                    out_channels = 64   # ResNet50 layer1 output (actually 64*4=256 after bottleneck)
                elif stage_idx == 2:
                    out_channels = 256  # ResNet50 layer2 output (from debug: torch.Size([1, 256, 16, 16]))
                elif stage_idx == 3:
                    out_channels = 512  # ResNet50 layer3 output
                elif stage_idx == 4:
                    out_channels = 1024 # ResNet50 layer4 output
                
                if out_channels is None:
                    logger.warning(f"Could not determine output channels for block {i} in stage {stage_idx}")
                    continue
                    
                # Add adapter as an attribute to the block
                adapter = ConvAdapter(out_channels, width, dropout)
                setattr(block, 'depth_adapter', adapter)
                logger.debug(f"Added adapter to stage {stage_idx}, block {i} with {out_channels} channels")

    def _register_adapter_hooks(self):
        """Register forward hooks to apply adapters after block outputs."""
        if not self.use_depth_adapters:
            return
            
        backbone = self.visual_encoder.backbone
        stage_layers = [backbone.layer1, backbone.layer2, backbone.layer3, backbone.layer4]
        
        def create_adapter_hook(block):
            def adapter_hook(module, input, output):
                if hasattr(block, 'depth_adapter'):
                    return block.depth_adapter(output)
                return output
            return adapter_hook
        
        # Register hooks for blocks that have adapters
        for layer in stage_layers:
            for block in layer:
                if hasattr(block, 'depth_adapter'):
                    block.register_forward_hook(create_adapter_hook(block))

    def _apply_adapters_to_block_output(self, x, block):
        """Apply adapter to block output if it exists."""
        if hasattr(block, 'depth_adapter') and self.use_depth_adapters:
            return block.depth_adapter(x)
        return x

    def freeze_base_parameters(self):
        """Freeze all base ResNet parameters, keep only adapters trainable."""
        # Freeze all visual encoder parameters
        for param in self.visual_encoder.parameters():
            param.requires_grad_(False)
            
        # Unfreeze adapter parameters
        if self.use_depth_adapters:
            backbone = self.visual_encoder.backbone
            stage_layers = [backbone.layer1, backbone.layer2, backbone.layer3, backbone.layer4]
            
            for layer in stage_layers:
                for block in layer:
                    if hasattr(block, 'depth_adapter'):
                        for param in block.depth_adapter.parameters():
                            param.requires_grad_(True)
                            
                        # Optionally unfreeze normalization parameters in adapters
                        if self.train_ln_with_adapters and hasattr(block.depth_adapter, 'norm'):
                            for param in block.depth_adapter.norm.parameters():
                                param.requires_grad_(True)
                                
        logger.info("Froze base depth encoder parameters, kept adapters trainable")

    def get_adapter_parameters(self):
        """Get all adapter parameters for optimizer."""
        adapter_params = []
        if self.use_depth_adapters:
            backbone = self.visual_encoder.backbone
            stage_layers = [backbone.layer1, backbone.layer2, backbone.layer3, backbone.layer4]
            
            for layer in stage_layers:
                for block in layer:
                    if hasattr(block, 'depth_adapter'):
                        adapter_params.extend(block.depth_adapter.parameters())
                        
        return adapter_params

    def count_adapter_parameters(self):
        """Count trainable adapter parameters."""
        if not self.use_depth_adapters:
            return 0
            
        adapter_params = self.get_adapter_parameters()
        return sum(p.numel() for p in adapter_params if p.requires_grad)

    def forward(self, observations):
        """
        Args:
            observations: [BATCH, HEIGHT, WIDTH, CHANNEL]
        Returns:
            [BATCH, OUTPUT_SIZE]
        """
        if "depth_features" in observations:
            x = observations["depth_features"]
        else:
            x = self.visual_encoder(observations)

        if self.spatial_output:
            b, c, h, w = x.size()

            spatial_features = (
                self.spatial_embeddings(
                    torch.arange(
                        0,
                        self.spatial_embeddings.num_embeddings,
                        device=x.device,
                        dtype=torch.long,
                    )
                )
                .view(1, -1, h, w)
                .expand(b, self.spatial_embeddings.embedding_dim, h, w)
            )

            # Concatenate features
            combined_features = torch.cat([x, spatial_features], dim=1)
            
            # Create projection layer dynamically if not exists or if input size changed
            if (self.projection_layer is None or 
                self.projection_layer.in_channels != combined_features.size(1)):
                # Force output to 128 channels to match ETP model expectations
                self.projection_layer = nn.Conv2d(
                    combined_features.size(1),  # Use actual input channels
                    128,  # Force 128 output channels to match ETP model
                    kernel_size=1
                ).to(combined_features.device)
            
            projected_features = self.projection_layer(combined_features)
            # Global average pooling to get [B, 128] features
            return torch.nn.functional.adaptive_avg_pool2d(projected_features, (1, 1)).view(projected_features.size(0), -1)
        else:
            # return self.visual_fc(x)
            return x


class TorchVisionResNet50(nn.Module):
    r"""
    Takes in observations and produces an embedding of the rgb component.

    Args:
        observation_space: The observation_space of the agent
        output_size: The size of the embedding vector
        device: torch.device
        spatial_output: Whether to output spatial features
        use_rgb_adapters: Whether to use RGB adapters for robustness
        rgb_adapter_width: Bottleneck width for RGB adapters
        rgb_adapter_dropout: Dropout probability for RGB adapters
        rgb_adapter_stages: Which ResNet stages to add adapters to
        train_ln_with_adapters: Whether to train normalization layers with adapters
    """

    def __init__(
        self,
        observation_space,
        output_size,
        device,
        spatial_output: bool = False,
        # RGB Adapter parameters
        use_rgb_adapters: bool = False,
        rgb_adapter_width: int = 64,
        rgb_adapter_dropout: float = 0.1,
        rgb_adapter_stages: Tuple[int, ...] = (2, 3),
        train_ln_with_adapters: bool = True,
        # Optional corruption-aware fusion gate and alias for norm flag
        use_rgb_fusion_gate: bool = False,
        train_rgb_norm_with_adapters: bool = None,
    ):
        super().__init__()
        self.device = device
        self.resnet_layer_size = 2048
        self.use_rgb_adapters = use_rgb_adapters
        self.train_ln_with_adapters = train_ln_with_adapters
        # Back-compat alias: allow config to supply train_rgb_norm_with_adapters
        if train_rgb_norm_with_adapters is not None:
            self.train_ln_with_adapters = train_rgb_norm_with_adapters
        self.use_rgb_fusion_gate = use_rgb_fusion_gate
        linear_layer_input_size = 0
        if "rgb" in observation_space.spaces:
            self._n_input_rgb = observation_space.spaces["rgb"].shape[2]
            obs_size_0 = observation_space.spaces["rgb"].shape[0]
            obs_size_1 = observation_space.spaces["rgb"].shape[1]
            if obs_size_0 != 224 or obs_size_1 != 224:
                logger.warn(
                    "TorchVisionResNet50: observation size is not conformant to expected ResNet input size [3x224x224]"
                )
            linear_layer_input_size += self.resnet_layer_size
        else:
            self._n_input_rgb = 0

        if self.is_blind:
            self.cnn = nn.Sequential()
            return

        rgb_resnet = models.resnet50(pretrained=True)
        rgb_modules = list(rgb_resnet.children())[:-2]
        self.cnn = torch.nn.Sequential(*rgb_modules)

        # Store original ResNet for adapter access
        self.backbone = rgb_resnet
        
        # disable gradients for resnet, params frozen
        for param in self.cnn.parameters():
            param.requires_grad_(False)
        self.cnn.eval()
        
        # Add RGB adapters if enabled
        if use_rgb_adapters:
            self._add_rgb_adapters(rgb_adapter_stages, rgb_adapter_width, rgb_adapter_dropout)
            self._register_rgb_adapter_hooks()
            logger.info(f"Added RGB adapters to stages {rgb_adapter_stages} with width {rgb_adapter_width}")

        # Optional corruption-aware fusion gate (scales RGB feature map by [0,1])
        if self.use_rgb_fusion_gate:
            self.rgb_gate = nn.Sequential(
                nn.AdaptiveAvgPool2d(1),
                nn.Conv2d(self.resnet_layer_size, 1, 1, bias=True),
                nn.Sigmoid(),
            )
        
        # Initialize projection layer for ResNet50 -> 512 dim conversion
        if spatial_output:
            self.projection_layer = nn.Sequential(
                nn.AdaptiveAvgPool2d(1),  # Global average pooling: [B, 2048+64, 4, 4] -> [B, 2048+64, 1, 1]
                nn.Flatten(),             # [B, 2048+64, 1, 1] -> [B, 2048+64]
                nn.Linear(2048 + 64, 512),  # Project to 512 dims: [B, 2048+64] -> [B, 512]
                nn.ReLU()
            )

        self.spatial_output = spatial_output

        if not self.spatial_output:
            self.output_shape = (output_size,)
            # self.fc = nn.Linear(linear_layer_input_size, output_size)
            # self.activation = nn.ReLU()
            None
        else:
            class SpatialAvgPool(nn.Module):
                def forward(self, x):
                    x = F.adaptive_avg_pool2d(x, (4, 4))

                    return x
            self.cnn.avgpool = SpatialAvgPool()
            self.cnn.fc = nn.Sequential()
            self.spatial_embeddings = nn.Embedding(4 * 4, 64)
            # Add projection layer to convert to correct output size
            self.projection_layer = nn.Conv2d(
                self.resnet_layer_size + self.spatial_embeddings.embedding_dim,
                output_size,
                kernel_size=1
            )
            self.output_shape = (output_size, 4, 4)

        # self.layer_extract = self.cnn._modules.get("avgpool")

        from torchvision import transforms
        self.rgb_transform = torch.nn.Sequential(
            transforms.ConvertImageDtype(torch.float),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
            )

    def _add_rgb_adapters(self, stages: Tuple[int, ...], adapter_width: int, dropout: float):
        """Add RGB ConvAdapter modules to specified ResNet stages."""
        if not self.use_rgb_adapters:
            return
            
        stage_layers = [self.backbone.layer1, self.backbone.layer2, self.backbone.layer3, self.backbone.layer4]
        stage_channels = [256, 512, 1024, 2048]  # ResNet-50 channel dimensions
        
        adapter_count = 0
        for stage_idx in stages:
            if 1 <= stage_idx <= 4:
                layer = stage_layers[stage_idx - 1]
                channels = stage_channels[stage_idx - 1]
                
                for block_idx, block in enumerate(layer):
                    # Add adapter to each ResNet block
                    adapter = ConvAdapter(channels, adapter_width, dropout)
                    block.rgb_adapter = adapter
                    adapter_count += 1
                    
        logger.info(f"Added {adapter_count} RGB adapters across {len(stages)} stages")

    def _register_rgb_adapter_hooks(self):
        """Register forward hooks to apply RGB adapters automatically."""
        if not self.use_rgb_adapters:
            return
            
        def create_rgb_adapter_hook(block):
            def hook_fn(module, input_tensor, output_tensor):
                if hasattr(block, 'rgb_adapter') and self.use_rgb_adapters:
                    return block.rgb_adapter(output_tensor)
                return output_tensor
            return hook_fn
            
        # Register hooks for all ResNet stages
        stage_layers = [self.backbone.layer1, self.backbone.layer2, self.backbone.layer3, self.backbone.layer4]
        for layer in stage_layers:
            for block in layer:
                if hasattr(block, 'rgb_adapter'):
                    block.register_forward_hook(create_rgb_adapter_hook(block))

    def freeze_base_parameters(self):
        """Freeze all base ResNet parameters, keep only RGB adapters trainable."""
        # Freeze all CNN parameters
        for param in self.cnn.parameters():
            param.requires_grad_(False)
            
        # Also freeze backbone parameters (including fc layer)
        for param in self.backbone.parameters():
            param.requires_grad_(False)
            
        # Keep projection layer trainable (it's needed to adapt ResNet output to expected format)
        if hasattr(self, 'projection_layer'):
            for param in self.projection_layer.parameters():
                param.requires_grad_(True)
            
        # Unfreeze adapter parameters
        if self.use_rgb_adapters:
            stage_layers = [self.backbone.layer1, self.backbone.layer2, self.backbone.layer3, self.backbone.layer4]
            
            for layer in stage_layers:
                for block in layer:
                    if hasattr(block, 'rgb_adapter'):
                        for param in block.rgb_adapter.parameters():
                            param.requires_grad_(True)
                            
                        # Optionally unfreeze normalization parameters in adapters
                        if self.train_ln_with_adapters and hasattr(block.rgb_adapter, 'norm'):
                            for param in block.rgb_adapter.norm.parameters():
                                param.requires_grad_(True)
        
        # Keep fusion gate trainable if present
        if hasattr(self, 'rgb_gate'):
            for param in self.rgb_gate.parameters():
                param.requires_grad_(True)
                                
        logger.info("Froze base RGB encoder parameters, kept adapters trainable")

    def get_rgb_adapter_parameters(self):
        """Get all RGB adapter parameters for optimizer."""
        adapter_params = []
        if self.use_rgb_adapters:
            stage_layers = [self.backbone.layer1, self.backbone.layer2, self.backbone.layer3, self.backbone.layer4]
            
            for layer in stage_layers:
                for block in layer:
                    if hasattr(block, 'rgb_adapter'):
                        adapter_params.extend(block.rgb_adapter.parameters())
            
            # Include projection layer parameters
            if hasattr(self, 'projection_layer'):
                adapter_params.extend(self.projection_layer.parameters())
        
        # Include fusion gate parameters if enabled
        if hasattr(self, 'rgb_gate'):
            adapter_params.extend(self.rgb_gate.parameters())
                        
        return adapter_params

    def count_rgb_adapter_parameters(self):
        """Count trainable RGB adapter parameters."""
        if not self.use_rgb_adapters:
            return 0
            
        adapter_params = self.get_rgb_adapter_parameters()
        return sum(p.numel() for p in adapter_params if p.requires_grad)

    @property
    def is_blind(self):
        return self._n_input_rgb == 0

    def forward(self, observations):
        r"""Sends RGB observation through the TorchVision ResNet50 pre-trained
        on ImageNet. Sends through fully connected layer, activates, and
        returns final embedding.
        """

        def resnet_forward(observation):
            # resnet_output = torch.zeros(
            #     1, dtype=torch.float32, device=observation.device
            # )
            # def hook(m, i, o):
            #     resnet_output.set_(o)

            # output: [BATCH x RESNET_DIM]
            # h = self.layer_extract.register_forward_hook(hook)
            resnet_output = self.cnn(observation)
            # h.remove()
            return resnet_output

        if "rgb_features" in observations:
            resnet_output = observations["rgb_features"]
        else:
            # permute tensor to dimension [BATCH x CHANNEL x HEIGHT x WIDTH]
            rgb_observations = observations["rgb"].permute(0, 3, 1, 2)

            rgb_observations = self.rgb_transform(rgb_observations)
            # rgb_observations = rgb_observations / 255.0  # normalize RGB

            resnet_output = resnet_forward(rgb_observations.contiguous())

        # Apply optional corruption-aware fusion gate (scales features, preserves shape)
        if self.use_rgb_fusion_gate and hasattr(self, 'rgb_gate'):
            gate = self.rgb_gate(resnet_output)  # [B,1,1,1]
            resnet_output = resnet_output * gate

        if self.spatial_output:
            b, c, h, w = resnet_output.size()

            spatial_features = (
                self.spatial_embeddings(
                    torch.arange(
                        0,
                        self.spatial_embeddings.num_embeddings,
                        device=resnet_output.device,
                        dtype=torch.long,
                    )
                )
                .view(1, -1, h, w)
                .expand(b, self.spatial_embeddings.embedding_dim, h, w)
            )

            # Concatenate ResNet features with spatial features
            combined_features = torch.cat([resnet_output, spatial_features], dim=1)
            
            # Project to 512 dimensions using pre-initialized projection layer
            projected = self.projection_layer(combined_features)  # Should output [B, 512]
            
            # Ensure we have the right shape: [B, 512, 1, 1]
            if projected.dim() == 2:  # [B, 512]
                return projected.unsqueeze(-1).unsqueeze(-1)  # [B, 512, 1, 1]
            else:  # Already has spatial dimensions
                return projected
        else:
            # return self.activation(
            #     self.fc(torch.flatten(resnet_output, 1))
            # )  # [BATCH x OUTPUT_DIM]
            return resnet_output


class CLIPEncoder(nn.Module):
    r"""
    Takes in observations and produces an embedding of the rgb component.

    Args:
        observation_space: The observation_space of the agent
        output_size: The size of the embedding vector
        device: torch.device
    """

    def __init__(
        self, device,
    ):
        super().__init__()
        self.model, _ = clip.load("ViT-B/32", device=device)
        for param in self.model.parameters():
            param.requires_grad_(False)
        self.model.eval()

        from torchvision import transforms
        self.rgb_transform = torch.nn.Sequential(
            transforms.ConvertImageDtype(torch.float),
            transforms.Normalize([0.48145466, 0.4578275, 0.40821073], [0.26862954, 0.26130258, 0.27577711]),
            )

    def forward(self, observations):
        r"""Sends RGB observation through the TorchVision ResNet50 pre-trained
        on ImageNet. Sends through fully connected layer, activates, and
        returns final embedding.
        """
        rgb_observations = observations["rgb"].permute(0, 3, 1, 2)
        rgb_observations = self.rgb_transform(rgb_observations)
        output = self.model.encode_image(rgb_observations.contiguous())

        return output.float() # to fp32