# -*- coding: utf-8 -*-
"""
Base Solver - 솔버 추상 인터페이스

이 모듈은 PCN의 잠재 변수 수렴을 위한 솔버 인터페이스를 정의합니다.
솔버는 delta_zs를 받아 zs를 업데이트하는 역할만 담당합니다.
"""

try:
    import torch
    TORCH_AVAILABLE = True
except ImportError:
    TORCH_AVAILABLE = False

from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional, Tuple


def create_solver(solver_name: str, args):
    """
    솔버 생성 함수
    
    Args:
        solver_name: 솔버 이름 ('vanilla', 'anderson', 'broyden')
        args: 설정 객체
        
    Returns:
        초기화된 솔버 인스턴스
    """
    if solver_name == 'vanilla':
        from .vanilla_solver import VanillaSolver
        return VanillaSolver(args)
    elif solver_name in ['anderson', 'anderson_type1']:
        from .anderson_solver import AndersonSolver
        return AndersonSolver(args)
    elif solver_name in ['broyden', 'broyden_uv']:
        from .broyden_solver import BroydenSolver
        return BroydenSolver(args)
    else:
        available = ['vanilla', 'anderson', 'broyden']
        raise ValueError(f"Unknown solver: {solver_name}. Available: {available}")


class BaseSolver(ABC):
    """
    잠재 변수 수렴을 위한 추상 기본 클래스
    
    솔버는 latent updater로부터 delta_zs를 받아 
    이를 사용해 zs를 업데이트하는 역할만 수행합니다.
    가속화 알고리즘(Anderson, Broyden)은 이 과정을 최적화합니다.
    """
    
    def __init__(self, args):
        """
        Args:
            args: 설정 객체 (T, eta, solver 관련 파라미터 등)
        """
        self.args = args
        self.T = getattr(args, 'T', 10)  # 최대 반복 횟수
        self.eta = getattr(args, 'eta', 0.2)  # 학습률
        
        # 수렴 추적
        self.trajectory = []
        self.convergence_info = {}
        
        # 메모리 관리
        self.memory_initialized = False
        
    @abstractmethod
    def solve(self, latent_updater, zs_init, 
              x, y, 
              loop_scheduler=None, **kwargs):
        """
        잠재 변수 수렴을 수행합니다.
        
        Args:
            latent_updater: delta_zs를 계산하는 updater (VanillaPCUpdater 또는 MetaPCUpdater)
            zs_init: 초기 잠재 상태들
            x: 입력 데이터
            y: 타겟 데이터
            loop_scheduler: 업데이트 순서 스케줄러 (optional)
            **kwargs: 추가 파라미터 (zhs for Meta-PC 등)
            
        Returns:
            zs_final: 수렴된 잠재 상태들
            info: 수렴 정보 (iterations, energy, timing 등)
        """
        pass
    
    @abstractmethod
    def initialize_memory(self, zs_shapes):
        """
        새로운 inference를 위한 메모리 초기화
        
        Args:
            zs_shapes: 각 레이어의 형태
        """
        pass
    
    @abstractmethod  
    def reset_memory(self):
        """
        Inference 종료 후 메모리 완전 초기화
        
        다음 샘플과의 간섭을 방지하기 위해 
        모든 history와 auxiliary 변수를 제거합니다.
        """
        pass
    
    def get_trajectory(self):
        """
        수렴 궤적 반환
        
        Returns:
            각 iteration에서의 zs 상태들
        """
        return self.trajectory
    
    def get_convergence_info(self):
        """
        수렴 정보 반환
        
        Returns:
            iterations, final_energy, timing 등의 정보
        """
        return self.convergence_info
    
    def compute_energy(self, zs, x, backbone_module_list):
        """
        현재 상태의 에너지 계산 (수렴 모니터링용)
        
        Args:
            zs: 현재 잠재 상태들
            x: 입력 데이터
            backbone_module_list: 백본 모듈들
            
        Returns:
            에너지 값
        """
        if not TORCH_AVAILABLE:
            return 0.0
            
        total_energy = 0.0
        
        # Input reconstruction energy
        if zs[0] is not None and x is not None:
            total_energy += torch.sum((zs[0] - x) ** 2).item()
        
        # Prediction energies
        for idx, backbone_module in enumerate(backbone_module_list):
            if idx < len(zs) - 1:
                with torch.no_grad():
                    pred = backbone_module(zs[idx])
                    total_energy += torch.sum((zs[idx + 1] - pred) ** 2).item()
        
        return total_energy
    
    def check_convergence(self, delta_zs, threshold=1e-4):
        """
        수렴 여부 확인
        
        Args:
            delta_zs: 현재 업데이트 delta 값들
            threshold: 수렴 임계값
            
        Returns:
            수렴 여부
        """
        if not TORCH_AVAILABLE:
            return True  # Mock convergence when PyTorch unavailable
            
        total_delta_norm = 0.0
        for delta_z in delta_zs:
            if delta_z is not None:
                total_delta_norm += torch.norm(delta_z).item()
        
        return total_delta_norm < threshold
    
    def get_solver_type(self):
        """솔버 타입 반환"""
        return self.__class__.__name__.lower().replace('solver', '')
    
    def supports_acceleration(self):
        """가속화 지원 여부"""
        return False  # 기본적으로 지원하지 않음
    
    def get_memory_requirements(self, batch_size, layer_shapes):
        """
        메모리 요구사항 계산
        
        Args:
            batch_size: 배치 크기
            layer_shapes: 각 레이어의 형태
            
        Returns:
            메모리 사용량 정보 (bytes)
        """
        return {
            'base_memory': 0,
            'history_memory': 0,
            'auxiliary_memory': 0
        }