# -*- coding: utf-8 -*-
"""
Meta-PC Updater - Meta-PC latent update

This module implements latent variable update using Meta-PC's frozen prediction.
Performs linearized updates using the gradient replacement logic from existing pcn_base.py.
"""

import torch
from typing import List, Optional


class MetaPCUpdater:
    """
    Meta-PC updater

    Performs linearized gradient update using frozen feedforward prediction.
    Core idea: fix FF prediction at inference start,
    and use composite gradient (current + frozen).
    """

    def __init__(self, args):
        self.eta = getattr(args, 'eta', 0.2)

    def prepare_frozen_predictions(self, zs_init: List[torch.Tensor]) -> List[torch.Tensor]:
        """
        Fix initial FF prediction

        This is Meta-PC's core: calculate only once at inference loop start,
        and use these values fixed throughout the entire inference process.

        Important: use only clone() same as parent branch (no .detach())
        """
        zhs = []
        for z in zs_init:
            zhs.append(z.clone())  # frozen copy - no detach same as parent
        return zhs

    def compute_delta_zs(self, zs: List[torch.Tensor], x: torch.Tensor, y: torch.Tensor,
                        zhs: Optional[List[torch.Tensor]] = None, **kwargs) -> List[torch.Tensor]:
        """
        Meta-PC delta zs calculation

        Uses the gradient replacement logic from existing pcn_base.py:143-173 as-is.
        Linearized update formula is already implemented.
        """
        if zhs is None:
            raise ValueError("Meta-PC requires frozen predictions (zhs)")

        backbone_module_list = kwargs.get('backbone_module_list')
        if backbone_module_list is None:
            raise ValueError("backbone_module_list required for Meta-PC update")

        with torch.enable_grad():
            # 1. Prepare basic latent states
            zs_state_t = [z.clone().detach().requires_grad_(True) for z in zs]

            # 2. Meta-PC: prepare frozen predictions for gradient tracking
            zhs_state_t = [z.clone().detach().requires_grad_(True) for z in zhs]

            # 3. Pre-calculate frozen predictions (values to replace)
            predictions = []
            for idx, backbone_module in enumerate(backbone_module_list):
                predictions.append(backbone_module(zhs_state_t[idx]))

            # 4. Loss calculation (replace prediction values)
            pc_loss = torch.sum((zs_state_t[0] - x)**2) / x.shape[0]
            for idx, backbone_module in enumerate(backbone_module_list):
                if idx != len(backbone_module_list) - 1:
                    # Replace prediction values: use frozen prediction
                    pred_value = predictions[idx]
                    pc_loss += torch.sum((zs_state_t[idx+1] - pred_value)**2) / x.shape[0]
                else:
                    # Replace final classification: use frozen prediction
                    final_pred = predictions[-1]
                    pc_loss += torch.nn.functional.cross_entropy(final_pred, y)

            pc_loss.backward()

            # 5. Gradient replacement (composite gradient)
            delta_zs_t = [torch.zeros_like(z) for z in zs_state_t]
            for idx in range(1, len(zs_state_t) - 1):
                # Core: same gradient combination method as parent branch
                grad_value = (zs_state_t[idx].grad + zhs_state_t[idx].grad) if zhs_state_t is not None else zs_state_t[idx].grad
                delta_zs_t[idx] = -self.eta * grad_value * x.shape[0]

        return delta_zs_t
