import time
from lightning.pytorch.callbacks import Callback


class TimingCallback(Callback):
    def __init__(self, log_every_n_batches=100):
        super().__init__()
        self.log_every_n_batches = log_every_n_batches
        self.training_step_times = []
        self.validation_step_times = []
        self.train_samples = 0
        self.val_samples = 0
        self.train_start_time = None
        self.val_start_time = None

    def on_train_epoch_start(self, trainer, pl_module):
        self.train_start_time = time.perf_counter()
        self.training_step_times = []
        self.train_samples = 0

    def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
        self.batch_start_time = time.perf_counter()

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        batch_time = time.perf_counter() - self.batch_start_time
        self.training_step_times.append(batch_time)
        batch_size = batch["x0"].shape[0] if isinstance(batch, dict) else batch.shape[0]
        self.train_samples += batch_size

        # Log samples/sec every N batches
        if (batch_idx + 1) % self.log_every_n_batches == 0:
            recent_times = self.training_step_times[-self.log_every_n_batches :]
            recent_samples = self.log_every_n_batches * batch_size
            speed = recent_samples / sum(recent_times) if sum(recent_times) > 0 else 0
            pl_module.log(
                "train/samples_per_sec_live",
                speed,
                on_step=True,
                prog_bar=True,
                sync_dist=True,
            )

    def on_train_epoch_end(self, trainer, pl_module):
        train_duration = sum(self.training_step_times)
        train_samples_per_sec = (
            self.train_samples / train_duration if train_duration > 0 else 0
        )
        pl_module.log_dict(
            {
                "train/samples_per_sec": train_samples_per_sec,
            },
            on_epoch=True,
            sync_dist=True,
        )

    def on_validation_epoch_start(self, trainer, pl_module):
        self.val_start_time = time.perf_counter()
        self.validation_step_times = []
        self.val_samples = 0

    def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx):
        self.batch_start_time = time.perf_counter()

    def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        batch_time = time.perf_counter() - self.batch_start_time
        self.validation_step_times.append(batch_time)
        if isinstance(batch, dict):
            self.val_samples += batch["x0"].shape[0]
        else:
            self.val_samples += batch.shape[0]

    def on_validation_epoch_end(self, trainer, pl_module):
        val_duration = sum(self.validation_step_times)
        val_samples_per_sec = self.val_samples / val_duration

        pl_module.log_dict(
            {
                "val/samples_per_sec": val_samples_per_sec,
            },
            on_epoch=True,
            sync_dist=True,
        )
