# -*- coding: utf-8 -*-
"""
Anderson Type-I Solver - Anderson 가속 솔버

이 모듈은 Anderson Type-I 가속화를 통해 수렴을 개선합니다.
기존 models/solvers/anderson_type1_solver.py를 기반으로 
new_models 아키텍처에 맞게 재구현했습니다.
"""

import torch
import time
import numpy as np
from typing import List, Dict, Any, Optional, Tuple
from .base_solver import BaseSolver


class AndersonSolver(BaseSolver):
    """
    Anderson Type-I 가속 솔버
    
    이전 m개의 iteration 정보를 사용하여 더 나은 수렴 방향을 찾습니다.
    메모리 사용량이 크지만 수렴 속도가 빠릅니다.
    """
    
    def __init__(self, args):
        super().__init__(args)
        self.m = getattr(args, 'anderson_m', 6)  # Anderson memory size
        self.lam = getattr(args, 'anderson_lam', 1e-4)  # Regularization
        self.tau = getattr(args, 'anderson_tau', 1.0)  # Damping factor
        self.min_anderson_iter = getattr(args, 'min_anderson_iter', 2)
        self.convergence_threshold = getattr(args, 'convergence_threshold', 1e-4)
        self.track_trajectory = getattr(args, 'track_trajectory', False)
        
        # Anderson history (per layer)
        self.G_history = None  # Residual history
        self.Z_history = None  # State history
        self.alpha_coeffs = None  # Mixing coefficients
        
    def supports_acceleration(self) -> bool:
        return True
        
    def solve(self, latent_updater, zs_init: List[torch.Tensor], 
              x: torch.Tensor, y: torch.Tensor, 
              loop_scheduler=None, **kwargs) -> Tuple[List[torch.Tensor], Dict[str, Any]]:
        """
        Anderson 가속으로 잠재 변수를 수렴시킵니다.
        """
        # 메모리 초기화
        batch_size = x.shape[0]
        self.initialize_memory([z.shape for z in zs_init])
        
        # 초기화
        zs = [z.clone() for z in zs_init]
        self.trajectory = []
        convergence_history = []
        anderson_active = False
        
        # 필요한 추가 파라미터 추출
        backbone_module_list = kwargs.get('backbone_module_list')
        zhs = kwargs.get('zhs', None)
        
        # 타이밍 시작
        start_time = time.time()
        
        # 반복 수렴
        for t in range(self.T):
            # 현재 상태 저장
            if self.track_trajectory:
                self.trajectory.append([z.clone().detach() for z in zs])
            
            # Delta 계산 (residual)
            delta_zs = latent_updater.compute_delta_zs(
                zs, x, y, 
                zhs=zhs,
                backbone_module_list=backbone_module_list
            )
            
            # Anderson 가속 적용 여부 결정
            if t >= self.min_anderson_iter and not anderson_active:
                anderson_active = True
                
            if anderson_active:
                # Anderson 가속 적용 (Loop Scheduler 고려)
                if loop_scheduler is not None:
                    zs_new = self._anderson_step_with_scheduler(
                        zs, delta_zs, t, loop_scheduler, latent_updater, x, y, 
                        zhs=zhs, backbone_module_list=backbone_module_list
                    )
                else:
                    zs_new = self._anderson_step(zs, delta_zs, t)
            else:
                # 표준 업데이트
                zs_new = self._vanilla_step(zs, delta_zs, loop_scheduler)
            
            # 수렴 체크
            delta_norm = sum(torch.norm(d).item() for d in delta_zs if d is not None)
            convergence_history.append(delta_norm)
            
            zs = zs_new
            
            if self.check_convergence(delta_zs, self.convergence_threshold):
                break
        
        # 타이밍 종료
        end_time = time.time()
        
        # 최종 에너지 계산
        final_energy = None
        if backbone_module_list is not None:
            final_energy = self.compute_energy(zs, x, backbone_module_list)
        
        # 수렴 정보 저장
        self.convergence_info = {
            'iterations': t + 1,
            'converged': t < self.T - 1,
            'final_delta_norm': convergence_history[-1] if convergence_history else 0.0,
            'final_energy': final_energy,
            'time_elapsed': end_time - start_time,
            'convergence_history': convergence_history,
            'anderson_activated_at': self.min_anderson_iter if anderson_active else None
        }
        
        return zs, self.convergence_info
    
    def _vanilla_step(self, zs: List[torch.Tensor], delta_zs: List[torch.Tensor],
                     loop_scheduler=None) -> List[torch.Tensor]:
        """표준 업데이트 스텝"""
        zs_new = [z.clone() for z in zs]
        
        if loop_scheduler is not None:
            n_layers = len(zs)
            schedule = loop_scheduler.get_update_schedule(n_layers)
            
            for layer_group in schedule:
                for layer_idx in layer_group:
                    if 0 < layer_idx < len(zs) - 1:  # hidden layers만
                        zs_new[layer_idx] = (zs[layer_idx] - self.eta * delta_zs[layer_idx]).detach()
        else:
            # 기본: Jacobi 업데이트
            for idx in range(1, len(zs) - 1):
                zs_new[idx] = (zs[idx] - self.eta * delta_zs[idx]).detach()
        
        return zs_new
    
    def _anderson_step(self, zs: List[torch.Tensor], delta_zs: List[torch.Tensor],
                      iteration: int) -> List[torch.Tensor]:
        """Anderson 가속 스텝"""
        zs_new = []
        
        for layer_idx in range(len(zs)):
            if layer_idx == 0 or layer_idx == len(zs) - 1:
                # Input/output layer는 그대로
                zs_new.append(zs[layer_idx])
                continue
                
            # Hidden layer에 Anderson 가속 적용
            z_current = zs[layer_idx]
            residual = -self.eta * delta_zs[layer_idx]
            
            # History 업데이트
            if self.G_history[layer_idx] is None:
                # 첫 번째 Anderson iteration
                self.G_history[layer_idx] = [residual.clone()]
                self.Z_history[layer_idx] = [z_current.clone()]
                z_new = z_current + residual
            else:
                # Anderson extrapolation
                self.G_history[layer_idx].append(residual.clone())
                self.Z_history[layer_idx].append(z_current.clone())
                
                # Memory limit 적용
                if len(self.G_history[layer_idx]) > self.m:
                    self.G_history[layer_idx] = self.G_history[layer_idx][-self.m:]
                    self.Z_history[layer_idx] = self.Z_history[layer_idx][-self.m:]
                
                z_new = self._compute_anderson_extrapolation(layer_idx, z_current, residual)
                
            zs_new.append(z_new.detach())
        
        return zs_new
    
    def _anderson_step_with_scheduler(self, zs: List[torch.Tensor], delta_zs: List[torch.Tensor],
                                    iteration: int, loop_scheduler, latent_updater, x: torch.Tensor, 
                                    y: torch.Tensor, **kwargs) -> List[torch.Tensor]:
        """
        Loop Scheduler와 함께 사용하는 Anderson 가속 스텝
        
        Block Sweep의 핵심: 각 layer group별로 순차적 업데이트 후 delta 재계산
        """
        zs_new = [z.clone() for z in zs]
        n_layers = len(zs)
        schedule = loop_scheduler.get_update_schedule(n_layers)
        
        # 필요한 파라미터 추출
        zhs = kwargs.get('zhs', None)
        backbone_module_list = kwargs.get('backbone_module_list')
        current_delta_zs = [d.clone() for d in delta_zs]
        
        # Block Sweep: 각 그룹별로 순차적 Anderson 가속 적용
        for group_idx, layer_group in enumerate(schedule):
            # 현재 그룹에 해당하는 layers에 Anderson 가속 적용
            for layer_idx in layer_group:
                if layer_idx == 0 or layer_idx == len(zs) - 1:
                    # Input/output layer는 그대로
                    continue
                    
                # Hidden layer에 Anderson 가속 적용
                z_current = zs_new[layer_idx]
                residual = -self.eta * current_delta_zs[layer_idx]
                
                # History 업데이트 및 Anderson extrapolation
                if self.G_history[layer_idx] is None:
                    # 첫 번째 Anderson iteration
                    self.G_history[layer_idx] = [residual.clone()]
                    self.Z_history[layer_idx] = [z_current.clone()]
                    z_new = z_current + residual
                else:
                    # Anderson extrapolation
                    self.G_history[layer_idx].append(residual.clone())
                    self.Z_history[layer_idx].append(z_current.clone())
                    
                    # Memory limit 적용
                    if len(self.G_history[layer_idx]) > self.m:
                        self.G_history[layer_idx] = self.G_history[layer_idx][-self.m:]
                        self.Z_history[layer_idx] = self.Z_history[layer_idx][-self.m:]
                    
                    z_new = self._compute_anderson_extrapolation(layer_idx, z_current, residual)
                
                zs_new[layer_idx] = z_new.detach()
            
            # 다음 그룹을 위해 delta 재계산 (Block Sweep의 핵심)
            if group_idx < len(schedule) - 1:  # 마지막 그룹이 아니면
                current_delta_zs = latent_updater.compute_delta_zs(
                    zs_new, x, y,
                    zhs=zhs,
                    backbone_module_list=backbone_module_list
                )
        
        return zs_new
    
    def _compute_anderson_extrapolation(self, layer_idx: int, 
                                       z_current: torch.Tensor, 
                                       residual_current: torch.Tensor) -> torch.Tensor:
        """Anderson extrapolation 계산"""
        G_hist = self.G_history[layer_idx]
        Z_hist = self.Z_history[layer_idx]
        
        if len(G_hist) < 2:
            return z_current + residual_current
        
        # Build difference matrices
        m_k = len(G_hist) - 1
        batch_size = z_current.shape[0]
        z_dim = z_current.shape[1:]
        
        # Flatten for easier computation
        G_flat = [g.view(batch_size, -1) for g in G_hist]
        Z_flat = [z.view(batch_size, -1) for z in Z_hist]
        
        # Compute G differences: ΔG_k = G_k - G_{k-1}
        DG_matrix = []
        for i in range(1, len(G_flat)):
            dg = G_flat[i] - G_flat[i-1]
            DG_matrix.append(dg)
        
        if not DG_matrix:
            return z_current + residual_current
        
        # Stack to form matrix [batch_size, feature_dim, m_k]
        DG_matrix = torch.stack(DG_matrix, dim=-1)  # [B, D, m_k]
        
        try:
            # Solve least squares problem: min ||DG α + G_k||^2
            G_current_flat = G_flat[-1].unsqueeze(-1)  # [B, D, 1]
            
            # Solve for each batch element
            alphas = []
            for b in range(batch_size):
                DG_b = DG_matrix[b]  # [D, m_k]
                G_b = G_current_flat[b]  # [D, 1]
                
                # Normal equation: (DG^T DG + λI) α = -DG^T G
                A = DG_b.t() @ DG_b + self.lam * torch.eye(m_k, device=DG_b.device)
                b_vec = -DG_b.t() @ G_b.squeeze(-1)
                
                # Solve
                alpha_b = torch.linalg.solve(A, b_vec)
                alphas.append(alpha_b)
            
            alpha = torch.stack(alphas, dim=0)  # [B, m_k]
            
            # Compute Anderson extrapolation
            z_anderson_flat = Z_flat[-1].clone()  # Current state
            g_anderson_flat = G_flat[-1].clone()  # Current residual
            
            # Add weighted differences
            for i in range(m_k):
                weight = alpha[:, i:i+1]  # [B, 1]
                z_anderson_flat += weight * (Z_flat[i+1] - Z_flat[i])
                g_anderson_flat += weight * (G_flat[i+1] - G_flat[i])
            
            # Apply damping
            z_result_flat = z_anderson_flat + self.tau * g_anderson_flat
            
            # Reshape back
            z_result = z_result_flat.view(z_current.shape)
            
            return z_result
            
        except Exception:
            # Fallback to vanilla update if Anderson fails
            return z_current + residual_current
    
    def initialize_memory(self, zs_shapes: List[tuple]):
        """Anderson 메모리 초기화"""
        n_layers = len(zs_shapes)
        
        self.G_history = [None] * n_layers  # Residual history
        self.Z_history = [None] * n_layers  # State history
        
        # Input/output layer는 메모리 불필요
        for i in [0, n_layers - 1]:
            self.G_history[i] = None
            self.Z_history[i] = None
            
        # Hidden layer만 초기화
        for i in range(1, n_layers - 1):
            self.G_history[i] = []
            self.Z_history[i] = []
        
        self.memory_initialized = True
    
    def reset_memory(self):
        """Anderson 메모리 완전 초기화"""
        self.G_history = None
        self.Z_history = None
        self.alpha_coeffs = None
        self.trajectory = []
        self.convergence_info = {}
        self.memory_initialized = False
    
    def get_memory_requirements(self, batch_size: int, 
                              layer_shapes: List[tuple]) -> Dict[str, int]:
        """Anderson 메모리 요구사항 계산"""
        base_memory = 0
        history_memory = 0
        
        # Hidden layer에 대한 메모리 계산
        for i, shape in enumerate(layer_shapes):
            if i == 0 or i == len(layer_shapes) - 1:
                continue  # Skip input/output
                
            layer_size = batch_size * torch.numel(torch.zeros(shape[1:]))  # Exclude batch dim
            
            # G_history와 Z_history 각각 m개씩
            history_memory += layer_size * self.m * 2 * 4  # float32
        
        trajectory_memory = 0
        if self.track_trajectory:
            total_params = sum(batch_size * torch.numel(torch.zeros(shape[1:])) for shape in layer_shapes)
            trajectory_memory = total_params * self.T * 4  # float32
        
        return {
            'base_memory': base_memory,
            'history_memory': history_memory,
            'auxiliary_memory': 0,
            'trajectory_memory': trajectory_memory
        }