# -*- coding: utf-8 -*-
"""
Broyden Solver - Broyden 준-뉴턴 방법 솔버

이 모듈은 Broyden의 준-뉴턴 방법을 사용하여 수렴을 가속화합니다.
기존 models/solvers/broyden_uv_solver.py를 기반으로
UV 분해를 통한 메모리 효율적인 구현을 제공합니다.
"""

import torch
import time
from typing import List, Dict, Any, Optional, Tuple
from .base_solver import BaseSolver


class BroydenSolver(BaseSolver):
    """
    Broyden 방법 솔버
    
    Jacobian의 low-rank approximation을 유지하여
    준-뉴턴 방향으로 업데이트를 수행합니다.
    메모리 효율적인 UV 분해를 사용합니다.
    """
    
    def __init__(self, args):
        super().__init__(args)
        self.memory_size = getattr(args, 'broyden_memory_size', 50)
        self.regularization = getattr(args, 'broyden_reg', 1e-6)
        self.convergence_threshold = getattr(args, 'convergence_threshold', 1e-4)
        self.track_trajectory = getattr(args, 'track_trajectory', False)
        
        # Broyden UV matrices (per layer)
        self.U_matrices = None  # U matrix for each layer
        self.V_matrices = None  # V matrix for each layer  
        self.s_vectors = None   # Diagonal scaling vectors
        
    def supports_acceleration(self) -> bool:
        return True
        
    def solve(self, latent_updater, zs_init: List[torch.Tensor], 
              x: torch.Tensor, y: torch.Tensor, 
              loop_scheduler=None, **kwargs) -> Tuple[List[torch.Tensor], Dict[str, Any]]:
        """
        Broyden 방법으로 잠재 변수를 수렴시킵니다.
        """
        # 메모리 초기화  
        self.initialize_memory([z.shape for z in zs_init])
        
        # 초기화
        zs = [z.clone() for z in zs_init]
        zs_prev = [z.clone() for z in zs_init]
        self.trajectory = []
        convergence_history = []
        broyden_updates = 0
        
        # 필요한 추가 파라미터 추출
        backbone_module_list = kwargs.get('backbone_module_list')
        zhs = kwargs.get('zhs', None)
        
        # 타이밍 시작
        start_time = time.time()
        
        # 첫 번째 iteration은 항상 vanilla
        delta_zs_prev = latent_updater.compute_delta_zs(
            zs, x, y, 
            zhs=zhs,
            backbone_module_list=backbone_module_list
        )
        zs = self._vanilla_step(zs, delta_zs_prev, loop_scheduler)
        
        # 반복 수렴
        for t in range(1, self.T):
            # 현재 상태 저장
            if self.track_trajectory:
                self.trajectory.append([z.clone().detach() for z in zs])
            
            # 현재 residual 계산
            delta_zs = latent_updater.compute_delta_zs(
                zs, x, y, 
                zhs=zhs,
                backbone_module_list=backbone_module_list
            )
            
            # Broyden 업데이트 (Loop Scheduler 고려)
            if t >= 1:
                if loop_scheduler is not None:
                    zs_new = self._broyden_step_with_scheduler(
                        zs, zs_prev, delta_zs, delta_zs_prev, t, loop_scheduler, 
                        latent_updater, x, y, zhs=zhs, backbone_module_list=backbone_module_list
                    )
                else:
                    zs_new = self._broyden_step(zs, zs_prev, delta_zs, delta_zs_prev, t)
                broyden_updates += 1
            else:
                zs_new = self._vanilla_step(zs, delta_zs, loop_scheduler)
            
            # 수렴 체크
            delta_norm = sum(torch.norm(d).item() for d in delta_zs if d is not None)
            convergence_history.append(delta_norm)
            
            # 다음 iteration을 위한 업데이트
            zs_prev = [z.clone() for z in zs]
            delta_zs_prev = [d.clone() for d in delta_zs]
            zs = zs_new
            
            if self.check_convergence(delta_zs, self.convergence_threshold):
                break
        
        # 타이밍 종료
        end_time = time.time()
        
        # 최종 에너지 계산
        final_energy = None
        if backbone_module_list is not None:
            final_energy = self.compute_energy(zs, x, backbone_module_list)
        
        # 수렴 정보 저장
        self.convergence_info = {
            'iterations': t + 1,
            'converged': t < self.T - 1,
            'final_delta_norm': convergence_history[-1] if convergence_history else 0.0,
            'final_energy': final_energy,
            'time_elapsed': end_time - start_time,
            'convergence_history': convergence_history,
            'broyden_updates': broyden_updates
        }
        
        return zs, self.convergence_info
    
    def _vanilla_step(self, zs: List[torch.Tensor], delta_zs: List[torch.Tensor],
                     loop_scheduler=None) -> List[torch.Tensor]:
        """표준 업데이트 스텝"""
        zs_new = [z.clone() for z in zs]
        
        if loop_scheduler is not None:
            n_layers = len(zs)
            schedule = loop_scheduler.get_update_schedule(n_layers)
            
            for layer_group in schedule:
                for layer_idx in layer_group:
                    if 0 < layer_idx < len(zs) - 1:  # hidden layers만
                        zs_new[layer_idx] = (zs[layer_idx] - self.eta * delta_zs[layer_idx]).detach()
        else:
            # 기본: Jacobi 업데이트
            for idx in range(1, len(zs) - 1):
                zs_new[idx] = (zs[idx] - self.eta * delta_zs[idx]).detach()
        
        return zs_new
    
    def _broyden_step(self, zs_current: List[torch.Tensor], 
                     zs_prev: List[torch.Tensor],
                     delta_zs_current: List[torch.Tensor], 
                     delta_zs_prev: List[torch.Tensor],
                     iteration: int) -> List[torch.Tensor]:
        """Broyden 업데이트 스텝"""
        zs_new = []
        
        for layer_idx in range(len(zs_current)):
            if layer_idx == 0 or layer_idx == len(zs_current) - 1:
                # Input/output layer는 그대로
                zs_new.append(zs_current[layer_idx])
                continue
                
            # Hidden layer에 Broyden 적용
            z_curr = zs_current[layer_idx]
            z_prev = zs_prev[layer_idx]
            f_curr = -self.eta * delta_zs_current[layer_idx]  # Current residual
            f_prev = -self.eta * delta_zs_prev[layer_idx]    # Previous residual
            
            # Broyden direction 계산
            broyden_direction = self._compute_broyden_direction(
                layer_idx, z_curr, z_prev, f_curr, f_prev
            )
            
            z_new = z_curr + broyden_direction
            zs_new.append(z_new.detach())
        
        return zs_new
    
    def _broyden_step_with_scheduler(self, zs_current: List[torch.Tensor], 
                                   zs_prev: List[torch.Tensor],
                                   delta_zs_current: List[torch.Tensor], 
                                   delta_zs_prev: List[torch.Tensor],
                                   iteration: int, loop_scheduler, latent_updater, 
                                   x: torch.Tensor, y: torch.Tensor, **kwargs) -> List[torch.Tensor]:
        """
        Loop Scheduler와 함께 사용하는 Broyden 가속 스텝
        
        Block Sweep의 핵심: 각 layer group별로 순차적 업데이트 후 delta 재계산
        """
        zs_new = [z.clone() for z in zs_current]
        n_layers = len(zs_current)
        schedule = loop_scheduler.get_update_schedule(n_layers)
        
        # 필요한 파라미터 추출
        zhs = kwargs.get('zhs', None)
        backbone_module_list = kwargs.get('backbone_module_list')
        current_delta_zs = [d.clone() for d in delta_zs_current]
        current_delta_zs_prev = [d.clone() for d in delta_zs_prev]
        
        # Block Sweep: 각 그룹별로 순차적 Broyden 가속 적용
        for group_idx, layer_group in enumerate(schedule):
            # 현재 그룹에 해당하는 layers에 Broyden 가속 적용
            for layer_idx in layer_group:
                if layer_idx == 0 or layer_idx == len(zs_current) - 1:
                    # Input/output layer는 그대로
                    continue
                    
                # Hidden layer에 Broyden 적용
                z_curr = zs_new[layer_idx]
                z_prev = zs_prev[layer_idx]
                f_curr = -self.eta * current_delta_zs[layer_idx]  # Current residual
                f_prev = -self.eta * current_delta_zs_prev[layer_idx]  # Previous residual
                
                # Broyden direction 계산
                broyden_direction = self._compute_broyden_direction(
                    layer_idx, z_curr, z_prev, f_curr, f_prev
                )
                
                z_new = z_curr + broyden_direction
                zs_new[layer_idx] = z_new.detach()
            
            # 다음 그룹을 위해 delta 재계산 (Block Sweep의 핵심)
            if group_idx < len(schedule) - 1:  # 마지막 그룹이 아니면
                current_delta_zs = latent_updater.compute_delta_zs(
                    zs_new, x, y,
                    zhs=zhs,
                    backbone_module_list=backbone_module_list
                )
        
        return zs_new
    
    def _compute_broyden_direction(self, layer_idx: int,
                                  z_curr: torch.Tensor, z_prev: torch.Tensor,
                                  f_curr: torch.Tensor, f_prev: torch.Tensor) -> torch.Tensor:
        """Broyden 방향 계산"""
        try:
            # 차분 계산
            s = z_curr - z_prev  # State difference
            y = f_curr - f_prev  # Residual difference
            
            batch_size = z_curr.shape[0]
            z_dim = z_curr.numel() // batch_size
            
            # Flatten
            s_flat = s.view(batch_size, -1)
            y_flat = y.view(batch_size, -1)
            f_curr_flat = f_curr.view(batch_size, -1)
            
            # UV 행렬 가져오기
            U = self.U_matrices[layer_idx]  # [B, D, memory_size]
            V = self.V_matrices[layer_idx]  # [B, memory_size, D]
            
            if U is None or V is None:
                # 첫 번째 Broyden step
                self.U_matrices[layer_idx] = torch.zeros(batch_size, z_dim, self.memory_size, 
                                                        device=z_curr.device, dtype=z_curr.dtype)
                self.V_matrices[layer_idx] = torch.zeros(batch_size, self.memory_size, z_dim,
                                                        device=z_curr.device, dtype=z_curr.dtype)
                self.s_vectors[layer_idx] = 0
                
                # Identity approximation으로 시작
                return f_curr
            
            # Broyden update
            current_idx = self.s_vectors[layer_idx] % self.memory_size
            
            # Sherman-Morrison formula를 사용한 inverse 업데이트
            # (I + UV)^{-1} = I - U(I + VU)^{-1}V
            
            # s^T s가 너무 작으면 skip
            s_norm_sq = torch.sum(s_flat * s_flat, dim=1, keepdim=True)
            valid_mask = s_norm_sq > 1e-12
            
            # Valid한 배치에 대해서만 업데이트
            for b in range(batch_size):
                if valid_mask[b]:
                    # Update UV matrices
                    denominator = s_norm_sq[b] + self.regularization
                    scale = 1.0 / denominator
                    
                    # s와 (y - As)를 사용한 rank-1 update
                    # A^{-1} ← A^{-1} + (s - A^{-1}y)(s^T A^{-1}) / (s^T A^{-1} y)
                    
                    # Current approximation: A^{-1} f = f + U V f
                    Vf = torch.matmul(V[b], f_curr_flat[b])  # [memory_size]
                    UVf = torch.matmul(U[b], Vf)  # [D]
                    Ainv_f = f_curr_flat[b] - UVf
                    
                    # Broyden direction
                    u_vec = (s_flat[b] - Ainv_f) * scale.sqrt()
                    v_vec = s_flat[b] * scale.sqrt()
                    
                    # Update UV at current index
                    U[b, :, current_idx] = u_vec
                    V[b, current_idx, :] = v_vec
            
            # Next index
            self.s_vectors[layer_idx] = current_idx + 1
            
            # 최종 방향 계산
            Vf = torch.bmm(V, f_curr_flat.unsqueeze(-1)).squeeze(-1)  # [B, memory_size]
            UVf = torch.bmm(U, Vf.unsqueeze(-1)).squeeze(-1)  # [B, D]
            
            broyden_direction_flat = f_curr_flat - UVf
            broyden_direction = broyden_direction_flat.view(z_curr.shape)
            
            return broyden_direction
            
        except Exception:
            # Fallback to vanilla if Broyden fails
            return f_curr
    
    def initialize_memory(self, zs_shapes: List[tuple]):
        """Broyden UV 메모리 초기화"""
        n_layers = len(zs_shapes)
        
        self.U_matrices = [None] * n_layers
        self.V_matrices = [None] * n_layers
        self.s_vectors = [0] * n_layers
        
        # Input/output layer는 메모리 불필요
        for i in [0, n_layers - 1]:
            self.U_matrices[i] = None
            self.V_matrices[i] = None
            self.s_vectors[i] = None
            
        self.memory_initialized = True
    
    def reset_memory(self):
        """Broyden 메모리 완전 초기화"""
        self.U_matrices = None
        self.V_matrices = None
        self.s_vectors = None
        self.trajectory = []
        self.convergence_info = {}
        self.memory_initialized = False
    
    def get_memory_requirements(self, batch_size: int, 
                              layer_shapes: List[tuple]) -> Dict[str, int]:
        """Broyden 메모리 요구사항 계산"""
        base_memory = 0
        history_memory = 0
        
        # Hidden layer에 대한 메모리 계산
        for i, shape in enumerate(layer_shapes):
            if i == 0 or i == len(layer_shapes) - 1:
                continue  # Skip input/output
                
            layer_size = batch_size * torch.numel(torch.zeros(shape[1:]))  # Exclude batch dim
            
            # U: [B, D, memory_size], V: [B, memory_size, D]
            uv_memory = layer_size * self.memory_size * 2 * 4  # float32
            history_memory += uv_memory
        
        trajectory_memory = 0
        if self.track_trajectory:
            total_params = sum(batch_size * torch.numel(torch.zeros(shape[1:])) for shape in layer_shapes)
            trajectory_memory = total_params * self.T * 4  # float32
        
        return {
            'base_memory': base_memory,
            'history_memory': history_memory,
            'auxiliary_memory': 0,
            'trajectory_memory': trajectory_memory
        }