# -*- coding: utf-8 -*-
"""
New PCN Model - Modularized PCN unified model

This module provides a unified model class that supports all PCN variants.
조건부 로직 없이 깔끔한 forward pass를 구현합니다.
"""

try:
    import torch
    import torch.nn as nn
    TORCH_AVAILABLE = True
except ImportError:
    TORCH_AVAILABLE = False

from typing import List, Dict, Any, Optional, Tuple
from .config import PCNConfig
from .latent_updaters.loop_scheduler import LoopScheduler
from .solvers.base_solver import BaseSolver, create_solver
from .normalization.normalization_wrapper import NormalizationManager



class NewPCNModel(nn.Module):
    """
    Modularized PCN unified model

    모든 PCN variants (Vanilla PC, Meta-PC)과
    모든 solver (Vanilla, Anderson, Broyden)를 supports unified model입니다.
    각 component가 clearly separated 있어 maintenance easy합니다.
    """

    def __init__(self, backbone_model, config: PCNConfig):
        """
        Args:
            backbone_model: PyTorch backbone model (backbone part of existing PCN model)
            config: PCNConfig object
        """
        if not TORCH_AVAILABLE:
            raise ImportError("PyTorch is required")

        super().__init__()

        self.config = config
        self.backbone_model = backbone_model

        # Initialize components
        self._init_components()

        # Apply normalization
        self._apply_normalization()

        # Initialization complete flag
        self._initialized = True

    def _init_components(self):
        """모든 component Initialize"""
        # 1. latent updater Initialize
        self._init_latent_updater()

        # 2. loop scheduler Initialize
        self._init_loop_scheduler()

        # 3. solver Initialize
        self._init_solver()

        # 4. normalization manager Initialize
        self._init_normalization_manager()

    def _init_latent_updater(self):
        """latent updater Initialize"""
        if self.config.use_meta_pc:
            from .latent_updaters.meta_pc_updater import MetaPCUpdater
            self.latent_updater = MetaPCUpdater(self.config)
        else:
            from .latent_updaters.vanilla_pc_updater import VanillaPCUpdater
            self.latent_updater = VanillaPCUpdater(self.config)

    def _init_loop_scheduler(self):
        """loop scheduler Initialize"""
        # loop_scheduler 파라미터가 명시적으로 설정된 경우 사용, 없으면 기본값
        scheduler_type = getattr(self.config, 'loop_scheduler', None) or self.config.update_latent
        self.loop_scheduler = LoopScheduler(scheduler_type)

    def _init_solver(self):
        """solver Initialize"""
        solver_config = self.config.get_solver_config()

        # solver별 설정을 config 객체로 변환
        class SolverArgs:
            def __init__(self, config_dict):
                for key, value in config_dict.items():
                    setattr(self, key, value)

        solver_args = SolverArgs(solver_config)
        self.solver = create_solver(self.config.solver_type, solver_args)

    def _init_normalization_manager(self):
        """normalization manager Initialize"""
        self.norm_manager = NormalizationManager(self.config)

    def _apply_normalization(self):
        """모델에 정규화 적용"""
        if hasattr(self.backbone_model, 'backbone_module_list'):
            self.norm_manager.apply_normalization(
                self.backbone_model,
                self.config.norm_type
            )

            # Combined Weight Clipping Initialize (clipping이 활성화된 경우)
            if hasattr(self.config, 'clip') and self.config.clip:
                from improved_weight_clipping import WeightClipper
                clip_value = getattr(self.config, 'clip_value', 1.0)
                self._weight_clipper = WeightClipper(self.backbone_model, clip_value)

    def forward(self, x, y, optimizer=None, return_info=False):
        """
        PCN forward pass

        Args:
            x: input data
            y: target data
            optimizer: 옵티마이저 (parent와 호환성)
            return_info: convergence info whether to return

        Returns:
            loss: loss
            info (optional): convergence info
        """
        # 1. FF Initialize
        zs_init = self._init_latent_states(x, y)

        # Store initial feedforward prediction for Acc_forward
        initial_pred = zs_init[-1].detach()

        # 2. Meta-PC의 경우 frozen predictions 준비
        solver_kwargs = {
            'loop_scheduler': self.loop_scheduler,
            'backbone_module_list': self.backbone_model.backbone_module_list
        }

        if self.config.use_meta_pc:
            zhs = self.latent_updater.prepare_frozen_predictions(zs_init)
            solver_kwargs['zhs'] = zhs

        # 3. 잠재 변수 수렴 (solver 사용)
        zs_final, convergence_info = self.solver.solve(
            self.latent_updater,
            zs_init,
            x, y,
            **solver_kwargs
        )

        # 4. 매개변수 업데이트 및 손실 계산
        if self.config.use_meta_pc and 'zhs' in solver_kwargs:
            loss = self._update_parameters_and_compute_loss(zs_final, x, y, optimizer, zhs=solver_kwargs['zhs'])
        else:
            loss = self._update_parameters_and_compute_loss(zs_final, x, y, optimizer)

        # 5. Combined Weight Clipping 적용 (TorchDEQ + Direct 방식)
        if hasattr(self, '_weight_clipper') and optimizer is not None:
            self._weight_clipper.clip_weights()

        # 4. 정규화 리셋 (매개변수 업데이트 후)
        self._reset_normalization_if_needed()

        # Parent branch와 동일한 형태로 return
        result = {
            'loss': loss,
            'pred': initial_pred
        }

        if return_info:
            result['convergence_info'] = convergence_info

        return result

    def _init_latent_states(self, x, y):
        """잠재 상태 Initialize (FF pass)"""
        if hasattr(self.backbone_model, 'init_zs_ff'):
            return self.backbone_model.init_zs_ff(x, y)
        else:
            # 기본 FF Initialize
            zs = [x]
            with torch.no_grad():
                for module in self.backbone_model.backbone_module_list:
                    z_next = module(zs[-1])
                    zs.append(z_next.detach())
            return zs

    def _update_parameters_and_compute_loss(self, zs_final, x, y, optimizer=None, zhs=None):
        """매개변수 업데이트 및 손실 계산"""
        # 기존 pcn_base.py의 매개변수 업데이트 메서드 사용
        method_name = f"_update_params_{self.config.param_update_method}"

        if hasattr(self.backbone_model, method_name):
            # 기존 메서드 호출
            update_method = getattr(self.backbone_model, method_name)

            if self.config.param_update_method == 'pred_freeze' and zhs is not None:
                # Meta-PC의 경우 frozen predictions 사용
                return update_method(zs_final, x, y, optimizer=optimizer, zhs=zhs)
            else:
                return update_method(zs_final, x, y, optimizer)
        else:
            # Fallback: 기본 손실 계산
            return self._compute_basic_loss(zs_final, x, y)

    def _compute_basic_loss(self, zs, x, y):
        """기본 PCN 손실 계산 (fallback)"""
        total_loss = 0.0

        # Input reconstruction loss
        if zs[0] is not None:
            total_loss += torch.sum((zs[0] - x) ** 2)

        # Backbone modules: prediction losses + final classification (기존 구조 따름)
        for i, module in enumerate(self.backbone_model.backbone_module_list):
            if i != len(self.backbone_model.backbone_module_list) - 1:
                # 중간 레이어들: prediction loss
                pred = module(zs[i])
                total_loss += torch.sum((zs[i + 1] - pred) ** 2)
            else:
                # 마지막 레이어: classification loss (기존 pcn_base.py 구조)
                logits = module(zs[-2])
                total_loss += torch.nn.functional.cross_entropy(logits, y)

        return total_loss

    def _reset_normalization_if_needed(self):
        """필요 시 정규화 리셋"""
        if self.norm_manager.is_applied():
            self.norm_manager.reset_normalization(self.backbone_model)

    def get_components_info(self):
        """component 정보 반환"""
        return {
            'latent_updater': self.latent_updater.__class__.__name__,
            'loop_scheduler': f"LoopScheduler({self.config.update_latent})",
            'solver': self.solver.__class__.__name__,
            'normalization': f"Applied: {self.norm_manager.is_applied()}",
            'config_summary': self.config.get_summary()
        }



    def predict_forward(self, x):
        """Feedforward prediction without PCN iterations"""
        zs = [x]
        for module in self.backbone_model.backbone_module_list:
            z_next = module(zs[-1])
            zs.append(z_next)
        return {'pred': zs[-1]}

    def reset_solver_memory(self):
        """solver 메모리 리셋 (샘플 간 간섭 방지)"""
        self.solver.reset_memory()

    def __repr__(self):
        return (f"NewPCNModel(\n"
                f"  config={self.config.get_summary()},\n"
                f"  components={self.get_components_info()}\n"
                f")")


def create_new_pcn_model(backbone_model, args=None, config=None):
    """
    NewPCNModel 생성 편의 함수

    Args:
        backbone_model: PyTorch 백본 모델
        args: 기존 argparse.Namespace (하위 호환성)
        config: PCNConfig 객체 (우선순위)

    Returns:
        NewPCNModel: Initialize된 모델
    """
    if config is None:
        config = PCNConfig(args)

    return NewPCNModel(backbone_model, config)
