import numpy as np


def log_single_suggestion_metrics(writer, trajectory, total_traj_cnt,
        areas_by_timestep, policy_coverage_by_timestep):
    """
    Analyze and log suggestion areas for a single trajectory.

    Args:
        writer: TensorBoard SummaryWriter
        trajectory: Dictionary containing suggestion data
        total_traj_cnt: Current total trajectory count for logging x-axis
    """
    suggestions = trajectory["suggestions"]
    policy_coverages = trajectory["policy_coverages"]
    policy_entropies = trajectory["policy_entropies"]
    is_invalid = trajectory["is_invalid"]
    suggestion_areas = trajectory['suggestion_areas']

    # Process each timestep
    for t, sugg in enumerate(suggestions):
        if sugg is not None and len(sugg) > 0 and is_invalid[t] == 0:
            # Add policy entropy
            policy_coverage = policy_coverages[t]
            entropy = policy_entropies[t]
            suggestion_area = suggestion_areas[t]
            writer.add_scalar(f'suggestions/macro_{t:03d}/policy_entropy',
                              entropy, total_traj_cnt)

            # Log suggestion area
            writer.add_scalar(f'suggestions/macro_{t:03d}/suggestion_area',
                              suggestion_area, total_traj_cnt)
            areas_by_timestep[t].append(suggestion_area)

            # Log coordinates
            coordinates = sugg[0]
            writer.add_text(f'suggestions/macro_{t:03d}/suggestion_coords',
                            str(coordinates), total_traj_cnt)

            policy_coverage_by_timestep[t].append(policy_coverage)
            writer.add_scalar(f'suggestions/macro_{t:03d}/policy_coverage',
                              policy_coverage, total_traj_cnt)


def log_batch_suggestion_metrics(writer, areas_by_timestep,
        policy_coverage_by_timestep, num_iterations):
    """
    Analyze and log suggestion areas from collected trajectories.
    Handles both overall statistics and per-macro analysis.

    Args:
        writer: TensorBoard SummaryWriter
        trajectories: List of trajectory dictionaries containing suggestion data
        num_iterations: Current total trajectory count for logging x-axis
    """

    # all areas is simply the joined list of all areas
    all_areas = [area for areas in areas_by_timestep.values() for area in areas]
    all_coverages = [coverage for coverages in
                     policy_coverage_by_timestep.values() for coverage in
                     coverages]
    # Log overall statistics
    if all_areas:
        avg_area = np.mean(all_areas)
        median_area = np.median(all_areas)
        std_area = np.std(all_areas)

        writer.add_scalar("suggestions/avg/area", avg_area, num_iterations)
        writer.add_scalar("suggestions/median/area", median_area,
                          num_iterations)
        writer.add_scalar("suggestions/std/area", std_area, num_iterations)

    # Log per-macro statistics
    for timestep, areas in areas_by_timestep.items():
        if areas:
            prefix = f"suggestions/macro_{timestep:03d}"
            avg_area = np.mean(areas)
            median_area = np.median(areas)
            std_area = np.std(areas)

            writer.add_scalar(f"{prefix}/avg/area", avg_area, num_iterations)
            writer.add_scalar(f"{prefix}/median/area", median_area,
                              num_iterations)
            writer.add_scalar(f"{prefix}/std/area", std_area, num_iterations)

    # Log overall policy coverage statistics
    if all_coverages:
        avg_coverage = np.mean(all_coverages)
        median_coverage = np.median(all_coverages)
        std_coverage = np.std(all_coverages)

        writer.add_scalar("suggestions/avg/policy_coverage", avg_coverage,
                          num_iterations)
        writer.add_scalar("suggestions/median/policy_coverage",
                          median_coverage, num_iterations)
        writer.add_scalar("suggestions/std/policy_coverage", std_coverage,
                          num_iterations)

    for timestep, coverages in policy_coverage_by_timestep.items():
        if coverages:
            prefix = f"suggestions/macro_{timestep:03d}"
            avg_coverage = np.mean(coverages)
            median_coverage = np.median(coverages)
            std_coverage = np.std(coverages)

            writer.add_scalar(f"{prefix}/avg/policy_coverage", avg_coverage,
                              num_iterations)
            writer.add_scalar(f"{prefix}/median/policy_coverage",
                              median_coverage, num_iterations)
            writer.add_scalar(f"{prefix}/std/policy_coverage", std_coverage,
                              num_iterations)
