import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, List, Optional, Union
from dataclasses import dataclass, field
import copy
import types
from enum import Enum

@dataclass
class BlockHadamardHiRAConfig:
    """
    Configuration class for Block Hadamard HiRA (Block-wise Hadamard high-rank adaptation).
    
    Block Hadamard HiRA formulation: ∆W^{(i,j)} = W_0^{(i,j)} ⊙ (A^{(i,j)} · B^{(i,j)})
    where:
    - W_0^{(i,j)} is the (i,j)-th block of the frozen pre-trained weight matrix
    - A^{(i,j)} and B^{(i,j)} are low-rank matrices for block (i,j)
    - ⊙ denotes the element-wise (Hadamard) product
    - The weight matrix is divided into num_blocks × num_blocks blocks
    
    Performance Optimizations:
    - Uses batched GEMM operations (torch.bmm) for all block matrix multiplications
    - Vectorized Hadamard products for all blocks simultaneously
    - Efficient block extraction and reconstruction using view/permute operations
    - Memory-efficient tensor layouts for GPU acceleration
    """
    r: int = 8  # Rank for low-rank decomposition per block
    alpha: float = 8  # Scaling factor
    dropout: float = 0.0
    target_modules: Optional[Union[List[str], str]] = None
    bias: str = "none"
    modules_to_save: Optional[List[str]] = None
    init_lora_weights: bool = True
    num_blocks: int = 4  # Number of blocks per dimension (e.g., 4x4 = 16 blocks total)
    block_arrangement: str = "square"  # "square" for num_blocks x num_blocks arrangement
    use_fast_inference: bool = True  # Enable optimized inference path
    
    def __post_init__(self):
        if self.target_modules is None:
            raise ValueError("target_modules cannot be None")
        if self.num_blocks < 1:
            raise ValueError("num_blocks must be at least 1")
        if self.block_arrangement not in ["square"]:
            raise ValueError(f"block_arrangement must be 'square', got '{self.block_arrangement}'")


class BlockHadamardHiRALayer(nn.Module):
    """
    Implementation of Block Hadamard HiRA layer.
    
    Block Hadamard HiRA extends HiRA by applying the Hadamard product adaptation
    at the block level, enabling more fine-grained control over different regions
    of the weight matrix.
    
    For a weight matrix W_0 of shape [out_features, in_features], we divide it into
    num_blocks × num_blocks blocks and apply HiRA independently to each block:
    
    ∆W^{(i,j)} = W_0^{(i,j)} ⊙ (A^{(i,j)} · B^{(i,j)})
    
    This is implemented efficiently using batched operations.
    """
    
    def __init__(
        self,
        base_layer: nn.Module,
        r: int = 8,
        alpha: float = 8,
        dropout: float = 0.0,
        num_blocks: int = 4,
        block_arrangement: str = "square",
        merge_weights: bool = False,
    ):
        super().__init__()
        
        # Store original layer attributes
        self.base_layer = base_layer
        
        # Get shape of the weight matrix
        if hasattr(base_layer, "weight"):
            weight = base_layer.weight
            self.out_features, self.in_features = weight.shape
        else:
            raise ValueError("Layer doesn't have a weight attribute")
        
        # Block configuration
        self.num_blocks = num_blocks
        self.block_arrangement = block_arrangement
        
        # Calculate block dimensions
        self.out_block_size = self.out_features // num_blocks
        self.in_block_size = self.in_features // num_blocks
        
        # Validate dimensions with helpful error messages
        self._validate_block_dimensions()
        
        # For bias handling
        if hasattr(base_layer, "bias") and base_layer.bias is not None:
            self.bias = base_layer.bias
        else:
            self.bias = None
        
        # Initialize dictionary for multiple adapters
        self.r = {}
        self.alpha = {}
        self.scaling = {}
        
        # Block-wise LoRA parameters stored as batched tensors for efficiency
        # Shape: [num_blocks, num_blocks, r, in_block_size] for A matrices
        # Shape: [num_blocks, num_blocks, out_block_size, r] for B matrices
        # These tensors are designed for efficient batched operations
        self.block_lora_A = nn.ParameterDict({})
        self.block_lora_B = nn.ParameterDict({})
        
        self.dropout = nn.Dropout(dropout)
        
        # Merge weights configuration
        self.merge_weights = merge_weights
        self.merged = False
        
        # Default active adapter
        self.active_adapter = None
        self.disable_adapters = False
    
    def update_layer(self, adapter_name, r, alpha, dropout):
        """
        Add or update an adapter with block-wise parameters.
        """
        # Store adapter hyperparameters
        self.r[adapter_name] = r
        self.alpha[adapter_name] = alpha
        self.scaling[adapter_name] = alpha / r  # Standard LoRA scaling
        
        # Initialize block-wise adapter parameters
        # A: [num_blocks, num_blocks, r, in_block_size]
        # B: [num_blocks, num_blocks, out_block_size, r]
        if adapter_name in self.block_lora_A:
            # If adapter exists, resize it
            self.block_lora_A[adapter_name] = nn.Parameter(
                self.block_lora_A[adapter_name].new_zeros(
                    (self.num_blocks, self.num_blocks, r, self.in_block_size)
                )
            )
            self.block_lora_B[adapter_name] = nn.Parameter(
                self.block_lora_B[adapter_name].new_zeros(
                    (self.num_blocks, self.num_blocks, self.out_block_size, r)
                )
            )
        else:
            # Create new adapter parameters
            self.block_lora_A[adapter_name] = nn.Parameter(
                torch.zeros((self.num_blocks, self.num_blocks, r, self.in_block_size))
            )
            self.block_lora_B[adapter_name] = nn.Parameter(
                torch.zeros((self.num_blocks, self.num_blocks, self.out_block_size, r))
            )
            
        # Initialize weights if this is a new adapter
        self.reset_parameters(adapter_name)
        
        # Set as active adapter if it's the first one
        if self.active_adapter is None:
            self.active_adapter = adapter_name
    
    def reset_parameters(self, adapter_name):
        """Initialize block-wise adapter weights following standard LoRA initialization"""
        if adapter_name in self.block_lora_A:
            # Standard LoRA initialization: A ~ N(0, σ²), B = 0
            # This ensures that initially ∆W = 0 for each block
            nn.init.kaiming_uniform_(self.block_lora_A[adapter_name], a=math.sqrt(5))
            nn.init.zeros_(self.block_lora_B[adapter_name])
    
    def _validate_block_dimensions(self):
        """
        Validate that weight dimensions are compatible with block configuration.
        Provides helpful error messages for dimension mismatches.
        """
        if self.out_features % self.num_blocks != 0:
            suggested_blocks = []
            for i in range(1, min(17, self.out_features + 1)):  # Check up to 16 blocks
                if self.out_features % i == 0 and self.in_features % i == 0:
                    suggested_blocks.append(i)
            
            suggestion = f" Suggested num_blocks values: {suggested_blocks[:5]}" if suggested_blocks else ""
            raise ValueError(
                f"out_features ({self.out_features}) must be divisible by num_blocks ({self.num_blocks}).{suggestion}"
            )
        
        if self.in_features % self.num_blocks != 0:
            suggested_blocks = []
            for i in range(1, min(17, self.in_features + 1)):  # Check up to 16 blocks
                if self.out_features % i == 0 and self.in_features % i == 0:
                    suggested_blocks.append(i)
            
            suggestion = f" Suggested num_blocks values: {suggested_blocks[:5]}" if suggested_blocks else ""
            raise ValueError(
                f"in_features ({self.in_features}) must be divisible by num_blocks ({self.num_blocks}).{suggestion}"
            )

    def get_block_indices(self, block_i, block_j):
        """Get the start and end indices for block (i, j)"""
        out_start = block_i * self.out_block_size
        out_end = (block_i + 1) * self.out_block_size
        in_start = block_j * self.in_block_size
        in_end = (block_j + 1) * self.in_block_size
        return out_start, out_end, in_start, in_end
    
    def _compute_batched_adaptation(self, W0, adapter_name):
        """
        Compute Block Hadamard HiRA adaptation using efficient batched operations.
        
        This method processes all blocks simultaneously using batched matrix multiplication
        and vectorized operations for significant speedup.
        
        Args:
            W0: Base weight matrix [out_features, in_features]
            adapter_name: Name of the active adapter
            
        Returns:
            adaptation_weight: Block-wise adaptation matrix [out_features, in_features]
        """
        # Get block-wise LoRA parameters
        A_blocks = self.block_lora_A[adapter_name]  # [num_blocks, num_blocks, r, in_block_size]
        B_blocks = self.block_lora_B[adapter_name]  # [num_blocks, num_blocks, out_block_size, r]
        
        # Reshape for batched operations
        # Flatten the first two dimensions to treat all blocks as a batch
        batch_size = self.num_blocks * self.num_blocks
        A_batched = A_blocks.view(batch_size, self.r[adapter_name], self.in_block_size)  # [batch_size, r, in_block_size]
        B_batched = B_blocks.view(batch_size, self.out_block_size, self.r[adapter_name])  # [batch_size, out_block_size, r]
        
        # Batched matrix multiplication: B @ A for all blocks simultaneously
        # torch.bmm: [batch_size, out_block_size, r] @ [batch_size, r, in_block_size] 
        #         -> [batch_size, out_block_size, in_block_size]
        BA_batched = torch.bmm(B_batched, A_batched)
        
        # Apply scaling
        BA_batched = BA_batched * self.scaling[adapter_name]
        
        # Reshape back to block structure
        BA_blocks = BA_batched.view(self.num_blocks, self.num_blocks, self.out_block_size, self.in_block_size)
        
        # Extract weight blocks efficiently using advanced indexing
        W0_blocks = self._extract_weight_blocks_vectorized(W0)  # [num_blocks, num_blocks, out_block_size, in_block_size]
        
        # Vectorized Hadamard product: W_0^{(i,j)} ⊙ (B·A)^{(i,j)} for all blocks
        adaptation_blocks = W0_blocks * BA_blocks
        
        # Reconstruct the full adaptation matrix efficiently
        adaptation_weight = self._reconstruct_weight_matrix_vectorized(adaptation_blocks)
        
        return adaptation_weight
    
    def _extract_weight_blocks_vectorized(self, W):
        """
        Extract all weight blocks using vectorized operations.
        
        Args:
            W: Weight matrix [out_features, in_features]
            
        Returns:
            blocks: Weight blocks [num_blocks, num_blocks, out_block_size, in_block_size]
        """
        # Use view and permute for efficient block extraction
        # Reshape to [num_blocks, out_block_size, num_blocks, in_block_size]
        W_reshaped = W.view(self.num_blocks, self.out_block_size, self.num_blocks, self.in_block_size)
        
        # Permute to [num_blocks, num_blocks, out_block_size, in_block_size]
        blocks = W_reshaped.permute(0, 2, 1, 3)
        
        return blocks
    
    def _reconstruct_weight_matrix_vectorized(self, blocks):
        """
        Reconstruct the full weight matrix from blocks using vectorized operations.
        
        Args:
            blocks: Weight blocks [num_blocks, num_blocks, out_block_size, in_block_size]
            
        Returns:
            W: Reconstructed weight matrix [out_features, in_features]
        """
        # Permute to [num_blocks, out_block_size, num_blocks, in_block_size]
        blocks_permuted = blocks.permute(0, 2, 1, 3)
        
        # Reshape to [out_features, in_features]
        W = blocks_permuted.contiguous().view(self.out_features, self.in_features)
        
        return W

    def merge(self):
        """
        Merge the active adapter weights into the base layer using Block Hadamard HiRA formulation.
        
        For each block (i,j):
        W_new^{(i,j)} = W_0^{(i,j)} + W_0^{(i,j)} ⊙ (A^{(i,j)} · B^{(i,j)})
        """
        if self.merged or self.active_adapter is None:
            return
        
        adapter_name = self.active_adapter
        if adapter_name in self.block_lora_A:
            # Get the current weight W₀
            W_0 = self.base_layer.weight.data
            
            # Compute adaptation using batched operations
            adaptation_weight = self._compute_batched_adaptation(W_0, adapter_name)
            
            # Apply the adaptation: W_new = W_0 + adaptation
            self.base_layer.weight.data = W_0 + adaptation_weight
            
            self.merged = True

    def unmerge(self):
        """
        Unmerge the active adapter weights from the base layer.
        Note: This is an approximation since the Hadamard product is not easily invertible.
        """
        if not self.merged or self.active_adapter is None:
            return
        
        # For simplicity, we don't implement unmerge for Block Hadamard HiRA
        # In practice, one would need to store the original weights
        print("Warning: Unmerge not implemented for Block Hadamard HiRA")
        self.merged = False

    def forward(self, x):
        # Get weight and bias from base layer
        if hasattr(self.base_layer, 'weight'):
            W0, bias = self.base_layer.weight, self.base_layer.bias
        else:
            # Handle wrapped layers
            original_layer = self.base_layer
            while hasattr(original_layer, 'base_layer') and not hasattr(original_layer, 'weight'):
                original_layer = original_layer.base_layer
            
            if hasattr(original_layer, 'weight'):
                W0, bias = original_layer.weight, original_layer.bias
            else:
                raise AttributeError(f"Cannot find weight attribute in base_layer of type {type(self.base_layer)}")
        
        if (self.disable_adapters
            or self.active_adapter is None
            or self.merged):
            return F.linear(x, W0, bias)

        adapter_name = self.active_adapter
        
        # Base path: W₀x
        y = F.linear(x, W0, bias)
        
        # Block Hadamard HiRA adaptation using batched operations
        adaptation_weight = self._compute_batched_adaptation(W0, adapter_name)
        
        # Apply Block Hadamard HiRA adaptation
        adapter_out = F.linear(x, adaptation_weight)
        
        # Add dropout and combine with base output
        adapter_out = self.dropout(adapter_out)
        y += adapter_out
        
        return y

    def set_adapter(self, adapter_name):
        """Set the active adapter"""
        if adapter_name in self.block_lora_A:
            self.active_adapter = adapter_name
        else:
            raise ValueError(f"Adapter {adapter_name} not found")
    
    def get_performance_info(self, adapter_name=None):
        """
        Get performance information about the Block Hadamard HiRA layer.
        
        Returns:
            dict: Performance statistics including memory usage and computation complexity
        """
        if adapter_name is None:
            adapter_name = self.active_adapter
        
        if adapter_name not in self.block_lora_A:
            return {}
        
        # Calculate parameter counts
        total_adapter_params = self.block_lora_A[adapter_name].numel() + self.block_lora_B[adapter_name].numel()
        params_per_block = total_adapter_params // (self.num_blocks * self.num_blocks)
        
        # Calculate theoretical speedup from batching
        sequential_ops = self.num_blocks * self.num_blocks  # Number of sequential block operations
        batched_ops = 1  # Single batched operation
        theoretical_speedup = sequential_ops / batched_ops
        
        # Memory usage estimates
        base_memory = self.out_features * self.in_features * 4  # bytes (float32)
        adapter_memory = total_adapter_params * 4  # bytes (float32)
        
        return {
            'total_blocks': self.num_blocks * self.num_blocks,
            'block_size': (self.out_block_size, self.in_block_size),
            'total_adapter_params': total_adapter_params,
            'params_per_block': params_per_block,
            'rank_per_block': self.r[adapter_name],
            'theoretical_speedup': f"{theoretical_speedup:.1f}x",
            'base_memory_mb': base_memory / (1024 * 1024),
            'adapter_memory_mb': adapter_memory / (1024 * 1024),
            'memory_overhead': f"{(adapter_memory / base_memory) * 100:.2f}%",
            'uses_batched_gemm': True,
            'vectorized_hadamard': True
        }


def get_adapter_state_dict(model, adapter_name):
    """Extract the state dict for a specific Block Hadamard HiRA adapter"""
    adapter_state_dict = {}
    
    for name, module in model.named_modules():
        if isinstance(module, BlockHadamardHiRALayer) and adapter_name in module.block_lora_A:
            # Save block-wise adapter parameters
            adapter_state_dict[f"{name}.block_lora_A.{adapter_name}"] = module.block_lora_A[adapter_name].data.cpu()
            adapter_state_dict[f"{name}.block_lora_B.{adapter_name}"] = module.block_lora_B[adapter_name].data.cpu()
    
    return adapter_state_dict


def set_adapter_state_dict(model, adapter_state_dict, adapter_name):
    """Load a specific Block Hadamard HiRA adapter state dict"""
    for name, module in model.named_modules():
        if isinstance(module, BlockHadamardHiRALayer) and adapter_name in module.block_lora_A:
            # Load block-wise adapter parameters
            key_A = f"{name}.block_lora_A.{adapter_name}"
            key_B = f"{name}.block_lora_B.{adapter_name}"
            
            if key_A in adapter_state_dict:
                module.block_lora_A[adapter_name].data = adapter_state_dict[key_A].to(module.block_lora_A[adapter_name].device)
            if key_B in adapter_state_dict:
                module.block_lora_B[adapter_name].data = adapter_state_dict[key_B].to(module.block_lora_B[adapter_name].device)


def apply_block_hadamard_hira(model, config, adapter_name="default"):
    """
    Apply Block Hadamard HiRA adapter to the model by directly modifying it.
    """
    # Store peft config and adapter info directly on the model
    model.peft_config = getattr(model, "peft_config", {})
    model.active_adapter = adapter_name
    model.adapter_layers = getattr(model, "adapter_layers", set())
    
    # Store config
    model.peft_config[adapter_name] = config
    
    # Get modules to target
    target_modules = config.target_modules
    if isinstance(target_modules, str):
        target_modules = [target_modules]
    
    # Replace target modules with Block Hadamard HiRA layers
    for name, module in model.named_modules():
        if any(target_module in name for target_module in target_modules):
            # Skip if already modified
            if hasattr(module, "active_adapter"):
                continue
                
            # Get parent module and target name
            parent_name, target_name = get_submodules(model, name)
            
            # Check if the module is a Linear layer
            if isinstance(module, nn.Linear):
                # Create a new Block Hadamard HiRA layer
                block_hira_layer = BlockHadamardHiRALayer(
                    module,
                    r=config.r,
                    alpha=config.alpha,
                    dropout=config.dropout,
                    num_blocks=config.num_blocks,
                    block_arrangement=config.block_arrangement,
                    merge_weights=False,
                )
                
                # Replace the module
                parent = model
                for name_part in parent_name.split("."):
                    if name_part:
                        parent = getattr(parent, name_part)
                
                setattr(parent, target_name, block_hira_layer)
                
                # Update the layer with the new adapter
                block_hira_layer.update_layer(
                    adapter_name=adapter_name,
                    r=config.r,
                    alpha=config.alpha,
                    dropout=config.dropout,
                )
                
                # Track the modified layer
                model.adapter_layers.add(name)
    
    # Add helper methods to the model
    model.set_adapter = types.MethodType(set_adapter, model)
    model.save_pretrained = types.MethodType(save_pretrained, model)
    model.merge_and_unload = types.MethodType(merge_and_unload, model)
    model.mark_only_adapters_as_trainable = types.MethodType(mark_only_adapters_as_trainable, model)
    
    # Freeze base model weights and enable only adapter parameters
    mark_only_adapters_as_trainable(model, adapter_name)
    
    return model


def get_submodules(model, key):
    """Get parent module name and target name for a given key"""
    if "." in key:
        parent_name, target_name = key.rsplit(".", 1)
    else:
        parent_name, target_name = "", key
    
    return parent_name, target_name


def set_adapter(self, adapter_name):
    """
    Activate a specific Block Hadamard HiRA adapter.
    """
    if adapter_name not in self.peft_config:
        raise ValueError(f"Adapter {adapter_name} not found.")
    
    # Set the active adapter
    self.active_adapter = adapter_name
    
    # Update all Block Hadamard HiRA layers
    for name, module in self.named_modules():
        if isinstance(module, BlockHadamardHiRALayer) and adapter_name in module.block_lora_A:
            module.set_adapter(adapter_name)


def merge_and_unload(self):
    """
    Merge the active Block Hadamard HiRA adapter weights and unload the model.
    """
    # Merge weights for all Block Hadamard HiRA layers
    for name, module in self.named_modules():
        if isinstance(module, BlockHadamardHiRALayer) and module.active_adapter == self.active_adapter:
            module.merge()
    
    # Create a new model with merged weights
    base_model = copy.deepcopy(self)
    
    # Replace Block Hadamard HiRA layers with regular nn.Linear
    for name in self.adapter_layers:
        parent_name, target_name = get_submodules(base_model, name)
        parent = base_model
        for name_part in parent_name.split("."):
            if name_part:
                parent = getattr(parent, name_part)
                
        module = getattr(parent, target_name)
        
        # Create a new Linear layer with merged weights
        if isinstance(module, BlockHadamardHiRALayer):
            new_module = nn.Linear(
                module.in_features,
                module.out_features,
                bias=module.bias is not None
            )
            new_module.weight.data = module.base_layer.weight.data
            if module.bias is not None:
                new_module.bias.data = module.bias.data
            
            # Replace the Block Hadamard HiRA layer
            setattr(parent, target_name, new_module)
    
    return base_model


def mark_only_adapters_as_trainable(self, adapter_name):
    """Set only Block Hadamard HiRA adapter parameters as trainable"""
    # Freeze all parameters
    for name, param in self.named_parameters():
        param.requires_grad = False
    
    # Unfreeze only the Block Hadamard HiRA adapter parameters
    for name, module in self.named_modules():
        if isinstance(module, BlockHadamardHiRALayer) and adapter_name in module.block_lora_A:
            module.block_lora_A[adapter_name].requires_grad = True
            module.block_lora_B[adapter_name].requires_grad = True


def save_pretrained(self, save_directory, **kwargs):
    """
    Save the Block Hadamard HiRA adapter model to a directory.
    """
    import os
    import json
    
    os.makedirs(save_directory, exist_ok=True)
    
    # Save active adapter configuration
    if hasattr(self, "active_adapter") and self.active_adapter:
        config = self.peft_config[self.active_adapter]
        config_dict = config.__dict__.copy()
        # Convert any non-serializable items
        for key, value in config_dict.items():
            if isinstance(value, set):
                config_dict[key] = list(value)
            elif isinstance(value, Enum):
                config_dict[key] = value.value
        
        # Add Block Hadamard HiRA specific fields
        config_dict["peft_type"] = "Block_Hadamard_HiRA"
        
        # Save config
        with open(os.path.join(save_directory, "adapter_config.json"), "w") as f:
            json.dump(config_dict, f, indent=2)
    
    # Save adapter weights
    adapter_state_dict = get_adapter_state_dict(self, self.active_adapter)
    torch.save(adapter_state_dict, os.path.join(save_directory, "adapter_model.bin"))


def get_block_hadamard_hira_model(model, config, adapter_name="default"):
    """
    Apply Block Hadamard HiRA to a model.
    """
    return apply_block_hadamard_hira(model, config, adapter_name)