import torch.nn as nn
import torch
import torch.nn.functional as F
from transformers import ViTImageProcessor, ViTForImageClassification, ViTConfig
from einops import repeat, rearrange


class CrossAttentionLayer(nn.Module):
    def __init__(self, d_model, nhead=8, dropout=0.1):
        super().__init__()
        self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, query, key, value):
        # query: [batch, seq_len_q, d_model]
        # key, value: [batch, seq_len_kv, d_model]
        attn_out, _ = self.multihead_attn(query, key, value)
        out = self.norm(query + self.dropout(attn_out))
        return out


class EmbeddingAdapter(nn.Module):
    def __init__(self, emb_dim=768, camera_dim=25):
        super().__init__()
        
        self.fc = nn.Sequential(
            nn.Linear(emb_dim, emb_dim),
            nn.GELU(),
            nn.Linear(emb_dim, emb_dim)
        )

        self.fc2 = nn.Sequential(
            nn.Linear(camera_dim, emb_dim),
            nn.GELU(),
            nn.Linear(emb_dim, emb_dim)
        )

    def forward(self, texts, cameras, gamma):
        # TODO: Add camera embedding
        emb_transformed = self.fc(texts)
        emb_camera = self.fc2(cameras)
        texts = texts + gamma[0] * emb_transformed + gamma[1] * emb_camera 
        texts = repeat(texts, 'n c -> n b c', b=1)
        return texts


class CEA(nn.Module):
    def __init__(self, 
        emb_dim: int=768, 
        train_from_scratch: bool=False, 
        no_of_classes: int=100,
        feature_dim: int=768,
        num_latent_queries: int=64
    ):
        super().__init__()
        VIT_MODEL = 'google/vit-base-patch16-224'
        self.vit_processor = ViTImageProcessor.from_pretrained(VIT_MODEL, resume_download=True)
        if train_from_scratch:
            vit_config = ViTConfig(num_labels=1000)
            self.vit_model = ViTForImageClassification(vit_config)
        else:
            self.vit_model = ViTForImageClassification.from_pretrained(VIT_MODEL, resume_download=True)
        for param in self.vit_model.parameters():
            param.requires_grad = False
        
        self.fc = nn.Sequential(
            nn.Linear(1000, 400),
            nn.GELU(),
            nn.Linear(400, no_of_classes)
        )
        self.dim = emb_dim
        self.feature_dim = feature_dim
        self.m = nn.Softmax(dim=1)
        
        self.embeddings = nn.Parameter(torch.randn(no_of_classes, self.dim))
        self.embedding_adapter = EmbeddingAdapter(emb_dim=self.dim)
        
        # CEA specific components
        # Linear projections for Q, K, V
        self.W_Q = nn.Linear(self.dim, self.feature_dim)
        self.W_K = nn.Linear(self.feature_dim, self.feature_dim) 
        self.W_V = nn.Linear(self.feature_dim, self.feature_dim)
        
        # Cross-attention for entropy-aware refinement
        self.entropy_cross_attn = CrossAttentionLayer(self.feature_dim)
        
        # Learnable latent queries for cross-view attention
        self.num_latent_queries = num_latent_queries
        self.latent_queries = nn.Parameter(torch.randn(num_latent_queries, self.feature_dim))
        
        # Cross-view Perceiver attention
        self.cross_view_attn = CrossAttentionLayer(self.feature_dim)
        
        self.gamma = nn.Parameter(torch.ones(2, self.dim) * 1e-4)
    
    def forward(self, images, costvolume):
        """
        Forward pass implementing CEA (Cost-volume Enhanced Attention) mechanism.
        
        Args:
            images: List or tensor of input images [B, 3, H, W] or [V, B, 3, H, W] for multi-view
            costvolume: Cost volume tensor [B, D, H, W] or [V, B, D, H, W] for multi-view
            cameras: Camera parameters for embedding adaptation
            
        Returns:
            CEA embeddings: [B, num_latent_queries, feature_dim]
        """
        # Handle single view or multi-view input
        if isinstance(images, list):
            # Multi-view case
            batch_size = images[0].shape[0]
            num_views = len(images)
            
            per_view_embeddings = []
            
            for i, (view_img, view_cost) in enumerate(zip(images, costvolume)):
                # Process each view
                view_embedding = self._process_single_view(view_img, view_cost)
                per_view_embeddings.append(view_embedding)
            
            # Stack per-view embeddings
            per_view_embeddings = torch.stack(per_view_embeddings, dim=1)  # [B, V, feature_dim]
            
        else:
            # Single view case - treat as multi-view with V=1
            batch_size = images.shape[0]
            if len(images.shape) == 5:  # [V, B, 3, H, W]
                num_views = images.shape[0] 
                per_view_embeddings = []
                
                for i in range(num_views):
                    view_embedding = self._process_single_view(images[i], costvolume[i], cameras)
                    per_view_embeddings.append(view_embedding)
                    
                per_view_embeddings = torch.stack(per_view_embeddings, dim=1)  # [B, V, feature_dim]
            else:
                # Single view [B, 3, H, W]
                view_embedding = self._process_single_view(images, costvolume, cameras)
                per_view_embeddings = view_embedding.unsqueeze(1)  # [B, 1, feature_dim]
        
        # Cross-view Perceiver-style attention
        # Reshape embeddings for attention: [B*V, feature_dim] -> [B, V*feature_dim]
        B, V, D = per_view_embeddings.shape
        keys_values = per_view_embeddings.reshape(B, -1, D)  # [B, V, feature_dim]
        
        # Expand latent queries for batch dimension
        latent_queries = self.latent_queries.unsqueeze(0).expand(B, -1, -1)  # [B, num_latent_queries, feature_dim]
        
        # Apply cross-view attention
        cea_embeddings = self.cross_view_attn(
            query=latent_queries,
            key=keys_values, 
            value=keys_values
        )  # [B, num_latent_queries, feature_dim]
        
        return cea_embeddings

    def _process_single_view(self, image, cost_volume):
        """
        Process a single view to get entropy-aware class embedding.
        
        Args:
            image: [B, 3, H, W]
            cost_volume: [B, D, H, W] 
            cameras: Camera parameters
            
        Returns:
            Refined class embedding: [B, feature_dim]
        """
        batch_size = image.shape[0]
        
        # 1. Get ViT features and class embeddings
        y = self.pad_to_make_square(image)
        
        with torch.no_grad():
            inputs = self.vit_processor(images=y, return_tensors="pt").to(image.device)
            vit_outputs = self.vit_model(**inputs)
            vit_logits = vit_outputs.logits
            
            # Extract feature maps from ViT (we need to get intermediate features)
            # For now, use a simple approach - you might need to modify this based on your ViT model
            # Here we simulate feature maps by reshaping and projecting the final features
            vit_features = vit_outputs.last_hidden_state[:, 1:]  # Remove CLS token, [B, HW, 768]
        
        # 2. Compute class probabilities and embeddings  
        class_probs = self.fc(vit_logits)
        class_probs = self.m(class_probs)
        class_embeddings = class_probs @ self.embeddings  # [B, feature_dim]
        
        # 3. Compute matching entropy from cost volume
        if cost_volume is not None:
            # Apply softmax along depth dimension to get depth posterior
            depth_posterior = F.softmax(cost_volume, dim=1)  # [B, D, H, W]
            
            # Compute entropy: H = -sum(p * log(p))
            eps = 1e-8
            log_posterior = torch.log(depth_posterior + eps)
            entropy = -torch.sum(depth_posterior * log_posterior, dim=1)  # [B, H, W]
            
            # Normalize entropy to get weights
            entropy_flat = entropy.flatten(start_dim=1)  # [B, HW]
            entropy_normalized = F.softmax(entropy_flat, dim=1)  # [B, HW]
            weights = entropy_normalized.reshape(entropy.shape)  # [B, H, W]
            
            # 4. Apply entropy weighting to features
            # Reshape ViT features to match spatial dimensions
            H, W = weights.shape[-2:]
            patch_size = int((vit_features.shape[1]) ** 0.5)  # Assume square patches
            
            if vit_features.shape[1] != H * W:
                # Interpolate features to match cost volume spatial dimensions
                vit_features_spatial = vit_features.transpose(1, 2).reshape(batch_size, -1, patch_size, patch_size)
                vit_features_spatial = F.interpolate(vit_features_spatial, size=(H, W), mode='bilinear', align_corners=False)
                vit_features_spatial = vit_features_spatial.flatten(start_dim=2).transpose(1, 2)  # [B, HW, feature_dim]
            else:
                vit_features_spatial = vit_features
            
            # Apply entropy weighting
            weights_expanded = weights.flatten(start_dim=1).unsqueeze(-1)  # [B, HW, 1]
            weighted_features = weights_expanded * vit_features_spatial  # [B, HW, feature_dim]
        else:
            # If no cost volume, use original features
            weighted_features = vit_features  # [B, HW, feature_dim]
        
        # 5. Project to Q, K, V for cross-attention
        class_queries = self.W_Q(class_embeddings).unsqueeze(1)  # [B, 1, feature_dim]
        feature_keys = self.W_K(weighted_features)  # [B, HW, feature_dim]  
        feature_values = self.W_V(weighted_features)  # [B, HW, feature_dim]
        
        # 6. Apply cross-attention to refine class embedding
        refined_embedding = self.entropy_cross_attn(
            query=class_queries,
            key=feature_keys,
            value=feature_values
        ).squeeze(1)  # [B, feature_dim]
        
        return refined_embedding

    @staticmethod
    def pad_to_make_square(x):
        y = 255 * ((x + 1) / 2)
        y = torch.permute(y, (0, 2, 3, 1))  # (bs, h, w, 3)
        bs, _, h, w = x.shape
        if w > h:
            patch = torch.zeros(bs, w-h, w, 3).to(x.device)
            y = torch.cat([y, patch], axis=1)
        else:
            patch = torch.zeros(bs, h, h-w, 3).to(x.device)
            y = torch.cat([y, patch], axis=2)
        return y.to(torch.int)

class DiffusionDecoder(nn.Module):
    def __init__(self, 
        in_channels, 
        out_channels,
        num_deconv=2,
        num_filters=[32, 32],
        deconv_kernels=[2, 2]
    ):
        super().__init__()
        self.deconv = num_deconv
        self.in_channels = in_channels
        self.deconv_layers = self._make_deconv_layer(
            num_deconv,
            num_filters,
            deconv_kernels,
        )
    
        conv_layers = []
        conv_layers.append(
            nn.Conv2d(
                in_channels=num_filters[-1],
                out_channels=out_channels,
                kernel_size=3,
                stride=1,
                padding=1))
        conv_layers.append(nn.BatchNorm2d(out_channels))
        conv_layers.append(nn.ReLU(inplace=True))
        self.conv_layers = nn.Sequential(*conv_layers)
        
        # self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        self.init_weights()

    def forward(self, conv_feats):
        out = self.deconv_layers(conv_feats[0])
        out = self.conv_layers(out)

        # out = self.up(out)
        # out = self.up(out)

        return out

    def _make_deconv_layer(self, num_layers, num_filters, num_kernels):
        """Make deconv layers."""
        
        layers = []
        in_planes = self.in_channels
        for i in range(num_layers):
            kernel, padding, output_padding = self._get_deconv_cfg(num_kernels[i])

            planes = num_filters[i]
            layers.append(
                nn.ConvTranspose2d(
                    in_channels=in_planes,
                    out_channels=planes,
                    kernel_size=kernel,
                    stride=2,
                    padding=padding,
                    output_padding=output_padding,
                    bias=False))
            layers.append(nn.BatchNorm2d(planes))
            layers.append(nn.ReLU(inplace=True))
            in_planes = planes

        return nn.Sequential(*layers)

    def _get_deconv_cfg(self, deconv_kernel):
        """Get configurations for deconv layers."""
        if deconv_kernel == 4:
            padding = 1
            output_padding = 0
        elif deconv_kernel == 3:
            padding = 1
            output_padding = 1
        elif deconv_kernel == 2:
            padding = 0
            output_padding = 0
        else:
            raise ValueError(f'Not supported num_kernels ({deconv_kernel}).')

        return deconv_kernel, padding, output_padding

    def init_weights(self):
        """Initialize model weights."""
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.ConvTranspose2d):
                nn.init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
