from AbstractModels.Config import Config
from AbstractModels.ConvolutionModel import ConvolutionModel
from SNN.util.energy_consumption import approximate_energy_consumption

import torch

import numpy as np

def main() -> None:
    config: Config = Config()

    seed = config.get_seed()

    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True

    model: ConvolutionModel = config.initialize_model()
    
    if config.calculate_energy_consumption():
        energy: float = calculate_energy_consumption(model, config)
        print(f"Estimated Energy Consumption of Model: {energy : 0.6f} mJ")
    elif config.validate():
        validate(model, config)
    else:
        train(model, config)
    
def train(model: ConvolutionModel, config: Config) -> None:
    optimizer = config.get_optimizer()
    scheduler = config.get_scheduler(optimizer=optimizer)
    
    model.fit(
        epochs=config.get_epochs(),
        train_loader=config.get_dataloader(train=True), 
        val_loader=config.get_dataloader(train=False), 
        optimizer=optimizer, 
        criterion=config.get_criterion(),
        model_dir=config.get_model_dir(),
        config=config,
        scheduler=scheduler,
        gradient_steps=config.get_gradient_steps()
    )

def validate(model: ConvolutionModel, config: Config) -> tuple[float, float, float, float]:    
    return model.validate(
        val_loader=config.get_dataloader(train=False),
        criterion=config.get_criterion(),
        config=config
    )

def calculate_energy_consumption(model: ConvolutionModel, config: Config) -> float:
    if config.load_weights():
        model.load(config.get_model_dir(), config.load_weights(), None, None, config.resume())
    return approximate_energy_consumption(
        model=model,
        dataset=config.get_dataloader(train=False),
        timesteps=config.get_timesteps()
    )

if __name__ == '__main__':
    main()