# -*- coding: utf-8 -*-
"""
Pred-Freeze Backbone Wrapper
Wrapper class implementing _update_params_pred_freeze logic from parent branch
"""

import torch
import torch.nn as nn
from typing import List, Optional



class PredFreezeBackbone(nn.Module):
    """
    Backbone wrapper supporting pred-freeze parameter update

    Implements parent branch's _update_params_pred_freeze logic as-is.
    """

    def __init__(self, backbone_module_list):
        super().__init__()
        self.backbone_module_list = backbone_module_list

    def forward(self, x):
        """Forward pass same as existing"""
        for module in self.backbone_module_list:
            x = module(x)
        return x

    def update_params_standard(self, zs_t: List[torch.Tensor], x: torch.Tensor,
                             y: torch.Tensor):
        """
        Standard parameter update (parent branch의 _update_params_standard와 동일)
        """
        with torch.enable_grad():
            pc_loss = torch.sum((zs_t[0] - x)**2) / x.shape[0]
            for idx, backbone_module in enumerate(self.backbone_module_list):
                if idx != len(self.backbone_module_list) - 1:
                    pc_loss += torch.sum((zs_t[idx+1] - backbone_module(zs_t[idx]))**2) / x.shape[0]
                else:
                    pc_loss += torch.nn.functional.cross_entropy(backbone_module(zs_t[-2]), y)
        return pc_loss

    def update_params_pred_freeze(self, zs_t: List[torch.Tensor], x: torch.Tensor,
                                 y: torch.Tensor, zhs: Optional[List[torch.Tensor]] = None):
        """
        Pred-freeze parameter update (parent branch의 _update_params_pred_freeze와 동일)

        Core: frozen predictions (zhs)를 사용해서 loss calculate
        """
        if zhs is None:
            raise ValueError("pred_freeze requires zhs (frozen predictions)")

        with torch.enable_grad():
            # 1. Frozen predictions calculate (parent branch line 212-213 (same))
            zs_0_pred_t = []
            for idx, backbone_module in enumerate(self.backbone_module_list):
                zs_0_pred_t.append(backbone_module(zhs[idx]))

            # 2. Loss calculate (parent branch line 215-220 (same))
            pc_loss = torch.sum((zs_t[0] - x)**2) / x.shape[0]
            for idx, backbone_module in enumerate(self.backbone_module_list):
                if idx != len(self.backbone_module_list) - 1:
                    # Core: 현재 zs_t와 frozen prediction zs_0_pred_t compare
                    pc_loss += torch.sum((zs_t[idx+1] - zs_0_pred_t[idx])**2) / x.shape[0]
                else:
                    # Core: frozen prediction으로 classification loss
                    pc_loss += torch.nn.functional.cross_entropy(zs_0_pred_t[-1], y)

        return pc_loss

    def update_params_pred_freeze_with_optimizer(self, zs_t: List[torch.Tensor],
                                               x: torch.Tensor, y: torch.Tensor,
                                               optimizer, zhs: Optional[List[torch.Tensor]] = None):
        """
        Parent branch와 동일한 pred_freeze parameter update
        optimizer.zero_grad() -> backward() -> optimizer.step() 포함
        """
        if zhs is None:
            raise ValueError("pred_freeze requires zhs (frozen predictions)")

        # 1. Gradient 초기화 (parent branch line 208)
        optimizer.zero_grad()

        with torch.enable_grad():
            # 2. Frozen predictions calculate (parent branch line 212-213 (same))
            zs_0_pred_t = []
            for idx, backbone_module in enumerate(self.backbone_module_list):
                zs_0_pred_t.append(backbone_module(zhs[idx]))

            # 3. Loss calculate (parent branch line 215-220 (same))
            pc_loss = torch.sum((zs_t[0] - x)**2) / x.shape[0]
            for idx, backbone_module in enumerate(self.backbone_module_list):
                if idx != len(self.backbone_module_list) - 1:
                    # Core: 현재 zs_t와 frozen prediction zs_0_pred_t compare
                    pc_loss += torch.sum((zs_t[idx+1] - zs_0_pred_t[idx])**2) / x.shape[0]
                else:
                    # Core: frozen prediction으로 classification loss
                    pc_loss += torch.nn.functional.cross_entropy(zs_0_pred_t[-1], y)

            # 4. Backward pass (parent branch line 221)
            pc_loss.backward()

        # 5. Parameter update (parent branch line 223)
        optimizer.step()

        return pc_loss

    def _update_params_pred_freeze(self, zs_final: List[torch.Tensor],
                                  x: torch.Tensor, y: torch.Tensor,
                                  optimizer,
                                  zhs: Optional[List[torch.Tensor]] = None):
        """
        NewPCNModel에서 기대하는 메서드명
        Parent branch와 동일한 시그니처 (optimizer 포함)
        """
        return self.update_params_pred_freeze_with_optimizer(zs_final, x, y, optimizer, zhs)

    def _update_params_standard(self, zs_final: List[torch.Tensor],
                               x: torch.Tensor, y: torch.Tensor):
        """
        NewPCNModel에서 기대하는 표준 메서드명
        """
        return self.update_params_standard(zs_final, x, y)

    def update_params_and_compute_loss(self, zs_final: List[torch.Tensor],
                                     x: torch.Tensor, y: torch.Tensor,
                                     zhs: Optional[List[torch.Tensor]] = None):
        """
        Parameter update + loss computation을 위한 통합 메서드
        NewPCNModel에서 호출될 메서드
        """
        if zhs is not None:
            # Meta-PC + pred_freeze mode
            return self.update_params_pred_freeze(zs_final, x, y, zhs)
        else:
            # Standard mode
            return self.update_params_standard(zs_final, x, y)
