import logging
import time

import lightning.pytorch as pl
import torch


class MeasureSamplingTimeCallback(pl.Callback):
    """
    PyTorch Lightning callback to measure and log sampling time.

    This callback tracks:
    - Total sampling time
    - Number of graphs sampled
    - Time per sample
    - Time per batch

    The metrics are logged both during and after sampling.
    """

    def __init__(self):
        super().__init__()
        self.reset()

    def reset(self):
        """Reset all time tracking metrics."""
        self.total_sampling_time = 0.0
        self.total_num_graphs = 0
        self.batch_times = []
        self.batch_graph_counts = []
        self.start_time = None

    def on_sample_start(self, fabric):
        """Called when sampling starts."""
        self.reset()

    def on_before_sample_batch(self, fabric, batch_idx):
        """Mark the start time before sampling a batch."""
        self.start_time = time.perf_counter()

    def on_after_sample_batch(self, sample, fabric, batch_idx):
        """
        Calculate and log sampling time after a batch is sampled.

        Args:
            sample: List of sample dictionaries
            sampler: The sampler object
        """
        if self.start_time is None:
            return

        # Calculate time elapsed
        end_time = time.perf_counter()
        time_elapsed = end_time - self.start_time

        # Count the number of graphs in this batch
        num_graphs = sum(sample["xhat_traj"].shape[1] for sample in sample)

        # Update totals
        self.total_sampling_time += time_elapsed
        self.total_num_graphs += num_graphs

        # Store batch information
        self.batch_times.append(time_elapsed)
        self.batch_graph_counts.append(num_graphs)

        # Log metrics
        fabric.log("sampler/batch_sampling_time", time_elapsed, step=batch_idx)
        fabric.log("sampler/batch_num_graphs", num_graphs, step=batch_idx)
        fabric.log("sampler/batch_time_per_graph", time_elapsed / num_graphs, step=batch_idx)
        fabric.log("sampler/total_sampling_time", self.total_sampling_time, step=batch_idx)
        fabric.log("sampler/total_num_graphs", self.total_num_graphs, step=batch_idx)
        fabric.log("sampler/avg_time_per_graph", self.total_sampling_time / self.total_num_graphs, step=batch_idx)

        # Log to console
        py_logger = logging.getLogger(__name__)
        py_logger.info(
            f"Sampled batch {batch_idx} with {num_graphs} samples in {time_elapsed:.4f} seconds "
            f"({time_elapsed / num_graphs:.4f} seconds per sample)."
        )

        # Reset start time
        self.start_time = None

    def on_sample_end(self, fabric):
        """Log final statistics when sampling is complete."""
        if self.total_num_graphs == 0:
            return

        avg_time = self.total_sampling_time / self.total_num_graphs

        # Log final metrics.
        fabric.log("sampler/final_total_sampling_time", self.total_sampling_time)
        fabric.log("sampler/final_total_num_graphs", self.total_num_graphs)
        fabric.log("sampler/final_avg_time_per_graph", avg_time)

        # Calculate additional statistics.
        if self.batch_times:
            fabric.log("sampler/min_batch_time", min(self.batch_times))
            fabric.log("sampler/max_batch_time", max(self.batch_times))
            fabric.log("sampler/avg_batch_time", sum(self.batch_times) / len(self.batch_times))
            fabric.log("sampler/std_batch_time", torch.std(torch.tensor(self.batch_times)).item())

        # Log to console.
        py_logger = logging.getLogger(__name__)
        py_logger.info(
            f"Total sampling time: {self.total_sampling_time:.4f} seconds "
            f"for {self.total_num_graphs} samples "
            f"({avg_time:.4f} seconds per sample)."
        )
