# -*- coding: utf-8 -*-
"""
Simple PCN Config - 간소화된 PCN 설정 관리
"""

class PCNConfig:
    """간소화된 PCN 설정 클래스"""
    
    def __init__(self, args=None):
        # Set default values
        self.model_type = 'pcn_jacobi'
        self.use_meta_pc = False
        self.update_latent = 'block_sweep_gs'
        self.solver_type = 'vanilla'
        self.loop_scheduler = None  # jacobi, gauss_seidel, block_sweep
        self.param_update_method = 'standard'
        self.norm_type = 'variance'
        self.norm_clip = False
        self.norm_clip_value = 1.0
        self.norm_learnable_scale = False
        self.norm_target_norm = 0.9
        self.norm_filter_out = None
        self.eta = 0.2
        self.T = 20
        self.dataset = 'MNIST'
        self.num_classes = 10
        self.pred_freeze = False
        self.convergence_threshold = 1e-4
        self.track_trajectory = False
        self.anderson_m = 5
        self.broyden_memory_size = 10
        
        # Anderson Type-II specific hyperparameters
        self.anderson_beta = 1.0              # Damping factor (βΔF_k term)
        self.anderson_lam = 1e-6              # Ridge regularization
        self.anderson_safeguard_tau = 0.1     # Monotonicity check threshold (0 = disabled)
        self.anderson_gamma_clip = None       # ||γ|| constraint (None = unlimited)
        self.anderson_warmup = 2              # Warmup steps before AA activation
        
        # Two-stage inference system settings
        self.enable_two_stage = False           # Two-stage mode activation
        self.T_meta = 5                        # Stage 1 max iterations
        self.T_total = 20                      # Total max iterations
        self.stage2_solver = 'anderson'        # Stage 2 solver type
        self.eta_meta = 0.2                    # Stage 1 learning rate
        self.stage1_use_block_sweep = False    # Stage 1 Block Sweep usage
        self.quality_threshold = 0.1           # Warm Start quality threshold
        
        # Override from args
        if args is not None:
            for attr in dir(args):
                if not attr.startswith('_'):
                    setattr(self, attr, getattr(args, attr))

            # Handle loop_scheduler mapping
            if hasattr(args, 'loop_scheduler') and args.loop_scheduler is not None:
                # Directly specified loop_scheduler takes priority
                self.loop_scheduler = args.loop_scheduler
                self.update_latent = args.loop_scheduler  # Maintain consistency
            elif hasattr(args, 'update_latent_rule'):
                # Map update_latent_rule to loop_scheduler (fallback)
                self.loop_scheduler = args.update_latent_rule
                self.update_latent = args.update_latent_rule

        # Handle loop_scheduler mapping even without args
        self._apply_loop_scheduler_mapping()

        # Auto-adjust Meta-PC settings
        if self.use_meta_pc:
            self.param_update_method = 'pred_freeze'
            self.pred_freeze = True

    def _apply_loop_scheduler_mapping(self):
        """loop_scheduler 매핑 처리 (factory에서 kwargs로 설정된 경우)"""
        # Handle case when passed as kwargs from factory.py
        if hasattr(self, 'loop_scheduler') and self.loop_scheduler is not None:
            # 직접 지정된 loop_scheduler가 우선
            self.update_latent = self.loop_scheduler
        elif hasattr(self, 'update_latent_rule') and self.update_latent_rule is not None:
            # update_latent_rule을 loop_scheduler로 매핑 (fallback)
            self.loop_scheduler = self.update_latent_rule
            self.update_latent = self.update_latent_rule

    def get_solver_config(self):
        """솔버 설정 반환"""
        return {
            'T': self.T,
            'eta': self.eta,
            'convergence_threshold': self.convergence_threshold,
            'track_trajectory': self.track_trajectory,
            'anderson_m': self.anderson_m,
            'broyden_memory_size': self.broyden_memory_size,
            # Anderson Type-II hyperparameters
            'anderson_beta': self.anderson_beta,
            'anderson_lam': self.anderson_lam,
            'anderson_safeguard_tau': self.anderson_safeguard_tau,
            'anderson_gamma_clip': self.anderson_gamma_clip,
            'anderson_warmup': self.anderson_warmup
        }
    
    def get_stage1_config(self):
        """1단계 Meta-PC 설정 생성"""
        config = PCNConfig()
        
        # Copy default settings
        for attr in ['dataset', 'num_classes', 'norm_type', 'convergence_threshold', 
                     'anderson_m', 'broyden_memory_size', 'pred_freeze', 'track_trajectory']:
            if hasattr(self, attr):
                setattr(config, attr, getattr(self, attr))
        
        # Stage 1 specific settings
        config.use_meta_pc = True
        config.solver_type = 'vanilla'
        # User setting priority, otherwise determined by stage1_use_block_sweep
        if self.loop_scheduler:
            config.update_latent = self.loop_scheduler
            config.loop_scheduler = self.loop_scheduler
        else:
            config.update_latent = 'block_sweep_gs' if self.stage1_use_block_sweep else 'jacobi'
            config.loop_scheduler = 'block_sweep_gs' if self.stage1_use_block_sweep else 'jacobi'
        config.T = self.T_meta
        config.eta = self.eta_meta
        config.param_update_method = 'pred_freeze'
        config.pred_freeze = True
        config.enable_two_stage = False  # Operate as single model
        
        return config
        
    def get_stage2_config(self):
        """2단계 Vanilla PC + Advanced Solver 설정 생성"""
        config = PCNConfig()
        
        # Copy default settings
        for attr in ['dataset', 'num_classes', 'norm_type', 'convergence_threshold',
                     'anderson_m', 'broyden_memory_size', 'track_trajectory']:
            if hasattr(self, attr):
                setattr(config, attr, getattr(self, attr))
        
        # Stage 2 specific settings
        config.use_meta_pc = False
        config.solver_type = self.stage2_solver
        # User setting priority, otherwise use jacobi default
        if self.loop_scheduler:
            config.update_latent = self.loop_scheduler
            config.loop_scheduler = self.loop_scheduler
        else:
            config.update_latent = 'jacobi'  # Use with advanced solver
            config.loop_scheduler = 'jacobi'
        config.T = self.T_total - self.T_meta
        config.eta = self.eta
        config.param_update_method = 'standard'
        config.pred_freeze = False
        config.enable_two_stage = False  # Operate as single model
        
        return config
    
    def get_summary(self):
        """설정 요약"""
        summary = {
            'model_type': self.model_type,
            'use_meta_pc': self.use_meta_pc,
            'update_latent': self.update_latent,
            'solver_type': self.solver_type,
            'param_update_method': self.param_update_method,
            'norm_type': self.norm_type,
            'iterations': self.T,
            'learning_rate': self.eta
        }
        
        # Additional info for two-stage mode
        if self.enable_two_stage:
            summary.update({
                'enable_two_stage': True,
                'T_meta': self.T_meta,
                'T_total': self.T_total,
                'stage2_solver': self.stage2_solver,
                'eta_meta': self.eta_meta,
                'stage1_use_block_sweep': self.stage1_use_block_sweep
            })
        
        return summary