# -*- coding: utf-8 -*-
"""
PCN Compatible Model - Implementation of exactly same logic as parent branch
"""

import torch
import torch.nn as nn
from typing import Dict, Any


class PCNCompatibleModel:
    """
    Model implementing exactly the same logic as parent branch PCN_jacobi
    """
    
    def __init__(self, backbone_model, config):
        self.backbone_model = backbone_model
        self.config = config
        
    def init_zs_ff(self, x, y=None):
        """Feedforward initialization - same as parent"""
        zs = [x]
        for backbone_module in self.backbone_model.backbone_module_list:
            r = backbone_module(zs[-1])
            zs.append(r)
        return zs
    
    def predict_forward(self, x):
        """Feedforward prediction - same as parent"""
        r = x
        for backbone_module in self.backbone_model.backbone_module_list:
            r = backbone_module(r)
        return {'pred': r}
    
    def forward(self, x, y, optimizer, **kwargs):
        """Forward pass exactly same as parent branch"""
        zs_t = self.init_zs_ff(x, y)

        # Store initial feedforward prediction for Acc_forward
        initial_pred = zs_t[-1].detach()

        pc_loss_list = []
        pc_energy_metric_list = []

        # Same as parent: T+1 iterations
        for t in range(self.config.T + 1):
            pc_energy_metric = self.calculate_pc_energy_metric(zs_t, x, y)
            pc_energy_metric_list.append(pc_energy_metric)

            delta_zs_t, pc_loss = self.forward_train(zs_t, x, y, t, optimizer)
            zs_t = [z + delta_z for z, delta_z in zip(zs_t, delta_zs_t)]
            pc_loss_list.append(pc_loss.item())

        return {
            'loss': pc_loss,  # Main loss for training
            'pred': zs_t[-1],
            'initial_pred': initial_pred,  # Initial feedforward prediction for Acc_forward
            'pc_loss': pc_loss,
            'pc_loss_list': pc_loss_list,
            'pc_energy_metric_list': pc_energy_metric_list
        }
    
    def forward_train(self, zs_t, x, y, t, optimizer):
        """Training step exactly same as parent branch"""
        optimizer.zero_grad()
        
        with torch.enable_grad():
            zs_state_t = [z.clone().detach().requires_grad_(True) for z in zs_t]
            
            # Loss calculation same as parent
            pc_loss = torch.sum((zs_state_t[0] - x)**2) / x.shape[0]
            for idx, backbone_module in enumerate(self.backbone_model.backbone_module_list):
                if idx != len(self.backbone_model.backbone_module_list) - 1:
                    pc_loss += torch.sum((zs_state_t[idx+1] - backbone_module(zs_state_t[idx]))**2) / x.shape[0]
                else:
                    pc_loss += torch.nn.functional.cross_entropy(backbone_module(zs_state_t[-2]), y)
            
            pc_loss.backward()
            
            # Delta calculation same as parent
            delta_zs_t = [torch.zeros_like(z) for z in zs_state_t]
            for idx in range(1, len(zs_state_t) - 1):
                delta_zs_t[idx] = -self.config.eta * zs_state_t[idx].grad * x.shape[0]
        
        # Same as parent: parameter update only in last iteration
        if t == self.config.T:
            optimizer.step()
        
        return delta_zs_t, pc_loss
    
    def calculate_pc_energy_metric(self, zs_t, x, y):
        """Energy metric calculation same as parent branch"""
        pc_energy_metric = []
        pc_energy_metric.append((torch.sum((zs_t[0] - x)**2) / x.shape[0]).item())
        
        for idx, backbone_module in enumerate(self.backbone_model.backbone_module_list):
            if idx != len(self.backbone_model.backbone_module_list) - 1:
                pc_energy_metric.append((torch.sum((zs_t[idx+1] - backbone_module(zs_t[idx]))**2) / x.shape[0]).item())
            else:
                pc_energy_metric.append((torch.nn.functional.cross_entropy(backbone_module(zs_t[-2]), y)).item())
        
        return pc_energy_metric
    
    def reset_solver_memory(self):
        """Empty function for compatibility"""
        pass