
from abc import ABC
from enum import Enum
from dataclasses import dataclass
from typing import Dict, List, Literal, Tuple, Union

import torch
from torch import nn


class LayerType (Enum):
    conv1D = 'conv1d'
    conv2D = 'conv2d'
    transformer = 'transformer'


@dataclass
class LayerConfig:
    layer_type: LayerType


@dataclass
class TransformerConfig(LayerConfig):
    hidden_size: int
    head_size: int
    num_heads: int
    intermediate_dimension: int


@dataclass
class Conv1dConfig(LayerConfig):
    output_channels: int


# === 1. Schema definitions ===

@dataclass
class LayerSpec:
    module_name: str
    attribute: str
    prune_type: Literal["row", "column"]

@dataclass
class LayerSchema(ABC):
    layer_name: str
    layer_type: LayerType
    
@dataclass
class TransformerLayerSchema(LayerSchema):
    norm_type: Literal["pre_norm", "post_norm"]
    layers: Dict[str, LayerSpec]

@dataclass
class ConvLayerSchema(LayerSchema):
    conv_type: Literal["conv1d", "conv2d"]
    prune_type: Literal["out_channel", "in_channel"]


def resolve_instance_schema_map(model: nn.Module, schema_list: List[LayerSchema]) -> List[Tuple[str, nn.Module, LayerSchema]]:
    """
    Map each submodule instance in `model` to its corresponding LayerSchema from `schema_list`.
    
    Returns:
        A list of tuples: (module_name, module_instance, layer_schema)
    """
    matched = []

    for name, module in model.named_modules():
        module_class_name = module.__class__.__name__
        for schema in schema_list:
            if schema.layer_name == module_class_name:
                matched.append((name, module, schema))
                break  # Assume one schema per layer type
    
    return matched


def get_attr(obj: object, attributes: List[str], default=None):
    """
    Safely get an attribute from an object, returning a default value if the attribute does not exist.
    
    Args:
        obj: The object to get the attribute from.
        attr: The name of the attribute to retrieve.
        default: The value to return if the attribute does not exist.
    
    Returns:
        The value of the attribute or the default value.
    """
    for attr in attributes:
        res = getattr(obj, attr, None)
        if res is not None:
            return res
    return default

def get_layer_config(config, layer_type, layer_name) -> LayerConfig:
    """
    Get the configuration class for a specific layer type and name.

    Args:
        config: The configuration object containing layer settings.
        layer_type (LayerType): The type of the layer (e.g., convolution, transformer).
        layer_name (str): The name of the layer.

    Returns:
        LayerConfig: An instance of the appropriate configuration class.
    """
    if layer_type == LayerType.transformer:
        hidden_size = get_attr(config, ['hidden_size', 'hidden_dim', 'hidden_dimension'], None)
        assert hidden_size is not None, f'Cannot find "hidden_size", "hidden_dim", or "hidden_dimension" in config for transformer layer {layer_name}.'
        head_size = get_attr(config, ['head_size', 'head_dim', 'head_dimension'], None)
        assert head_size is not None, f'Cannot find "head_size", "head_dim", or "head_dimension" in config for transformer layer {layer_name}.'
        num_heads = get_attr(config, ['num_heads', 'num_attention_heads'], None)
        num_heads  = num_heads if num_heads is not None else hidden_size // head_size
        intermediate_dimension = get_attr(config, ['intermediate_size', 'intermediate_dim', 'intermediate_dimension'], None)
        assert intermediate_dimension is not None, f'Cannot find "intermediate_size", "intermediate_dim", or "intermediate_dimension" in config for transformer layer {layer_name}.'
        return TransformerConfig(
            layer_type=layer_type,
            hidden_size=hidden_size,
            head_size=head_size,
            num_heads=num_heads,
            intermediate_dimension=intermediate_dimension
        )
    else:
        raise ValueError(f"Unsupported layer type: {layer_type}")


def extract_layers_by_class(model: nn.Module, target_classes: list) -> list:
    """
    Extracts all layers of a PyTorch model in order, filtering by class names,
    while maintaining execution order and storing indices.

    Args:
        model (torch.nn.Module): PyTorch model.
        target_classes (list of str): List of class names to extract (e.g., ["Linear", "Conv2d"]).

    Returns:
        list: List of (index, layer_name, layer) tuples maintaining execution order.
    """
    extracted_layers = []
    layer_index = 0  # Initialize layer index

    for name, layer in model.named_modules():  # Iterates in execution order
        if layer.__class__.__name__ in target_classes:
            extracted_layers.append((layer_index, name, layer))
            layer_index += 1  # Increment index for next layer

    return extracted_layers


def calculate_selected_block_indices_transformer(layer, selected_heads, selected_ffn_dims):
    """
    Calculate flattened indices for selected blocks in a transformer layer.

    Args:
        layer (torch.nn.Module): The transformer layer of interest.
        selected_heads (list of int): Indices of selected heads in MHA.
        selected_ffn_dims (list of int): Indices of selected FFN dimensions.
    
    Returns:
        list of int: Indices of selected blocks in the flattened parameter space.
    """
    selected_indices = []
    hidden_size = layer.attention.self.query.weight.size(1)  # Hidden size
    num_heads = layer.attention.self.num_attention_heads
    head_size = hidden_size // num_heads
    intermediate_size = layer.intermediate.dense.weight.size(0)

    # Calculate head indices in query, key, and value weights
    for head in selected_heads:
        # Head offset for each attention weight matrix
        qkv_start_idx = 0
        for weight_matrix in [
            layer.attention.self.query.weight, 
            layer.attention.self.key.weight, 
            layer.attention.self.value.weight
        ]:
            start_idx = qkv_start_idx + head * head_size * hidden_size
            end_idx = start_idx + head_size * hidden_size
            selected_indices.extend(range(start_idx, end_idx))
            qkv_start_idx += weight_matrix.numel()  # Update offset for next matrix

    # Calculate FFN dimension indices
    ffn_start_idx = qkv_start_idx  # FFN weights start after QKV weights
    for ffn_dim in selected_ffn_dims:
        start_idx = ffn_start_idx + ffn_dim * hidden_size
        end_idx = start_idx + hidden_size
        selected_indices.extend(range(start_idx, end_idx))
    
    return selected_indices


def calculate_selected_block_indices_convolution(conv_layer, selected_output_channels):
    """
    Calculates the flattened indices of the selected output channels in a convolutional layer.

    Args:
        conv_layer (torch.nn.Conv2d): The convolutional layer.
        selected_output_channels (list[int]): List of output channel indices to select.

    Returns:
        torch.Tensor: A tensor of flattened indices corresponding to the selected output channels.
    """
    # Get the shape of the convolutional layer's weights
    out_channels, in_channels, kernel_height, kernel_width = conv_layer.weight.shape

    # Initialize an empty list to store the selected indices
    selected_indices = []

    # Loop over the selected output channels
    for out_channel in selected_output_channels:
        # Calculate the range of indices for this output channel
        start_index = out_channel * in_channels * kernel_height * kernel_width
        end_index = start_index + (in_channels * kernel_height * kernel_width)
        
        # Append all indices corresponding to this output channel
        selected_indices.extend(range(start_index, end_index))

    # Return the selected indices as a tensor
    return torch.tensor(selected_indices, dtype=torch.long)
