"""
RAM++ with ADE20K Adapter

This module implements an adapter that maps RAM++ (4584 classes) to ADE20K (150 classes)
using a multi-layer perceptron. The original RAM++ backbone is frozen during training.
"""

import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, List, Optional, Union
import numpy as np

# Add parent directory to path to import RAM modules
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))

from ram.models import ram_plus
from ram import get_transform


class RAM_plus_ADE20K(nn.Module):
    """
    RAM++ model with ADE20K adapter
    
    Architecture:
    RAM++ (frozen) → [4584 logits] → MLP Adapter → [150 ADE20K logits] → predictions
    """
    
    def __init__(self, 
                 ram_plus_model: nn.Module = None,
                 freeze_backbone: bool = True,
                 adapter_dropout: float = 0.3,
                 threshold: float = 0.5):
        """
        Initialize RAM++ ADE20K adapter model
        
        Args:
            ram_plus_model: Pre-trained RAM++ model (if None, will load default)
            freeze_backbone: Whether to freeze RAM++ parameters
            adapter_dropout: Dropout rate for adapter layers
            threshold: Classification threshold for predictions
        """
        super().__init__()
        
        self.threshold = threshold
        self.freeze_backbone = freeze_backbone
        
        # Store RAM++ model
        if ram_plus_model is not None:
            self.ram_plus = ram_plus_model
        else:
            # Will be loaded later with load_pretrained method
            self.ram_plus = None
        
        # Freeze RAM++ parameters if specified
        if self.ram_plus is not None and freeze_backbone:
            for param in self.ram_plus.parameters():
                param.requires_grad = False
            print("RAM++ backbone frozen")
        
        # ADE20K class names (150 classes)
        self.ade20k_classes = [
            'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed', 'windowpane', 'grass',
            'cabinet', 'sidewalk', 'person', 'earth', 'door', 'table', 'mountain', 'plant', 'curtain', 'chair',
            'car', 'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug', 'field',
            'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe', 'lamp', 'bathtub', 'railing', 'cushion',
            'base', 'box', 'column', 'signboard', 'chest of drawers', 'counter', 'sand', 'sink', 'skyscraper', 'fireplace',
            'refrigerator', 'grandstand', 'path', 'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door', 'stairway',
            'river', 'bridge', 'bookcase', 'blind', 'coffee table', 'toilet', 'flower', 'book', 'hill', 'bench',
            'counter top', 'stove', 'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar', 'arcade machine', 'hovel',
            'bus', 'towel', 'light', 'truck', 'tower', 'chandelier', 'awning', 'streetlight', 'booth', 'television receiver',
            'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister', 'escalator', 'ottoman', 'bottle', 'buffet',
            'poster', 'stage', 'van', 'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything', 'swimming pool',
            'stool', 'barrel', 'basket', 'waterfall', 'tent', 'bag', 'minibike', 'cradle', 'oven', 'ball',
            'food', 'step', 'tank', 'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake', 'dishwasher',
            'screen', 'blanket', 'sculpture', 'hood', 'sconce', 'vase', 'traffic light', 'tray', 'ashcan', 'fan',
            'pier', 'crt screen', 'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass', 'clock', 'flag'
        ]
        
        self.num_classes = len(self.ade20k_classes)
        
        # MLP Adapter: 4585 → 150
        # Architecture: 4585 → 2048 → 1024 → 150
        self.adapter = nn.Sequential(
            nn.Linear(4585, 2048),
            nn.GELU(),
            nn.Dropout(adapter_dropout),
            nn.Linear(2048, 1024), 
            nn.GELU(),
            nn.Dropout(adapter_dropout),
            nn.Linear(1024, self.num_classes)
        )
        
        print(f"Initialized RAM++ ADE20K adapter")
        print(f"Input dimension: 4585 (RAM++ classes)")
        print(f"Output dimension: {self.num_classes} (ADE20K classes)")
        print(f"Adapter dropout: {adapter_dropout}")
        print(f"Classification threshold: {threshold}")
    
    def _get_ram_logits(self, image: torch.Tensor) -> torch.Tensor:
        """
        Extract logits from RAM++ model
        
        Args:
            image: Input image tensor [batch_size, 3, 384, 384]
            
        Returns:
            RAM++ logits [batch_size, 4584]
        """
        if self.ram_plus is None:
            raise RuntimeError("RAM++ model not loaded. Use load_pretrained() first.")
        
        # Get image embeddings
        image_embeds = self.ram_plus.image_proj(self.ram_plus.visual_encoder(image))
        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
        
        image_cls_embeds = image_embeds[:, 0, :]
        bs = image_cls_embeds.shape[0]
        
        # Multi-tag description reweighting
        des_per_class = int(self.ram_plus.label_embed.shape[0] / self.ram_plus.num_class)
        image_cls_embeds = image_cls_embeds / image_cls_embeds.norm(dim=-1, keepdim=True)
        reweight_scale = self.ram_plus.reweight_scale.exp()
        logits_per_image = (reweight_scale * image_cls_embeds @ self.ram_plus.label_embed.t())
        logits_per_image = logits_per_image.view(bs, -1, des_per_class)
        
        weight_normalized = F.softmax(logits_per_image, dim=2)
        label_embed_reweight = torch.empty(bs, self.ram_plus.num_class, 512).to(image.device).to(image.dtype)
        
        for i in range(bs):
            reshaped_value = self.ram_plus.label_embed.view(-1, des_per_class, 512)
            product = weight_normalized[i].unsqueeze(-1) * reshaped_value
            label_embed_reweight[i] = product.sum(dim=1)
        
        label_embed = torch.nn.functional.relu(self.ram_plus.wordvec_proj(label_embed_reweight))
        
        # Get tagging logits
        tagging_embed = self.ram_plus.tagging_head(
            encoder_embeds=label_embed,
            encoder_hidden_states=image_embeds,
            encoder_attention_mask=image_atts,
            return_dict=False,
            mode='tagging',
        )
        
        logits = self.ram_plus.fc(tagging_embed[0]).squeeze(-1)
        return logits
    
    def forward(self, image: torch.Tensor) -> torch.Tensor:
        """
        Forward pass
        
        Args:
            image: Input image tensor [batch_size, 3, 384, 384]
            
        Returns:
            ADE20K logits [batch_size, 150]
        """
        # Get RAM++ logits (frozen)
        if self.freeze_backbone:
            with torch.no_grad():
                ram_logits = self._get_ram_logits(image)
        else:
            ram_logits = self._get_ram_logits(image)
        
        # Apply adapter to map to ADE20K classes
        ade20k_logits = self.adapter(ram_logits)
        return ade20k_logits
    
    def predict(self, 
                image: torch.Tensor, 
                threshold: Optional[float] = None) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Make predictions on images
        
        Args:
            image: Input image tensor [batch_size, 3, 384, 384]
            threshold: Classification threshold (uses self.threshold if None)
            
        Returns:
            predictions: Binary predictions [batch_size, 150]
            probabilities: Sigmoid probabilities [batch_size, 150]
        """
        if threshold is None:
            threshold = self.threshold
        
        # Get logits and convert to probabilities
        logits = self.forward(image)
        probabilities = torch.sigmoid(logits)
        
        # Apply threshold to get binary predictions
        predictions = (probabilities > threshold).float()
        
        return predictions, probabilities
    
    def predict_classes(self, 
                       image: torch.Tensor, 
                       threshold: Optional[float] = None) -> List[List[str]]:
        """
        Predict class names for each image in batch
        
        Args:
            image: Input image tensor [batch_size, 3, 384, 384]
            threshold: Classification threshold
            
        Returns:
            List of predicted class names for each image
        """
        predictions, _ = self.predict(image, threshold)
        
        batch_classes = []
        for i in range(predictions.shape[0]):
            image_predictions = predictions[i]
            predicted_indices = torch.where(image_predictions == 1)[0]
            predicted_classes = [self.ade20k_classes[idx] for idx in predicted_indices]
            batch_classes.append(predicted_classes)
        
        return batch_classes


def load_ram_plus_ade20k_pretrained(ram_plus_checkpoint: str,
                                   ade20k_adapter_checkpoint: Optional[str] = None,
                                   image_size: int = 384,
                                   vit: str = 'swin_l',
                                   freeze_backbone: bool = True,
                                   device: str = 'cuda') -> RAM_plus_ADE20K:
    """
    Load pretrained RAM++ ADE20K model
    
    Args:
        ram_plus_checkpoint: Path to RAM++ pretrained weights
        ade20k_adapter_checkpoint: Path to trained adapter weights (optional)
        image_size: Input image size
        vit: Vision transformer architecture
        freeze_backbone: Whether to freeze RAM++ backbone
        device: Device to load model on
        
    Returns:
        Loaded RAM++ ADE20K model
    """
    print(f"Loading RAM++ model from {ram_plus_checkpoint}")
    
    # Load RAM++ model
    ram_model = ram_plus(
        pretrained=ram_plus_checkpoint,
        image_size=image_size,
        vit=vit
    )
    ram_model.eval()
    
    # Create ADE20K adapter model
    model = RAM_plus_ADE20K(
        ram_plus_model=ram_model,
        freeze_backbone=freeze_backbone
    )
    
    # Load adapter weights if provided
    if ade20k_adapter_checkpoint and os.path.exists(ade20k_adapter_checkpoint):
        print(f"Loading adapter weights from {ade20k_adapter_checkpoint}")
        checkpoint = torch.load(ade20k_adapter_checkpoint, map_location='cpu')
        
        if 'adapter_state_dict' in checkpoint:
            model.adapter.load_state_dict(checkpoint['adapter_state_dict'])
        elif 'state_dict' in checkpoint:
            # Try to load from full model state dict
            adapter_state_dict = {k.replace('adapter.', ''): v 
                                for k, v in checkpoint['state_dict'].items() 
                                if k.startswith('adapter.')}
            model.adapter.load_state_dict(adapter_state_dict)
        else:
            # Assume checkpoint contains adapter state dict directly
            model.adapter.load_state_dict(checkpoint)
        
        print("Adapter weights loaded successfully")
    else:
        print("No adapter checkpoint provided - using randomly initialized adapter")
    
    model = model.to(device)
    model.eval()
    
    print(f"Model loaded on device: {device}")
    return model


# Test function
if __name__ == "__main__":
    # Test model initialization
    print("Testing RAM++ ADE20K model initialization...")
    
    # Create model without loading RAM++
    model = RAM_plus_ADE20K()
    print(f"Model created successfully")
    print(f"Number of ADE20K classes: {model.num_classes}")
    print(f"Adapter architecture: {model.adapter}")
    
    # Test adapter forward pass with dummy input
    dummy_ram_logits = torch.randn(2, 4584)  # Batch size 2
    ade20k_logits = model.adapter(dummy_ram_logits)
    print(f"Adapter test - Input shape: {dummy_ram_logits.shape}")
    print(f"Adapter test - Output shape: {ade20k_logits.shape}")
    
    print("✅ Model initialization test passed!")