import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import ViTModel

class ViTUNetForImageReconstruction(nn.Module):
    def __init__(self, pretrained_model_name):
        super().__init__()
        self.vit = ViTModel.from_pretrained(pretrained_model_name)
        
        # Define feature dimensions - these would need to be adjusted based on the ViT model
        self.hidden_dim = 768  # For base ViT model
        
        # Adaptive pooling to get spatial features of appropriate size
        self.adaptive_pool1 = nn.AdaptiveAvgPool2d((60, 60))  # For early features
        self.adaptive_pool2 = nn.AdaptiveAvgPool2d((30, 30))  # For middle features
        self.adaptive_pool3 = nn.AdaptiveAvgPool2d((15, 15))  # For late features
        self.adaptive_pool_final = nn.AdaptiveAvgPool2d((15, 15))  # For final features
        
        # Skip connection processing for each level
        self.skip_conn1 = nn.Conv2d(self.hidden_dim, 128, kernel_size=1)  # For 60x60 features -> 120x120 upsampling
        self.skip_conn2 = nn.Conv2d(self.hidden_dim, 256, kernel_size=1)  # For 30x30 features -> 60x60 upsampling
        self.skip_conn3 = nn.Conv2d(self.hidden_dim, 512, kernel_size=1)  # For 15x15 features -> 30x30 upsampling
        
        # Decoder with skip connections
        # First upsampling block (15x15 -> 30x30)
        self.up1 = nn.Sequential(
            nn.Upsample(size=(30, 30), mode='bilinear', align_corners=True),
            nn.Conv2d(self.hidden_dim, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True)
        )
        
        # Second upsampling block (30x30 -> 60x60)
        self.up2 = nn.Sequential(
            nn.Upsample(size=(60, 60), mode='bilinear', align_corners=True),
            nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )
        
        # Third upsampling block (60x60 -> 120x120)
        self.up3 = nn.Sequential(
            nn.Upsample(size=(120, 120), mode='bilinear', align_corners=True),
            nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )
        
        # Final upsampling blocks
        self.up4 = nn.Sequential(
            nn.Upsample(size=(240, 240), mode='bilinear', align_corners=True),
            nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Upsample(size=(449, 449), mode='bilinear', align_corners=True),
            nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1),
            nn.Sigmoid()
        )

    def extract_intermediate_features(self, pixel_values):
        """Extract features from different layers of the ViT model"""
        # Get outputs from different transformer blocks
        outputs = self.vit(pixel_values, output_hidden_states=True)
        
        # Last hidden state (output from final layer)
        final_layer = outputs.last_hidden_state
        
        # Intermediate features (from earlier layers)
        # These indices can be adjusted based on which layers work best
        early_layer = outputs.hidden_states[3]  # 3rd layer
        middle_layer = outputs.hidden_states[6]  # 6th layer
        late_layer = outputs.hidden_states[9]   # 9th layer
        
        return early_layer, middle_layer, late_layer, final_layer
    
    def reshape_features(self, features):
        """Reshape features to have spatial dimensions"""
        batch_size, num_patches, hidden_dim = features.shape
        height_width = int((num_patches - 1) ** 0.5)  # Adjust for CLS token
        
        # Remove CLS token and reshape to spatial dimensions
        spatial_features = features[:, 1:, :].permute(0, 2, 1)
        spatial_features = spatial_features.view(batch_size, hidden_dim, height_width, height_width)
        
        return spatial_features

    def forward(self, pixel_values):
        # Extract multi-level features
        early_features, middle_features, late_features, final_features = self.extract_intermediate_features(pixel_values)
        
        # Reshape features to spatial form (batch, channels, height, width)
        early_spatial = self.reshape_features(early_features)
        middle_spatial = self.reshape_features(middle_features)
        late_spatial = self.reshape_features(late_features)
        final_spatial = self.reshape_features(final_features)
        
        # Apply adaptive pooling to get appropriate sizes for skip connections
        skip1 = self.adaptive_pool1(early_spatial)  # 60x60
        skip2 = self.adaptive_pool2(middle_spatial)  # 30x30
        skip3 = self.adaptive_pool3(late_spatial)  # 15x15
        
        # Process final features (deepest encoder output)
        x = self.adaptive_pool_final(final_spatial)  # 15x15
        
        # First upsampling block (15x15 -> 30x30)
        x = self.up1(x)  # Now 30x30
        
        # Add skip connection from late_spatial (already at 15x15, need to upsample)
        skip3_processed = self.skip_conn3(skip3)  # Process channels
        skip3_upsampled = F.interpolate(skip3_processed, size=(30, 30), mode='bilinear', align_corners=True)
        x = x + skip3_upsampled  # Add skip connection
        
        # Second upsampling block (30x30 -> 60x60)
        x = self.up2(x)  # Now 60x60
        
        # Add skip connection from middle_spatial (already at 30x30, need to upsample)
        skip2_processed = self.skip_conn2(skip2)  # Process channels
        skip2_upsampled = F.interpolate(skip2_processed, size=(60, 60), mode='bilinear', align_corners=True)
        x = x + skip2_upsampled  # Add skip connection
        
        # Third upsampling block (60x60 -> 120x120)
        x = self.up3(x)  # Now 120x120
        
        # Add skip connection from early_spatial (already at 60x60, need to upsample)
        skip1_processed = self.skip_conn1(skip1)  # Process channels
        skip1_upsampled = F.interpolate(skip1_processed, size=(120, 120), mode='bilinear', align_corners=True)
        x = x + skip1_upsampled  # Add skip connection
        
        # Final upsampling to target size (120x120 -> 449x449)
        x = self.up4(x)
        
        return x

# Test the model
if __name__ == "__main__":
    # Load model
    model = ViTUNetForImageReconstruction(pretrained_model_name='google/vit-base-patch16-224-in21k')

    # Move model to device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # Dummy input (match ViT input size: 224x224)
    dummy_input = torch.randn(2, 3, 224, 224).to(device)  # batch of 2 images

    # Dummy target (match final output size: 449x449)
    dummy_target = torch.randn(2, 3, 449, 449).to(device)

    # Define loss and optimizer
    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    # Forward pass
    output = model(dummy_input)
    
    print(f"Input shape: {dummy_input.shape}")
    print(f"Output shape: {output.shape}")
    print(f"Target shape: {dummy_target.shape}")

    # Compute loss
    loss = criterion(output, dummy_target)

    # Backward pass
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print("Backpropagation ran successfully. Loss:", loss.item())