"""Defines the dataclass for the model training arguments.

The model training arguments are stored in the dataclass TrainArgs, which is
passed to the trainer module.
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
    from torch.nn import Module
    from torch import Tensor
    from collections.abc import Callable

from dataclasses import dataclass

@dataclass(frozen=True)
class TrainArgs:
    """Store hyperparmeters for model training.
    
    Attributes:
    ----------
        model: Module
            The initialized Pytorch model.

        optim: Module
            The uninitialzed Pytorch optimizer to be used for training.

        fn_loss: Callable[[Tensor, Tensor], Tensor]
            The initialized loss function to be used for training.

        lr: float
            The learning rate to be used for training.

        dataloaders: tuple[Module, Module]
            A tuple containing Pytorch dataloaders. The first dataloader must
            be the training dataloader. The second dataloader is for valid-
            ation.

        perform_metrics: None|dict[str, type]=None
            Any performance metrics to be computed and monitored during train
            -ing. The metric must be written as a class with two methods: 
                - compute(x,y): computes the performance metric between x and
                y.
                - finalize(): finalizes the metric value before returning it
                (for example, taking an average).
            The performance metric must be passed as a dictionary with the key
            describing the name of the metric and the value being the unitial-
            ized class.
        
        max_epochs: int (default: -1)
            Specifies the duration of training. If not set by the user,
            defaults to -1. The trainer will then use a patience value as a
            stopping criteria. If set by the user, it will override any
            patience value set.

        wandb_kwargs: dict[str, Any]|None (default: None)
            Provides any required kwargs to the Weights and Biases library.

        ckpt_name: str|None (default: None)
            The name to be used for model checkpointing.

        save_folder: str|None (default: None)
            The folder to be used for saving model checkpoints.

        save_final: bool (default: False)
            If training for max_epochs, will save a checkpoint of the model at
            max_epochs.
        
        optim_kwargs: dict[str, Any]|None (default: None)
            Any optional kwargs to be passed to the optimizer.
        
        lr_sched: None|Module (default: None)
            An uninitialized Pytorch learning rate scheduler.
        
        lr_sched_kwargs: dict[str, Any]|None
            Any required or optional kwargs to be passed to the learning rate
            scheduler.

        patience: int (default: 7)
            If not training for max_epochs, sets the number of epochs to wait
            if a criteria is met. If the criteria is met after the value
            specified by the patience, the trainer will stop training. 

        monitor: None|str (default: "train")
            The data set to be monitored during training. See the patience
            parameter.

        warm_up: int (default: -1)
            Will pause the "patience" parameter for the number of epochs
            specified by warm_up.

        device: str (default: "cuda")
            The device to be used for training.

        sweep: bool (default: False)
            A flag that tells the trainer that the user is performing a hyperp-
            arameter sweep. In this state, the trainer will not save check-
            points.

        sw_run: object|None (default: None)
            During a sweep, passes the intialized wandb.init object to the
            trainer.
    """
    model: Module
    optim: Module
    fn_loss: Callable[[Tensor, Tensor], Tensor]
    lr: float
    dataloaders: tuple[Module, Module]
    perform_metrics: None|dict[str, type]=None
    max_epochs: int=-1
    wandb_kwargs: dict[str, Any]|None=None
    ckpt_name: str|None=None
    save_folder: str|None=None
    save_final: bool=False
    optim_kwargs: dict[str, Any]|None=None
    lr_sched: None|Module=None
    lr_sched_kwargs: dict[str, Any]|None=None
    patience: int=7
    monitor: None|str="train"
    warm_up: int=-1
    device: str="cuda"
    sweep: bool=False
    sw_run: object|None=None
