from __future__ import annotations

import logging
from typing import Any, Dict, List, Optional
import numpy as np
import swanlab

from toy_example.problems import ToyExampleSol


class BaseLogger:
    """
    Base logger for algorithm metrics.
    """

    def __init__(self, objective_fn, constraint_fn, exact_sol: Optional[ToyExampleSol] = None):
        self.objective_fn = objective_fn
        self.constraint_fn = constraint_fn
        self.exact_sol = exact_sol

    def log(self, step: int, state: Any, extra_metrics: Optional[Dict[str, float]] = None, extra_log_items: Optional[List[str]] = None) -> None:
        """
        Log metrics for the current step.
        """
        obj_val = self.objective_fn(state.x, state.y)
        
        # Calculate norms
        x_norm = np.linalg.norm(state.x)
        y_norm = np.linalg.norm(state.y)
        con_val = self.constraint_fn(state.x, state.y)
        con_val = np.clip(con_val, 0.0, None)
        con_norm = np.linalg.norm(con_val)
        con_rel_err = con_norm / (1.0 + x_norm + y_norm)

        # Base metrics dictionary
        log_dict = {
            "metrics/objective": obj_val,
            "metrics/con_rel_err": con_rel_err,
        }
        
        # Base log items string list
        log_items = [
            f"[Step {step:04d}]",
            f"obj={obj_val:.2e}",
            f"con_rel_err={con_rel_err:.2e}",
        ]

        # Add extra metrics if provided
        if extra_metrics:
            log_dict.update(extra_metrics)
        
        # Add extra log items if provided (usually corresponding to extra_metrics)
        if extra_log_items:
            log_items[1:1] = extra_log_items

        # Handle exact solution comparison
        if self.exact_sol is not None:
            norm_diff_x = np.linalg.norm(state.x - self.exact_sol.x)
            norm_diff_y = np.linalg.norm(state.y - self.exact_sol.y)
            norm_exact_x = np.linalg.norm(self.exact_sol.x)
            norm_exact_y = np.linalg.norm(self.exact_sol.y)
            x_rel_err = norm_diff_x / (1.0 + norm_exact_x)
            y_rel_err = norm_diff_y / (1.0 + norm_exact_y)
            agg_rel_err = (norm_diff_x + norm_diff_y) / (1.0 + norm_exact_x + norm_exact_y)
            
            log_dict.update({
                "metrics/rel_err": agg_rel_err,
                "metrics/x_rel_err": x_rel_err,
                "metrics/y_rel_err": y_rel_err,
            })

            log_items.extend([
                f"rel_err={agg_rel_err:.2e}",
                f"x_rel_err={x_rel_err:.2e}",
                f"y_rel_err={y_rel_err:.2e}",
            ])

        # Execute logging
        logging.info(", ".join(log_items))
        
        if swanlab.get_run():
            swanlab.log(log_dict)
