# -*- coding: utf-8 -*-
"""
Factory - Simple PCN component creation functions

Avoids excessive factory patterns and provides only simple creation functions.
"""

try:
    import torch
    import torch.nn as nn
    TORCH_AVAILABLE = True
except ImportError:
    TORCH_AVAILABLE = False

from .config import PCNConfig
from .new_pcn_model import NewPCNModel
# from .two_stage_inference import TwoStageInference  # Temporarily disabled - this module not in SUCCESS branch


def create_simple_backbone(layer_sizes, num_classes):
    """
    Create simple backbone model

    Args:
        layer_sizes: Layer size list [input_size, hidden1, hidden2, ...]
        num_classes: Number of classification classes

    Returns:
        Backbone model
    """
    if not TORCH_AVAILABLE:
        raise ImportError("PyTorch is required")

    class SimpleBackbone(nn.Module):
        def __init__(self, in_dim, out_dim):
            super().__init__()
            self.linear = nn.Linear(in_dim, out_dim)

        def forward(self, x):
            return self.linear(x)

    class SimplePCNBackbone(nn.Module):
        def __init__(self, layer_sizes, num_classes):
            super().__init__()

            # Backbone modules (including final classifier) - maintain existing PCN structure
            self.backbone_module_list = nn.ModuleList()
            for i in range(len(layer_sizes) - 1):
                self.backbone_module_list.append(
                    SimpleBackbone(layer_sizes[i], layer_sizes[i + 1])
                )

            # Add classifier at the end (same as existing structure)
            self.backbone_module_list.append(
                SimpleBackbone(layer_sizes[-1], num_classes)
            )

        def init_zs_ff(self, x, y):
            """FF initialization"""
            zs = [x]
            for module in self.backbone_module_list:
                z_next = module(zs[-1])
                zs.append(z_next.detach())
            return zs

        def _update_params_standard(self, zs, x, y, optimizer=None):
            """Standard parameter update - follows existing pcn_base.py structure"""
            total_loss = 0.0

            # Input reconstruction
            total_loss += torch.sum((zs[0] - x) ** 2)

            # Backbone modules: prediction losses + final classification
            for i, module in enumerate(self.backbone_module_list):
                if i != len(self.backbone_module_list) - 1:
                    # Intermediate layers: prediction loss
                    pred = module(zs[i])
                    total_loss += torch.sum((zs[i + 1] - pred) ** 2)
                else:
                    # Last layer: classification loss (same as existing structure)
                    logits = module(zs[-2])  # Same as before: input zs[-2] to last module
                    total_loss += torch.nn.functional.cross_entropy(logits, y)

            return total_loss


        def _update_params_pred_freeze(self, zs, x, y, optimizer=None, zhs=None):
            """Frozen prediction parameter update - follows existing structure"""
            if zhs is None:
                return self._update_params_standard(zs, x, y, optimizer)

            total_loss = 0.0

            zhs_pred = []
            for i, module in enumerate(self.backbone_module_list):
                zhs_pred.append(module(zhs[i]))
            import ipdb; ipdb.set_trace()

            # Use frozen predictions (existing pcn_base.py structure)
            for i, module in enumerate(self.backbone_module_list):
                if i != len(self.backbone_module_list) - 1:
                    # Intermediate layers: frozen prediction loss
                    total_loss += torch.sum((zs[i + 1] - zhs_pred[i]) ** 2)
                else:
                    # Last layer: classification with frozen predictions
                    logits = module(zhs[-2])  # Same as before: input zhs[-2] to last module
                    total_loss += torch.nn.functional.cross_entropy(logits, y)

            return total_loss

    return SimplePCNBackbone(layer_sizes, num_classes)


def create_pcn_from_args(args, layer_sizes=None, num_classes=None):
    """
    Create PCN model from args

    Args:
        args: argparse.Namespace object
        layer_sizes: Layer sizes (optional)
        num_classes: Number of classes (optional)

    Returns:
        NewPCNModel
    """
    # Create configuration
    config = PCNConfig(args)

    # Set default values
    if layer_sizes is None:
        if config.dataset == 'MNIST':
            layer_sizes = [784, 256, 128]  # 28*28, hidden1, hidden2
        else:
            layer_sizes = [100, 64, 32]  # default values

    if num_classes is None:
        num_classes = config.num_classes

    # Create backbone
    backbone = create_simple_backbone(layer_sizes, num_classes)

    # Create PCN model
    return NewPCNModel(backbone, config)


def create_pcn_from_config(config, layer_sizes, num_classes):
    """
    Create PCN model from config object

    Args:
        config: PCNConfig object
        layer_sizes: Layer sizes
        num_classes: Number of classes

    Returns:
        NewPCNModel
    """
    backbone = create_simple_backbone(layer_sizes, num_classes)
    return NewPCNModel(backbone, config)


def create_mnist_pcn(model_type='pcn_jacobi', solver_type='vanilla', use_meta_pc=False):
    """
    Create PCN model for MNIST

    Args:
        model_type: Model type
        solver_type: Solver type
        use_meta_pc: Whether to use Meta-PC

    Returns:
        NewPCNModel
    """
    # Basic configuration creation
    config = PCNConfig()
    config.model_type = model_type
    config.solver_type = solver_type
    config.use_meta_pc = use_meta_pc
    config.dataset = 'MNIST'
    config.num_classes = 10

    # Architecture suitable for MNIST
    layer_sizes = [784, 512, 256]  # 28*28 -> 512 -> 256
    num_classes = 10

    return create_pcn_from_config(config, layer_sizes, num_classes)


# Convenience functions
def create_vanilla_pcn(args):
    """Create Vanilla PC model"""
    config = PCNConfig(args)
    config.use_meta_pc = False
    config.solver_type = 'vanilla'

    layer_sizes = [784, 256, 128] if config.dataset == 'MNIST' else [100, 64, 32]
    backbone = create_simple_backbone(layer_sizes, config.num_classes)

    return NewPCNModel(backbone, config)


def create_meta_pcn(args):
    """Create Meta-PC model"""
    config = PCNConfig(args)
    config.use_meta_pc = True
    config.pred_freeze = True

    layer_sizes = [784, 256, 128] if config.dataset == 'MNIST' else [100, 64, 32]
    backbone = create_simple_backbone(layer_sizes, config.num_classes)

    return NewPCNModel(backbone, config)


def create_accelerated_pcn(args, solver_type='anderson'):
    """Create PCN model using accelerated solver"""
    config = PCNConfig(args)
    config.solver_type = solver_type

    layer_sizes = [784, 256, 128] if config.dataset == 'MNIST' else [100, 64, 32]
    backbone = create_simple_backbone(layer_sizes, config.num_classes)

    return NewPCNModel(backbone, config)


def create_simple_backbone_wrapper(backbone_module_list):
    """
    Existing approach: simple sequential wrapper
    """
    class SimpleWrapper(torch.nn.Module):
        def __init__(self, module_list):
            super().__init__()
            self.backbone_module_list = module_list

        def forward(self, x):
            for module in self.backbone_module_list:
                x = module(x)
            return x

        def update_params_and_compute_loss(self, zs_final, x, y, optimizer=None, zhs=None):
            # Optimizer handling additional
            if optimizer is not None:
                optimizer.zero_grad()

            # 기본 PCN loss (pred_freeze without)
            with torch.enable_grad():
                pc_loss = torch.sum((zs_final[0] - x)**2) / x.shape[0]
                for idx, module in enumerate(self.backbone_module_list):
                    if idx != len(self.backbone_module_list) - 1:
                        pc_loss += torch.sum((zs_final[idx+1] - module(zs_final[idx]))**2) / x.shape[0]
                    else:
                        pc_loss += torch.nn.functional.cross_entropy(module(zs_final[-2]), y)

            # Backward와 step additional
            pc_loss.backward()
            if optimizer is not None:
                optimizer.step()

            return pc_loss

        def _update_params_default(self, zs_final, x, y, optimizer=None):
            """기본 매개변수 업데이트 - new_pcn_model.py에서 called됨"""
            return self.update_params_and_compute_loss(zs_final, x, y, optimizer)

    return SimpleWrapper(backbone_module_list)

def create_backbone_from_external(backbone_name):
    """
    External backbone에서 to get model from 함수

    Args:
        backbone_name: backbone 이름 ('vgg13', 'shallow_cnn' 등)

    Returns:
        External backbone 모델 (backbone_module_list 포함)
    """
    try:
        # External backbone 모듈에서 가져오기
        from backbones.get_backbone_module_list import get_backbone_module_list

        class ExternalBackbone(nn.Module):
            def __init__(self, backbone_name):
                super().__init__()

                # 외부에서 backbone_module_list 가져오기
                self.backbone_module_list = get_backbone_module_list(backbone_name)

            def init_zs_ff(self, x, y=None):
                """FF 초기화 (기존 PCN_Base와 호환)"""
                zs = [x]
                for module in self.backbone_module_list:
                    z_next = module(zs[-1])
                    zs.append(z_next.detach())
                return zs

            def _update_params_standard(self, zs, x, y, optimizer=None):
                """표준 매개변수 업데이트 (기존 PCN_Base 구조 따름)"""
                total_loss = 0.0

                # Input reconstruction
                total_loss += torch.sum((zs[0] - x) ** 2)

                # Backbone modules: prediction losses + final classification
                for i, module in enumerate(self.backbone_module_list):
                    if i != len(self.backbone_module_list) - 1:
                        # Intermediate layers: prediction loss
                        pred = module(zs[i])
                        total_loss += torch.sum((zs[i + 1] - pred) ** 2)
                    else:
                        # Last layer: classification loss (Existing structure 그대로)
                        logits = module(zs[-2])  # Same as before: zs[-2]를 마지막 module에 입력
                        total_loss += torch.nn.functional.cross_entropy(logits, y)

                return total_loss

            def _update_params_origin(self, zs, x, y, optimizer=None, zs_origin=None):
                """Origin 매개변수 업데이트"""
                return self._update_params_standard(zs, x, y, optimizer)

            def _update_params_pred_freeze(self, zs, x, y, optimizer=None, zhs=None):
                """Frozen prediction 매개변수 업데이트"""
                if zhs is None:
                    return self._update_params_standard(zs, x, y, optimizer)

                total_loss = 0.0

                # Use frozen predictions
                for i, module in enumerate(self.backbone_module_list):
                    if i != len(self.backbone_module_list) - 1:
                        # Intermediate layers: frozen prediction loss
                        zhs_pred = module(zhs[i])
                        total_loss += torch.sum((zs[i + 1] - zhs_pred) ** 2)
                    else:
                        # Last layer: classification with frozen predictions
                        logits = module(zhs[-2])
                        total_loss += torch.nn.functional.cross_entropy(logits, y)

                return total_loss

        return ExternalBackbone(backbone_name)

    except ImportError as e:
        raise ImportError(f"External backbone 모듈을 가져올 수 없습니다: {e}")
    except Exception as e:
        raise ValueError(f"Backbone '{backbone_name}' creation 실패: {e}")


def create_pcn_with_external_backbone(backbone_name, model_type='pcn_jacobi',
                                    solver_type='vanilla', use_meta_pc=False,
                                    num_classes=10, **kwargs):
    """
    External backbone을 사용하는 PCN 모델 creation

    Args:
        backbone_name: External backbone 이름 ('vgg13', 'shallow_cnn' 등)
        model_type: Model type
        solver_type: Solver type
        use_meta_pc: Whether to use Meta-PC
        **kwargs: Additional settings

    Returns:
        NewPCNModel (External backbone 사용)
    """
    # Basic configuration creation
    config = PCNConfig()
    config.model_type = model_type
    config.solver_type = solver_type
    config.use_meta_pc = use_meta_pc

    # norm_kwargs handling (clipping 설정)
    norm_kwargs = kwargs.pop('norm_kwargs', {})

    # kwargs에서 Additional settings 적용
    for key, value in kwargs.items():
        if hasattr(config, key):
            setattr(config, key, value)

    # norm_kwargs를 config에 additional (clipping 관련)
    for key, value in norm_kwargs.items():
        setattr(config, key, value)

    # Meta-PC 설정 재적용 (factory에서 속성을 수정한 후)
    if config.use_meta_pc:
        config.param_update_method = 'pred_freeze'
        config.pred_freeze = True

    # 외부 백본 모듈 리스트 직접 가져오기
    try:
        from backbones.get_backbone_module_list import get_backbone_module_list
        backbone_module_list = get_backbone_module_list(backbone_name, num_classes=num_classes)
    except ImportError as e:
        raise ImportError(f"External backbone 모듈을 가져올 수 없습니다: {e}")

    # Meta-PC + pred_freeze를 위한 특별한 backbone 래퍼 사용
    if config.use_meta_pc and config.param_update_method == 'pred_freeze':
        from .pred_freeze_backbone import PredFreezeBackbone
        backbone = PredFreezeBackbone(backbone_module_list)
    else:
        # Existing approach: 간단한 sequential 래퍼
        backbone = create_simple_backbone_wrapper(backbone_module_list)

    return NewPCNModel(backbone, config)


def create_pcn_from_backbone(backbone_name, **kwargs):
    """
    범용 PCN 모델 creation 함수 - 다양한 backbone 지원

    Args:
        backbone_name: backbone 이름 ('vgg13', 'shallow_cnn' 등)
        **kwargs: PCN 설정 (model_type, solver_type, use_meta_pc, T, eta 등)

    Returns:
        NewPCNModel
    """
    return create_pcn_with_external_backbone(backbone_name, **kwargs)


def create_compatible_pcn_from_backbone(backbone_name, **kwargs):
    """
    부모 브랜치와 완전히 호환되는 PCN 모델 creation

    Args:
        backbone_name: backbone 이름 ('vgg13', 'shallow_cnn' 등)
        **kwargs: PCN 설정 (T, eta 등)

    Returns:
        PCNCompatibleModel (부모 브랜치와 동일한 로직)
    """
    from .pcn_compatible_model import PCNCompatibleModel

    # Basic configuration
    class CompatibleConfig:
        def __init__(self):
            self.T = kwargs.get('T', 10)
            self.eta = kwargs.get('eta', 0.2)
            self.model_type = 'pcn_jacobi'
            self.solver_type = 'vanilla'

    config = CompatibleConfig()

    # 외부 백본 creation
    backbone = create_backbone_from_external(backbone_name)

    return PCNCompatibleModel(backbone, config)


def create_vgg13_pcn(model_type='pcn_jacobi', solver_type='vanilla', use_meta_pc=False, **kwargs):
    """
    VGG13 backbone을 사용하는 PCN 모델 creation

    Args:
        model_type: Model type
        solver_type: Solver type
        use_meta_pc: Whether to use Meta-PC
        **kwargs: Additional settings (dataset, T, eta 등)

    Returns:
        NewPCNModel (VGG13 backbone)
    """
    return create_pcn_with_external_backbone(
        backbone_name='vgg13',
        model_type=model_type,
        solver_type=solver_type,
        use_meta_pc=use_meta_pc,
        **kwargs
    )


def create_shallow_cnn_pcn(model_type='pcn_jacobi', solver_type='vanilla', use_meta_pc=False, **kwargs):
    """
    Shallow CNN backbone을 사용하는 PCN 모델 creation

    Args:
        model_type: Model type
        solver_type: Solver type
        use_meta_pc: Whether to use Meta-PC
        **kwargs: Additional settings

    Returns:
        NewPCNModel (Shallow CNN backbone)
    """
    return create_pcn_with_external_backbone(
        backbone_name='shallow_cnn',
        model_type=model_type,
        solver_type=solver_type,
        use_meta_pc=use_meta_pc,
        **kwargs
    )


# ============================================================================
# Two-Stage Inference Factory Functions
# ============================================================================

def create_two_stage_pcn_from_args(args, layer_sizes=None, num_classes=None):
    """
    argparse.Namespace를 사용하여 2단계 PCN 모델 creation

    Args:
        args: argparse.Namespace 객체
        layer_sizes: 레이어 크기 (optional)
        num_classes: 클래스 수 (optional)

    Returns:
        NewPCNModel 모델 (TwoStageInference disabled)
    """
    # TwoStageInference due to unavailability NewPCNModel로 replacement
    from .config import PCNConfig
    from .new_pcn_model import NewPCNModel

    config = PCNConfig(args)

    if layer_sizes is None:
        layer_sizes = [784, 256, 128] if config.dataset == 'MNIST' else [32*32*3, 128, 64]
    if num_classes is None:
        num_classes = config.num_classes

    backbone = create_simple_backbone(layer_sizes, num_classes)
    return NewPCNModel(backbone, config)


def create_two_stage_pcn_from_config(config, layer_sizes=None, num_classes=None):
    """
    PCNConfig를 사용하여 2단계 PCN 모델 creation

    Args:
        config: PCNConfig 객체
        layer_sizes: 레이어 크기 (optional)
        num_classes: 클래스 수 (optional)

    Returns:
        NewPCNModel 모델 (TwoStageInference disabled)
    """
    # TwoStageInference due to unavailability NewPCNModel로 replacement
    from .new_pcn_model import NewPCNModel

    if layer_sizes is None:
        layer_sizes = [784, 256, 128] if config.dataset == 'MNIST' else [32*32*3, 128, 64]
    if num_classes is None:
        num_classes = config.num_classes

    backbone = create_simple_backbone(layer_sizes, num_classes)
    return NewPCNModel(backbone, config)


def create_two_stage_pcn_with_external_backbone(backbone_name, dataset='CIFAR10', num_classes=10,
                                               T_meta=5, T_total=20, stage2_solver='anderson',
                                               eta_meta=0.2, stage1_use_block_sweep=False, **kwargs):
    """
    외부 백본을 사용하여 2단계 PCN 모델 creation

    Args:
        backbone_name: 백본 이름 ('vgg13' 또는 'shallow_cnn')
        dataset: 데이터셋 이름
        num_classes: 클래스 수
        T_meta: Stage 1 최대 반복횟수
        T_total: 전체 최대 반복횟수
        stage2_solver: Stage 2 Solver type
        eta_meta: Stage 1 학습률
        stage1_use_block_sweep: Stage 1 Block Sweep 사용 여부
        **kwargs: Additional settings

    Returns:
        NewPCNModel 모델 (TwoStageInference disabled)
    """
    # TwoStageInference due to unavailability NewPCNModel로 replacement
    from .config import PCNConfig
    from .new_pcn_model import NewPCNModel

    backbone = create_backbone_from_external(backbone_name)

    config = PCNConfig()
    config.enable_two_stage = False  # 강제로 단일 단계 모드
    config.dataset = dataset
    config.num_classes = num_classes
    config.T = T_total  # 전체 반복수 사용
    config.solver_type = stage2_solver
    config.eta = eta_meta

    # Additional settings 적용
    for key, value in kwargs.items():
        if hasattr(config, key):
            setattr(config, key, value)

    return NewPCNModel(backbone, config)


def create_two_stage_vgg13_pcn(T_meta=5, T_total=20, stage2_solver='anderson',
                              eta_meta=0.2, stage1_use_block_sweep=False, **kwargs):
    """
    VGG13 백본을 사용하는 2단계 PCN 모델 creation

    Args:
        T_meta: Stage 1 최대 반복횟수
        T_total: 전체 최대 반복횟수
        stage2_solver: Stage 2 Solver type
        eta_meta: Stage 1 학습률
        stage1_use_block_sweep: Stage 1 Block Sweep 사용 여부
        **kwargs: Additional settings

    Returns:
        NewPCNModel 모델 (VGG13 backbone, TwoStageInference disabled)
    """
    return create_two_stage_pcn_with_external_backbone(
        backbone_name='vgg13',
        T_meta=T_meta,
        T_total=T_total,
        stage2_solver=stage2_solver,
        eta_meta=eta_meta,
        stage1_use_block_sweep=stage1_use_block_sweep,
        **kwargs
    )


def create_two_stage_shallow_cnn_pcn(T_meta=5, T_total=20, stage2_solver='anderson',
                                    eta_meta=0.2, stage1_use_block_sweep=False, **kwargs):
    """
    Shallow CNN 백본을 사용하는 2단계 PCN 모델 creation

    Args:
        T_meta: Stage 1 최대 반복횟수
        T_total: 전체 최대 반복횟수
        stage2_solver: Stage 2 Solver type
        eta_meta: Stage 1 학습률
        stage1_use_block_sweep: Stage 1 Block Sweep 사용 여부
        **kwargs: Additional settings

    Returns:
        NewPCNModel 모델 (Shallow CNN backbone, TwoStageInference disabled)
    """
    return create_two_stage_pcn_with_external_backbone(
        backbone_name='shallow_cnn',
        T_meta=T_meta,
        T_total=T_total,
        stage2_solver=stage2_solver,
        eta_meta=eta_meta,
        stage1_use_block_sweep=stage1_use_block_sweep,
        **kwargs
    )
