from dataclasses import dataclass
from pathlib import Path
from time import time

from utils.enums import Datasets, Losses, Optims, LR_Schedulers, Devices


@dataclass
class RunSetup:
    """
    Class that contains parameters for one model train run.
    """

    # main params
    dataset: Datasets
    model_name: str
    loss: Losses
    optim: Optims
    lr: float
    lr_scheduler: LR_Schedulers
    # when the model reaches this gradient norm value, it stops
    target_norm_grad: float

    # Path that will contain checkpoints and run logs
    path: Path
    # Path that contains the dataset files. If dataset is not present, it will be automatically generated
    dataset_path: Path

    # random seed for training-related random processes
    seed: int = 42
    device: Devices = Devices.CPU

    # more dataset parameters

    min_num_epochs: int = 10_000
    max_num_epochs: int = 200_000
    # save model every X epoch
    save_model_every: int = 100
    # compute Lipschitz every X epoch
    compute_L_every: int = 100
    # compute Lipschitz for the first X epochs
    compute_L_for_first: int = 1000

    # scale of the gaussian noise applied to the dataset
    dataset_noise: float = 0.0
    batch_size: int = 512
    batch_size_test: int = 512
    # fraction of the shuffling in the dataset
    alpha_shuffle: float = 0.0
    # num of parallel processes that give out batches
    data_num_workers: int = 1

    # use this to continue training
    load_model: bool = False
    load_from_epoch: int = -1
    # The unique id for the run, used to differentiate runs with the same parameters.
    # Leave as is if you start a new run. Set to run_time of the experiment you want to continue training for, if load_model = True
    runtimestamp: int = int(time())

    def as_dict(self):
        d = self.__dict__.copy()
        d["dataset"] = self.dataset.name
        d["loss"] = self.loss.name
        d["optim"] = self.optim.name
        d["lr_scheduler"] = self.lr_scheduler.name
        d["device"] = self.device.name
        d["path"] = str(self.path.resolve())
        d["dataset_path"] = str(self.dataset_path.resolve())
        return d
