import torch
from torch import nn
from typing import Tuple, Optional
from .net import (
    get_resnet18_modules, get_resnet34_modules, get_resnet50_modules,
    get_vgg5_modules, get_vgg7_modules, get_vgg9_modules,
    get_vgg11_modules, get_vgg13_modules, get_vgg16_modules, get_vgg19_modules,
    get_vgg2_5_modules, get_vgg2_7_modules, get_vgg2_9_modules,
    get_vgg2_11_modules, get_vgg2_13_modules, get_vgg2_16_modules, get_vgg2_19_modules,
    get_vgg3_5_modules, get_vgg3_7_modules, get_vgg3_9_modules,
    get_vgg3_11_modules, get_vgg3_13_modules, get_vgg3_16_modules, get_vgg3_19_modules,
    get_mlp_modules, get_deep_mlp_modules, get_wide_mlp_modules,
    get_pinchetti_mlp_modules, get_goemaere_mlp_modules, get_goemaere_deep_mlp_modules,
    get_salvatori_mlp_modules, get_goemaere_mlp_with_sigmoid_modules, get_custom_mlp_modules,
    get_shallow_cnn_modules, get_deep_cnn_modules, get_wide_cnn_modules
)


def get_backbone_module_list(backbone_name: str,
                           input_channels: int = 3,
                           num_classes: int = 10,
                           input_size: Tuple[int, int] = (32, 32),
                           input_dim: Optional[int] = None) -> nn.ModuleList:
    """
    Factory function to create backbone module list for PCN

    Args:
        backbone_name: Name of the backbone architecture
        input_channels: Number of input channels (1 for MNIST, 3 for CIFAR/ImageNet)
        num_classes: Number of output classes
        input_size: Input image size (height, width) for CNN architectures
        input_dim: Input dimension for MLP architectures (flattened image size)

    Returns:
        nn.ModuleList: List of sequential modules for PCN

    Note:
        Pooling operations should be placed at the beginning of the next module,
        not at the end of the current module, to prevent sudden latent shape changes.
    """

    # ResNet architectures
    if backbone_name == 'resnet18':
        return get_resnet18_modules(input_channels, num_classes, input_size)
    elif backbone_name == 'resnet34':
        return get_resnet34_modules(input_channels, num_classes, input_size)
    elif backbone_name == 'resnet50':
        return get_resnet50_modules(input_channels, num_classes, input_size)

    # VGG architectures (Qi, 2025)
    elif backbone_name == 'vgg5':
        return get_vgg5_modules(input_channels, num_classes, input_size)
    elif backbone_name == 'vgg7':
        return get_vgg7_modules(input_channels, num_classes, input_size)
    elif backbone_name == 'vgg9':
        return get_vgg9_modules(input_channels, num_classes, input_size)
    elif backbone_name == 'vgg11':
        return get_vgg11_modules(input_channels, num_classes, input_size)
    elif backbone_name == 'vgg13':
        return get_vgg13_modules(input_channels, num_classes, input_size)
    elif backbone_name == 'vgg16':
        return get_vgg16_modules(input_channels, num_classes, input_size)
    elif backbone_name == 'vgg19':
        return get_vgg19_modules(input_channels, num_classes, input_size)

    # VGG2 architectures (MaxPool at end of current module)
    elif backbone_name == 'vgg2_5':
        return get_vgg2_5_modules(input_channels, num_classes, input_size)
    elif backbone_name == 'vgg2_7':
        return get_vgg2_7_modules(input_channels, num_classes, input_size)
    elif backbone_name == 'vgg2_9':
        return get_vgg2_9_modules(input_channels, num_classes, input_size)
    elif backbone_name == 'vgg2_11':
        return get_vgg2_11_modules(input_channels, num_classes, input_size)
    elif backbone_name == 'vgg2_13':
        return get_vgg2_13_modules(input_channels, num_classes, input_size)
    elif backbone_name == 'vgg2_16':
        return get_vgg2_16_modules(input_channels, num_classes, input_size)
    elif backbone_name == 'vgg2_19':
        return get_vgg2_19_modules(input_channels, num_classes, input_size)

    # VGG3 architectures (MaxPool at beginning of each block)
    elif backbone_name == 'vgg3_5':
        return get_vgg3_5_modules(input_channels, num_classes, input_size)
    elif backbone_name == 'vgg3_7':
        return get_vgg3_7_modules(input_channels, num_classes, input_size)
    elif backbone_name == 'vgg3_9':
        return get_vgg3_9_modules(input_channels, num_classes, input_size)
    elif backbone_name == 'vgg3_11':
        return get_vgg3_11_modules(input_channels, num_classes, input_size)
    elif backbone_name == 'vgg3_13':
        return get_vgg3_13_modules(input_channels, num_classes, input_size)
    elif backbone_name == 'vgg3_16':
        return get_vgg3_16_modules(input_channels, num_classes, input_size)
    elif backbone_name == 'vgg3_19':
        return get_vgg3_19_modules(input_channels, num_classes, input_size)

    # Generic MLP architectures
    elif backbone_name == 'mlp':
        if input_dim is None:
            input_dim = input_channels * input_size[0] * input_size[1]
        return get_mlp_modules(input_dim, num_classes)
    elif backbone_name == 'deep_mlp':
        if input_dim is None:
            input_dim = input_channels * input_size[0] * input_size[1]
        return get_deep_mlp_modules(input_dim, num_classes)
    elif backbone_name == 'wide_mlp':
        if input_dim is None:
            input_dim = input_channels * input_size[0] * input_size[1]
        return get_wide_mlp_modules(input_dim, num_classes)

    # Paper-specific MLP architectures
    elif backbone_name == 'pinchetti_mlp':
        if input_dim is None:
            input_dim = input_channels * input_size[0] * input_size[1]
        return get_pinchetti_mlp_modules(input_dim, num_classes)
    elif backbone_name == 'goemaere_mlp':
        if input_dim is None:
            input_dim = input_channels * input_size[0] * input_size[1]
        return get_goemaere_mlp_modules(input_dim, num_classes)
    elif backbone_name == 'goemaere_deep_mlp':
        if input_dim is None:
            input_dim = input_channels * input_size[0] * input_size[1]
        return get_goemaere_deep_mlp_modules(input_dim, num_classes)
    elif backbone_name == 'salvatori_mlp':
        if input_dim is None:
            input_dim = input_channels * input_size[0] * input_size[1]
        return get_salvatori_mlp_modules(input_dim, num_classes)
    elif backbone_name == 'goemaere_mlp_sigmoid':
        if input_dim is None:
            input_dim = input_channels * input_size[0] * input_size[1]
        return get_goemaere_mlp_with_sigmoid_modules(input_dim, num_classes)

    # CNN architectures
    elif backbone_name == 'shallow_cnn':
        return get_shallow_cnn_modules(input_channels, num_classes, input_size)
    elif backbone_name == 'deep_cnn':
        return get_deep_cnn_modules(input_channels, num_classes, input_size)
    elif backbone_name == 'wide_cnn':
        return get_wide_cnn_modules(input_channels, num_classes, input_size)

    else:
        raise ValueError(f"Invalid backbone name: {backbone_name}. "
                        f"Supported backbones: resnet18, resnet34, resnet50, "
                        f"vgg5, vgg7, vgg9, vgg11, vgg13, vgg16, vgg19, "
                        f"vgg2_5, vgg2_7, vgg2_9, vgg2_11, vgg2_13, vgg2_16, vgg2_19, "
                        f"vgg3_5, vgg3_7, vgg3_9, vgg3_11, vgg3_13, vgg3_16, vgg3_19, "
                        f"mlp, deep_mlp, wide_mlp, pinchetti_mlp, goemaere_mlp, goemaere_deep_mlp, "
                        f"salvatori_mlp, goemaere_mlp_sigmoid, shallow_cnn, deep_cnn, wide_cnn")


def get_available_backbones() -> list:
    """
    Get list of available backbone networks

    Returns:
        list: List of supported backbone names
    """
    return [
        # ResNet
        'resnet18', 'resnet34', 'resnet50',

        # VGG (Qi, 2025)
        'vgg5', 'vgg7', 'vgg9', 'vgg11', 'vgg13', 'vgg16', 'vgg19',

        # VGG2 (MaxPool at end of current module)
        'vgg2_5', 'vgg2_7', 'vgg2_9', 'vgg2_11', 'vgg2_13', 'vgg2_16', 'vgg2_19',

        # VGG3 (MaxPool at beginning of each block)
        'vgg3_5', 'vgg3_7', 'vgg3_9', 'vgg3_11', 'vgg3_13', 'vgg3_16', 'vgg3_19',

        # Generic MLP
        'mlp', 'deep_mlp', 'wide_mlp',

        # Paper-specific MLP
        'pinchetti_mlp', 'goemaere_mlp', 'goemaere_deep_mlp',
        'salvatori_mlp', 'goemaere_mlp_sigmoid',

        # CNN
        'shallow_cnn', 'deep_cnn', 'wide_cnn'
    ]


def get_backbone_info(backbone_name: str) -> dict:
    """
    Get information about a specific backbone architecture

    Args:
        backbone_name: Name of the backbone architecture

    Returns:
        dict: Dictionary containing backbone information
    """
    info = {
        'name': backbone_name,
        'type': 'unknown',
        'description': '',
        'parameters': 'unknown'
    }

    # ResNet info
    if backbone_name in ['resnet18', 'resnet34', 'resnet50']:
        info['type'] = 'ResNet'
        if backbone_name == 'resnet18':
            info['description'] = 'ResNet-18 with 18 layers'
            info['parameters'] = '~11M'
        elif backbone_name == 'resnet34':
            info['description'] = 'ResNet-34 with 34 layers'
            info['parameters'] = '~21M'
        elif backbone_name == 'resnet50':
            info['description'] = 'ResNet-50 with 50 layers'
            info['parameters'] = '~25M'

    # VGG info (Qi, 2025)
    elif backbone_name in ['vgg5', 'vgg7', 'vgg9', 'vgg11', 'vgg13', 'vgg16', 'vgg19']:
        info['type'] = 'VGG'
        info['description'] = f'VGG-{backbone_name[3:]} with MaxPool at beginning of next module'
        if backbone_name == 'vgg5':
            info['parameters'] = '~1M'
        elif backbone_name == 'vgg7':
            info['parameters'] = '~2M'
        elif backbone_name == 'vgg9':
            info['parameters'] = '~3M'
        elif backbone_name == 'vgg11':
            info['parameters'] = '~9M'
        elif backbone_name == 'vgg13':
            info['parameters'] = '~9M'
        elif backbone_name == 'vgg16':
            info['parameters'] = '~14M'
        elif backbone_name == 'vgg19':
            info['parameters'] = '~20M'

    # VGG2 info (MaxPool at end of current module)
    elif backbone_name in ['vgg2_5', 'vgg2_7', 'vgg2_9', 'vgg2_11', 'vgg2_13', 'vgg2_16', 'vgg2_19']:
        info['type'] = 'VGG2'
        vgg_num = backbone_name.split('_')[1]
        info['description'] = f'VGG-{vgg_num} with MaxPool at end of current module and adaptive FC'
        if backbone_name == 'vgg2_5':
            info['parameters'] = '~1M'
        elif backbone_name == 'vgg2_7':
            info['parameters'] = '~2M'
        elif backbone_name == 'vgg2_9':
            info['parameters'] = '~3M'
        elif backbone_name == 'vgg2_11':
            info['parameters'] = '~9M'
        elif backbone_name == 'vgg2_13':
            info['parameters'] = '~9M'
        elif backbone_name == 'vgg2_16':
            info['parameters'] = '~14M'
        elif backbone_name == 'vgg2_19':
            info['parameters'] = '~20M'

    # VGG3 info (MaxPool at beginning of each block)
    elif backbone_name in ['vgg3_5', 'vgg3_7', 'vgg3_9', 'vgg3_11', 'vgg3_13', 'vgg3_16', 'vgg3_19']:
        info['type'] = 'VGG3'
        vgg_num = backbone_name.split('_')[1]
        info['description'] = f'VGG-{vgg_num} classic style with traditional pooling and classifier'
        if backbone_name == 'vgg3_5':
            info['parameters'] = '~1M'
        elif backbone_name == 'vgg3_7':
            info['parameters'] = '~2M'
        elif backbone_name == 'vgg3_9':
            info['parameters'] = '~3M'
        elif backbone_name == 'vgg3_11':
            info['parameters'] = '~15M'
        elif backbone_name == 'vgg3_13':
            info['parameters'] = '~15M'
        elif backbone_name == 'vgg3_16':
            info['parameters'] = '~20M'
        elif backbone_name == 'vgg3_19':
            info['parameters'] = '~25M'

    # MLP info
    elif backbone_name in ['mlp', 'deep_mlp', 'wide_mlp']:
        info['type'] = 'MLP'
        if backbone_name == 'mlp':
            info['description'] = 'Generic MLP with 2 hidden layers'
            info['parameters'] = '~1M'
        elif backbone_name == 'deep_mlp':
            info['description'] = 'Deep MLP with 5 hidden layers'
            info['parameters'] = '~2M'
        elif backbone_name == 'wide_mlp':
            info['description'] = 'Wide MLP with 2 wide hidden layers'
            info['parameters'] = '~3M'

    # Paper-specific MLP info
    elif backbone_name in ['pinchetti_mlp', 'goemaere_mlp', 'goemaere_deep_mlp', 'salvatori_mlp', 'goemaere_mlp_sigmoid']:
        info['type'] = 'Paper-specific MLP'
        if backbone_name == 'pinchetti_mlp':
            info['description'] = 'Pinchetti 2025 MLP: 3 layers, 128 neurons'
            info['parameters'] = '~100K'
        elif backbone_name == 'goemaere_mlp':
            info['description'] = 'Goemaere 2025 MLP: 4 layers, 128 neurons, GELU'
            info['parameters'] = '~200K'
        elif backbone_name == 'goemaere_deep_mlp':
            info['description'] = 'Goemaere 2025 Deep MLP: 20 layers, 128 neurons, GELU'
            info['parameters'] = '~1M'
        elif backbone_name == 'salvatori_mlp':
            info['description'] = 'Salvatori 2023 MLP: 2 hidden layers, 64 neurons'
            info['parameters'] = '~50K'
        elif backbone_name == 'goemaere_mlp_sigmoid':
            info['description'] = 'Goemaere 2025 MLP with Sigmoid output'
            info['parameters'] = '~200K'

    # CNN info
    elif backbone_name in ['shallow_cnn', 'deep_cnn', 'wide_cnn']:
        info['type'] = 'CNN'
        if backbone_name == 'shallow_cnn':
            info['description'] = 'Shallow CNN with 2 conv layers'
            info['parameters'] = '~500K'
        elif backbone_name == 'deep_cnn':
            info['description'] = 'Deep CNN with 4 conv layers'
            info['parameters'] = '~1M'
        elif backbone_name == 'wide_cnn':
            info['description'] = 'Wide CNN with 2 wide conv layers'
            info['parameters'] = '~2M'

    return info