# -*- coding: utf-8 -*-
"""
Vanilla Solver - 표준 반복 솔버

이 모듈은 가속화 없는 표준 iterative update를 구현합니다.
가장 단순하지만 안정적인 수렴 방법입니다.
"""

import torch
import time
from typing import List, Dict, Any, Optional, Tuple
from .base_solver import BaseSolver


class VanillaSolver(BaseSolver):
    """
    표준 반복 솔버

    delta_zs를 직접 적용하여 zs를 업데이트합니다.
    메모리를 사용하지 않는 stateless 솔버입니다.
    """

    def __init__(self, args):
        super().__init__(args)
        self.convergence_threshold = getattr(args, 'convergence_threshold', 1e-4)
        self.track_trajectory = getattr(args, 'track_trajectory', False)

    def solve(self, latent_updater, zs_init: List[torch.Tensor],
              x: torch.Tensor, y: torch.Tensor,
              loop_scheduler=None, **kwargs) -> Tuple[List[torch.Tensor], Dict[str, Any]]:
        """
        표준 반복으로 잠재 변수를 수렴시킵니다.

        Args:
            latent_updater: delta_zs를 계산하는 updater
            zs_init: 초기 잠재 상태들
            x: 입력 데이터
            y: 타겟 데이터
            loop_scheduler: 업데이트 순서 스케줄러 (optional)
            **kwargs: 추가 파라미터 (backbone_module_list, zhs for Meta-PC 등)

        Returns:
            zs_final: 수렴된 잠재 상태들
            info: 수렴 정보
        """
        # 초기화
        zs = [z.clone() for z in zs_init]
        self.trajectory = []
        convergence_history = []

        # 필요한 추가 파라미터 추출
        backbone_module_list = kwargs.get('backbone_module_list')
        zhs = kwargs.get('zhs', None)  # Meta-PC용 frozen predictions

        # 타이밍 시작
        start_time = time.time()

        # 반복 수렴
        for t in range(self.T):
            # 현재 상태 저장 (trajectory tracking)
            if self.track_trajectory:
                self.trajectory.append([z.clone().detach() for z in zs])

            # # Delta 계산
            # delta_zs = latent_updater.compute_delta_zs(
                # zs, x, y,
                # zhs=zhs,
                # backbone_module_list=backbone_module_list
            # )

            # 스케줄에 따른 업데이트 적용
            if loop_scheduler is not None:
                # 스케줄러가 있으면 패턴에 따라 순차적 업데이트
                n_layers = len(zs)
                schedule = loop_scheduler.get_update_schedule(n_layers)

                # Block Sweep: 각 그룹별로 순차적 업데이트
                for layer_group in schedule:
                    # 그룹 내 레이어들 업데이트
                    delta_zs = latent_updater.compute_delta_zs(
                        zs, x, y,
                        zhs=zhs,
                        backbone_module_list=backbone_module_list
                    )
                    # print(schedule)
                    # print(layer_group)
                    # import ipdb; ipdb.set_trace()

                    for layer_idx in layer_group:
                        if 0 < layer_idx < len(zs) - 1:  # hidden layers만
                            zs[layer_idx] = (zs[layer_idx] + delta_zs[layer_idx]).detach()

                    # # 다음 그룹을 위해 delta 재계산 (Block Sweep의 핵심)
                    # if layer_group != schedule[-1]:  # 마지막 그룹이 아니면
                        # delta_zs = latent_updater.compute_delta_zs(
                            # zs, x, y,
                            # zhs=zhs,
                            # backbone_module_list=backbone_module_list
                    #     )
            else:
                # 기본: 모든 hidden layer 동시 업데이트 (Jacobi)

                delta_zs = latent_updater.compute_delta_zs(
                    zs, x, y,
                    zhs=zhs,
                    backbone_module_list=backbone_module_list
                )
                for idx in range(1, len(zs) - 1):  # hidden layers만
                    zs[idx] = (zs[idx] + delta_zs[idx]).detach()

            # 수렴 체크
            delta_norm = sum(torch.norm(d).item() for d in delta_zs if d is not None)
            convergence_history.append(delta_norm)

            if self.check_convergence(delta_zs, self.convergence_threshold):
                break

        # 타이밍 종료
        end_time = time.time()

        # 최종 에너지 계산
        final_energy = None
        if backbone_module_list is not None:
            final_energy = self.compute_energy(zs, x, backbone_module_list)

        # 수렴 정보 저장
        self.convergence_info = {
            'iterations': t + 1,
            'converged': t < self.T - 1,
            'final_delta_norm': convergence_history[-1] if convergence_history else 0.0,
            'final_energy': final_energy,
            'time_elapsed': end_time - start_time,
            'convergence_history': convergence_history
        }

        return zs, self.convergence_info

    def initialize_memory(self, zs_shapes: List[tuple]):
        """
        Vanilla 솔버는 메모리가 필요 없음 (stateless)
        """
        self.memory_initialized = True

    def reset_memory(self):
        """
        Vanilla 솔버는 메모리가 없으므로 리셋할 것도 없음
        """
        self.trajectory = []
        self.convergence_info = {}
        self.memory_initialized = False

