from dataclasses import dataclass, field
from typing import List, Optional

import torch
import wandb
from omegaconf import OmegaConf


@dataclass
class Torch_Profiler:
    pass


@dataclass
class ProfilerParams:
    use_profiler: bool = False
    exit_after: bool = True

    torch_profiler: Torch_Profiler = Torch_Profiler()


def profile(profiler_params: ProfilerParams, trainer, model, data_module):
    if not profiler_params.use_profiler:
        return
    memlab_report(profiler_params, trainer, model, data_module)
    if profiler_params.exit_after:
        wandb.finish()
        quit()


def memlab_report(profiler_params: ProfilerParams, trainer, model, data_module):
    from pytorch_memlab import MemReporter

    reporter = MemReporter(model)
    print("[start profiling]")
    trainer.fit(model, datamodule=data_module)

    print("[Start printing profiler]")
    reporter.report()
    print("[end profiling]")
