import dataclasses
import typing as tp
from typing import Callable

import torch
import torch.utils.data
from tqdm.auto import tqdm


@dataclasses.dataclass
class TrainState:
    model: torch.nn.Module
    optimizer: torch.optim.Optimizer
    criterion: torch.nn.Module

    def zero_grad(self):
        self.optimizer.zero_grad()

    def step(self):
        self.optimizer.step()


@dataclasses.dataclass
class SimpleTrainer:
    state: TrainState
    dataloader: torch.utils.data.DataLoader
    num_epochs: int
    val_dataloader: tp.Optional[torch.utils.data.DataLoader] = None
    train_step: Callable[
        [TrainState, dict[str, torch.Tensor]], tuple[torch.Tensor, float]
    ] = None
    val_step: tp.Optional[
        Callable[[TrainState, dict[str, torch.Tensor]], tuple[torch.Tensor, float]]
    ] = None
    show_progress: bool = True

    def default_train_step(
        self, state: TrainState, batch: dict[str, torch.Tensor]
    ) -> tuple[torch.Tensor, float]:
        """Default training step that assumes model takes entire batch as input."""
        state.zero_grad()
        y_pred = state.model(batch).squeeze()
        loss = state.criterion(y_pred, batch["admissible"])
        loss.backward()
        state.step()
        return y_pred, loss.item()

    def default_val_step(
        self, state: TrainState, batch: dict[str, torch.Tensor]
    ) -> tuple[torch.Tensor, float]:
        """Default validation step that assumes model takes entire batch as input."""
        with torch.no_grad():
            y_pred = state.model(batch).squeeze()
            loss = state.criterion(y_pred, batch["admissible"])
        return y_pred, loss.item()

    def run(self) -> dict[str, list[float]]:
        """Run training and validation loops.

        Returns:
          Dictionary containing training and validation metrics.
        """
        train_step_fn = self.train_step or self.default_train_step
        val_step_fn = self.val_step or self.default_val_step

        metrics = {
            "train_losses": [],
            "val_losses": [] if self.val_dataloader else None,
        }

        epoch_iterator = range(self.num_epochs)
        if self.show_progress:
            epoch_iterator = tqdm(epoch_iterator, desc="Training", unit="epoch")

        for epoch in epoch_iterator:
            # Training loop
            self.state.model.train()
            epoch_losses = []

            batch_iterator = self.dataloader
            if self.show_progress:
                batch_iterator = tqdm(
                    batch_iterator, desc=f"Epoch {epoch}", leave=False
                )

            for batch in batch_iterator:
                _, loss = train_step_fn(self.state, batch)
                epoch_losses.append(loss)
                if self.show_progress:
                    batch_iterator.set_postfix({"loss": f"{loss:.4f}"})

            avg_train_loss = sum(epoch_losses) / len(epoch_losses)
            metrics["train_losses"].append(avg_train_loss)

            # Validation loop
            if self.val_dataloader:
                self.state.model.eval()
                val_losses = []

                val_iterator = self.val_dataloader
                if self.show_progress:
                    val_iterator = tqdm(val_iterator, desc="Validation", leave=False)

                for batch in val_iterator:
                    _, loss = val_step_fn(self.state, batch)
                    val_losses.append(loss)
                    if self.show_progress:
                        val_iterator.set_postfix({"loss": f"{loss:.4f}"})

                avg_val_loss = sum(val_losses) / len(val_losses)
                metrics["val_losses"].append(avg_val_loss)

                if self.show_progress:
                    epoch_iterator.set_postfix(
                        {
                            "train_loss": f"{avg_train_loss:.4f}",
                            "val_loss": f"{avg_val_loss:.4f}",
                        }
                    )
            else:
                if self.show_progress:
                    epoch_iterator.set_postfix({"train_loss": f"{avg_train_loss:.4f}"})

        return metrics
