# -*- coding: utf-8 -*-
"""
Loop Scheduler - 업데이트 순서 관리자

이 모듈은 PCN의 잠재 변수 업데이트 순서를 관리합니다.
기존의 Jacobi, Gauss-Seidel, Block Sweep을 개별 구현하지 않고,
스케줄링 패턴으로 통합하여 하나의 업데이트 엔진으로 처리합니다.
"""


try:
    import torch
    TORCH_AVAILABLE = True
except ImportError:
    TORCH_AVAILABLE = False

from typing import List, Dict, Any
from enum import Enum


class UpdatePattern(Enum):
    """업데이트 패턴 정의"""
    JACOBI = "jacobi"
    GAUSS_SEIDEL = "gs"
    BLOCK_SWEEP = "block_sweep_gs"


class LoopScheduler:
    """
    잠재 변수 업데이트 순서를 관리하는 스케줄러

    핵심 아이디어:
    - Jacobi: [[1,2,...,n-2]] - 모든 레이어 동시 업데이트
    - GS: [[n-2],[n-3],...,[1]] - 레이어별 순차 업데이트
    - Block Sweep: [even_layers, odd_layers] - 짝수/홀수 번갈아 업데이트
    """

    def __init__(self, pattern: str = "jacobi"):
        """
        Args:
            pattern: 업데이트 패턴 ("jacobi", "gs", "block_sweep_gs")
        """
        self._pattern_enum = UpdatePattern(pattern)

    @property
    def pattern(self) -> str:
        """pattern 속성을 문자열로 반환 (사용자 친화적)"""
        return self._pattern_enum.value

    def get_update_schedule(self, n_layers: int) -> List[List[int]]:
        """
        주어진 레이어 수에 대한 업데이트 스케줄 생성

        Args:
            n_layers: 전체 레이어 수 (input + hidden + output)

        Returns:
            schedule: 업데이트 순서 [[layer_indices], [layer_indices], ...]
        """
        if n_layers < 3:
            raise ValueError(f"최소 3개 레이어 필요 (input, hidden, output), 현재: {n_layers}")

        hidden_layers = list(range(1, n_layers - 1))  # 1부터 n-2까지 (hidden layers만)

        if self._pattern_enum == UpdatePattern.JACOBI:
            return self._get_jacobi_schedule(hidden_layers)
        elif self._pattern_enum == UpdatePattern.GAUSS_SEIDEL:
            return self._get_gauss_seidel_schedule(hidden_layers)
        elif self._pattern_enum == UpdatePattern.BLOCK_SWEEP:
            return self._get_block_sweep_schedule(hidden_layers)
        else:
            raise ValueError(f"지원되지 않는 패턴: {self._pattern_enum}")

    def _get_jacobi_schedule(self, hidden_layers: List[int]) -> List[List[int]]:
        """Jacobi: 모든 hidden layer 동시 업데이트"""
        return [hidden_layers]  # 한 번에 모든 레이어

    def _get_gauss_seidel_schedule(self, hidden_layers: List[int]) -> List[List[int]]:
        """Gauss-Seidel: output에서 input 방향으로 순차 업데이트"""
        return [[layer] for layer in reversed(hidden_layers)]  # 역순으로 하나씩

    def _get_block_sweep_schedule(self, hidden_layers: List[int]) -> List[List[int]]:
        """
        Block Sweep: 짝수/홀수 레이어 번갈아 업데이트

        주의: 기존 pcn_base.py의 버그 수정
        - 잘못된 공식: len(zs_state_t) - idx % 2
        - 올바른 공식: (len(zs_state_t) - idx) % 2
        """
        if len(hidden_layers) == 0:
            return []

        n_total = len(hidden_layers) + 2  # input + hidden + output

        even_layers = []
        odd_layers = []

        for idx in hidden_layers:
            # 올바른 연산자 우선순위: 괄호 추가
            if (n_total - idx) % 2 == 0:
                even_layers.append(idx)
            else:
                odd_layers.append(idx)

        # 비어있지 않은 그룹들만 반환
        schedule = []
        if even_layers:
            schedule.append(even_layers)
        if odd_layers:
            schedule.append(odd_layers)

        return schedule

    def apply_updates(self, delta_zs, zs, schedule):
        """
        스케줄에 따라 업데이트 적용

        Args:
            delta_zs: 각 레이어별 delta 값들 (List of tensors or mock objects)
            zs: 현재 잠재 상태들 (List of tensors or mock objects)
            schedule: 업데이트 스케줄

        Returns:
            updated_zs: 업데이트된 잠재 상태들
        """
        if TORCH_AVAILABLE and hasattr(zs[0], 'clone'):
            # PyTorch tensors
            updated_zs = [z.clone() for z in zs]

            for layer_group in schedule:
                for layer_idx in layer_group:
                    if 0 <= layer_idx < len(delta_zs):
                        updated_zs[layer_idx] = (updated_zs[layer_idx] + delta_zs[layer_idx]).detach()
        else:
            # Mock objects or other types - just return copy
            updated_zs = list(zs)  # Simple copy for testing

        return updated_zs

    def validate_schedule(self, schedule: List[List[int]], n_layers: int) -> bool:
        """스케줄 검증"""
        if not schedule:
            return True  # 빈 스케줄도 유효

        # 모든 hidden layer가 정확히 한 번씩 업데이트되는지 확인
        expected_layers = set(range(1, n_layers - 1))
        scheduled_layers = set()

        for group in schedule:
            for layer_idx in group:
                if layer_idx in scheduled_layers:
                    return False  # 중복 업데이트
                scheduled_layers.add(layer_idx)

        return scheduled_layers == expected_layers

    def get_pattern_info(self) -> Dict[str, Any]:
        """현재 패턴 정보 반환"""
        info = {
            'pattern': self.pattern,
            'parallel': self._pattern_enum == UpdatePattern.JACOBI,
            'sequential': self._pattern_enum == UpdatePattern.GAUSS_SEIDEL,
            'block_based': self._pattern_enum == UpdatePattern.BLOCK_SWEEP
        }

        if self._pattern_enum == UpdatePattern.JACOBI:
            info['description'] = "모든 hidden layer 동시 업데이트 (병렬성 높음, 수렴성 낮음)"
        elif self._pattern_enum == UpdatePattern.GAUSS_SEIDEL:
            info['description'] = "레이어별 순차 업데이트 (병렬성 낮음, 수렴성 높음)"
        elif self._pattern_enum == UpdatePattern.BLOCK_SWEEP:
            info['description'] = "짝수/홀수 레이어 번갈아 업데이트 (절충안)"

        return info
