import torch
import torch.nn as nn
from torchvision.models.vision_transformer import VisionTransformer
import torch.nn.functional as F

from .forward import forward_pruned_layers, forward_pruned_encoder

import types 

from torch import Tensor
from typing import Optional, List, Tuple

class ViTEnsembleLinear(nn.Module):
    def __init__(
            self,
            num_models: int,
            weights: Tensor,
            bias: Optional[Tensor] = None
        ):
        """
        A linear layer for an ensemble of Vision Transformers (ViTs) where
        the main difference is that it supports multiple models with their own weights and biases.

        Args:
            num_models (int): Number of ensemble models.
            weights (Tensor): Weight tensor of shape (num_models, out_features, in_features).
            bias (Optional[Tensor]): Optional bias tensor of shape (num_models, out_features).
        """
        super().__init__()
        self.num_models = num_models
        self.in_features = min(weights[0].shape)
        self.out_features = max(weights[0].shape)

        #weights = weights.unsqueeze(0)
        self.weight = nn.Parameter(weights)

        if bias is not None: 
            bias = bias.unsqueeze(0)
            self.bias = nn.Parameter(bias)

    def forward(
            self, 
            x: Tensor
        ) -> Tensor:
        """
        Assumes input x is of shape (B, num_models, in_features, T) and returns output 
        of shape (B, num_models, out_features, T).
        """
        #out = torch.matmul(self.weight, x)  # → [B, num_models, out_features, T]
        out = torch.einsum('bmft,mof->bmot', x, self.weight)
        if self.bias is not None:
            out = out + self.bias  
        out = out.transpose(-1, -2)
        return out
    
    def __repr__(self):
        return f"ViTEnsembleLinear(num_models={self.num_models},in_features={self.in_features}, out_features={self.out_features})"
    
class EnsembleMHAPrunedSeq(nn.Module):
    """
    Applies a sequence of pruned Multihead Attention (MHA) blocks to segments of an input sequence,
    each handled by a separate MHA block in an ensemble. 

    Args:
        attention_blocks (List[nn.Module]): List of attention blocks, each expected to take a chunk of the sequence.
    Warning:
        This is a simplified and old implementation and it probably won't be used in the future.
        It is kept mostly for reference.
    """
    def __init__(self, attention_blocks):
        super().__init__()
        self.num_models = len(attention_blocks) 
        self.attention_blocks = nn.ModuleList(attention_blocks)

    def forward(self, x, k=None, v=None, need_weights=False):
        # Use x as k and v if not provided
        seq_length = x.shape[1] // self.num_models 
        k = k if k is not None else x
        v = v if v is not None else x

        outputs_ensemble = []
        attn_probs_ensemble = []
        for i, attn in enumerate(self.attention_blocks):
            mha_block_input = x[:, i * seq_length:(i + 1) * seq_length, :]
            out, attn_probs = attn(mha_block_input, k, v, need_weights=False)
            outputs_ensemble.append(out)
            attn_probs_ensemble.append(attn_probs.unsqueeze(1))

        # Average across all outputs
        output = torch.concat(outputs_ensemble, dim=1)
        attn_probs = torch.concat(attn_probs_ensemble, dim=1)
        return output, attn_probs 
    
class EnsembleMHAPruned(nn.Module):
    def __init__(
            self, 
            attention_blocks: List[nn.Module]
        ):
        """
        Multi-head attention ensemble module that applies pruned linear projections
        from multiple pre-trained attention blocks.

        Args:
            attention_blocks (List[nn.Module]): List of attention blocks (e.g., from Pruned MultiheadAttention),
                each having `in_proj_weight`, `in_proj_bias`, and `out_proj` layers.
        """
        super().__init__()
        self.num_models = len(attention_blocks) 
        
        # Extract weights and biases from attention blocks
        in_proj_weight = [attn.in_proj_weight for attn in attention_blocks]
        in_proj_bias = [attn.in_proj_bias for attn in attention_blocks]
        out_proj_weight = [attn.out_proj.weight for attn in attention_blocks]
        out_proj_bias = [attn.out_proj.bias for attn in attention_blocks]

        # Assume shared heads and dimensions
        self.num_heads = attention_blocks[0].num_heads
        self.head_dim = attention_blocks[0].head_dim
        self.pruned_embed_dim = in_proj_weight[0].shape[0] // 3 # stacked dim / 3 matrices (qkv) 
        self.embed_dim = in_proj_weight[0].shape[1]

        del attention_blocks

        # Stack and reshape for ensemble
        in_proj_weight = torch.cat(in_proj_weight, dim=0).reshape(self.num_models, self.pruned_embed_dim*3, self.embed_dim)
        in_proj_bias = torch.cat(in_proj_bias, dim=0).reshape(self.num_models, self.pruned_embed_dim*3, 1)
        out_proj_weight = torch.cat(out_proj_weight, dim=0).reshape(self.num_models, self.embed_dim, self.pruned_embed_dim)
        out_proj_bias = torch.cat(out_proj_bias, dim=0).reshape(self.num_models, -1, 1)

        self.in_proj = ViTEnsembleLinear(self.num_models, in_proj_weight, in_proj_bias)
        self.out_proj = ViTEnsembleLinear(self.num_models, out_proj_weight, out_proj_bias)
        
    def forward(
            self,
            x: Tensor,
            k: Optional[Tensor] = None,
            v: Optional[Tensor] = None,
            need_weights: bool = False
        ) -> Tuple[Tensor, Tensor]:
        """
        Args:
            x (Tensor): Input tensor of shape (B, num_models, T, D)
            k (Optional[Tensor]): Ignored, projections are handled internally
            v (Optional[Tensor]): Ignored, projections are handled internally
            need_weights (bool): Ignored, always returns empty attention weights

        Returns:
            Tuple[Tensor, Tensor]:
                - Output tensor of shape (B, num_models, T, embed_dim)
                - Empty tensor for compatibility (attention weights placeholder)
        """
        # Use x as k and v if not provided
        B, M, T, D = x.shape
        x = x.transpose(-1,-2)

        qkv_blocks = self.in_proj(x)
        qkv_blocks = qkv_blocks.view(B, self.num_models, T, 3, self.num_heads, self.head_dim)
        qkv_blocks = qkv_blocks.permute(3, 0, 1, 4, 2, 5)  # (3, B, num_models, num_heads, seq_length, head_dim)
        q, k, v = qkv_blocks[0], qkv_blocks[1], qkv_blocks[2]
        
        # pytorch will scale it automatically
        attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0) 
        
        attn_output = attn_output.permute(0, 1, 2, 4, 3).reshape(B, self.num_models, -1, T)
        
        out = self.out_proj(attn_output)
        return out, torch.empty(0) 
    
    def __repr__(self):
        return (
            f"EnsembleMHAPruned(num_models={self.num_models}, "
            f"embed_dim={self.embed_dim}, num_heads={self.num_heads}, "
            f"pruned_embed_dim={self.pruned_embed_dim})"
        )
        
class ViTEnsemblePruned(VisionTransformer):
    def __init__(
            self,
            num_heads_to_prune: int,
            num_models: int,
            image_size: int = 224,
            patch_size: int = 16,
            num_layers: int = 12,
            num_heads: int = 12,
            hidden_dim: int = 768,
            mlp_dim: int = 3072,
        ):
        """
        Vision Transformer ensemble with pruned multi-head self-attention.

        Each transformer layer replaces its MHA block with an ensemble of pruned attention blocks
        taken from several ViT models. MLPs are reused from one of the models (assumed identical).

        Args:
            num_heads_to_prune (int): Number of attention heads pruned per model.
            num_models (int): Number of pruned ViT models in the ensemble.
            image_size (int): Input image size (default 224).
            patch_size (int): Patch size (default 16).
            num_layers (int): Number of transformer layers (default 12).
            num_heads (int): Number of attention heads per layer (default 12).
            hidden_dim (int): Hidden embedding dimension (default 768).
            mlp_dim (int): MLP intermediate dimension (default 3072).
        """
        super().__init__(image_size, patch_size, num_layers,
                         num_heads, hidden_dim, mlp_dim)
        self.num_heads_pruned = (num_heads * num_models) - (num_heads_to_prune * num_models)
        self.num_heads_to_prune = num_heads_to_prune
        self.num_heads_ensemble = num_heads * num_models
        self.num_models = num_models

    def create_ensemble(
            self, 
            pruned_models: List[VisionTransformer]
        ) -> None:
        """
        Builds attention ensembles by collecting pruned attention blocks from provided models.
        The MLP of each layer is assumed to be the same across all models.

        Args:
            pruned_models (List[VisionTransformer]): List of ViT models with pruned MHA layers.
        """
    
        mha_attention_blocks = {layer_idx: [] for layer_idx in range(len(self.encoder.layers))}
        mlp_blocks = {layer_idx: [] for layer_idx in range(len(self.encoder.layers))}
        for model in pruned_models:
            for layer_idx, layer in enumerate(model.encoder.layers):
                mha_attention_blocks[layer_idx].append(layer.self_attention)
                mlp_blocks[layer_idx].append(layer.mlp)

        # Overwrite original attention with ensemble
        for idx, layer in enumerate(self.encoder.layers):
            self.encoder.layers[idx].self_attention = EnsembleMHAPruned(mha_attention_blocks[idx])
            self.encoder.layers[idx].mlp = mlp_blocks[idx][0] # assume all models have the same MLP block

        self.encoder.layers.forward = types.MethodType(forward_pruned_layers, 
                                                       self.encoder.layers)
        for i in range(len(self.encoder.layers)):
            self.encoder.layers[i].forward = types.MethodType(forward_pruned_encoder,
                                                              self.encoder.layers[i])
    
    def forward(
            self, 
            x: torch.Tensor
            ) -> torch.Tensor:
        """
        Forward pass through the ensemble of pruned Vision Transformers.
        It differs from the original one only for the last 2 lines of the classification head.
        """
        x = self._process_input(x)
        B, T, E = x.shape

        batch_class_token = self.class_token.expand(B, -1, -1)
        x = torch.cat([batch_class_token, x], dim=1)

        x = self.encoder(x)

        # Get class tokens per model
        x = x.reshape(B, self.num_models, T + 1, -1)[:, :, 0]

        x = self.heads(x).mean(dim=1)

        return x