# -*- coding: utf-8 -*-
"""
Vanilla PC Updater - 표준 예측 코딩 업데이터
"""

import torch
from typing import List, Optional


class VanillaPCUpdater:
    """
    Vanilla PC (표준 예측 코딩) 업데이터

    표준 PCN의 자유 에너지 최소화를 통한 잠재 변수 업데이트를 수행합니다.
    Loss: ∑(z_{l+1} - f_l(z_l))^2 + CrossEntropy(f_final(z_L), y)
    """

    def __init__(self, args):
        self.eta = getattr(args, 'eta', 0.2)

    def compute_delta_zs(self, zs: List[torch.Tensor], x: torch.Tensor, y: torch.Tensor,
                        zhs: Optional[List[torch.Tensor]] = None, **kwargs) -> List[torch.Tensor]:
        """
        Vanilla PC delta zs 계산

        기존 pcn_base.py의 Jacobi 업데이트 로직을 사용하되,
        Meta-PC 부분은 제외하고 순수 Vanilla PC만 구현
        """
        backbone_module_list = kwargs.get('backbone_module_list')
        if backbone_module_list is None:
            raise ValueError("backbone_module_list required for vanilla PC update")

        with torch.enable_grad():
            # 1. gradient 계산을 위한 latent states 준비
            zs_state_t = [z.clone().detach().requires_grad_(True) for z in zs]

            # 2. Vanilla PC Loss 계산
            pc_loss = torch.sum((zs_state_t[0] - x)**2) / x.shape[0]

            for idx, backbone_module in enumerate(backbone_module_list):
                if idx != len(backbone_module_list) - 1:
                    # Hidden layer prediction error
                    pred = backbone_module(zs_state_t[idx])
                    pc_loss += torch.sum((zs_state_t[idx+1] - pred)**2) / x.shape[0]
                else:
                    # Final classification error
                    final_pred = backbone_module(zs_state_t[-2])
                    pc_loss += torch.nn.functional.cross_entropy(final_pred, y)

            # 3. Backward pass
            pc_loss.backward()

            # 4. Delta zs 계산 (hidden layers만)
            delta_zs_t = [torch.zeros_like(z) for z in zs_state_t]
            for idx in range(1, len(zs_state_t) - 1):
                if zs_state_t[idx].grad is not None:
                    delta_zs_t[idx] = -self.eta * zs_state_t[idx].grad * x.shape[0]

        return delta_zs_t
