import pytorch_lightning as pl

from configs import get_config
from models import construct_model
from data import construct_data_module
from utils import init_logger, construct_callbacks
import torch
import numpy as np
import random
import os
from pytorch_lightning.callbacks import EarlyStopping
os.environ["CUDA_VISIBLE_DEVICES"] = "1" 
torch.set_float32_matmul_precision('medium')
os.environ['HTTP_PROXY'] = 'http://127.0.0.1:7890'
os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:7890'
def set_seed(seed):
    pl.seed_everything(seed)
    
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
    np.random.seed(seed)
    
    random.seed(seed)
    
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    os.environ['PYTHONHASHSEED'] = str(seed)



if __name__ == "__main__":
    config = get_config()
    set_seed(config["seed"])
    data_module = construct_data_module(config)
    ec_cem_model = construct_model(config, data_module.imbalance_weight)
    logger = init_logger(config)
    callbacks = construct_callbacks(config)
    #! early stopping
    early_stop = EarlyStopping(
        monitor=config['ckpt_save_monitor'],
        mode=config['ckpt_save_mode'],
        patience=15, 
        min_delta=1e-3
    )
    callbacks.append(early_stop)
    
    trainer = pl.Trainer(
        accelerator="gpu",
        devices=[config["device"]],
        max_epochs=config["max_epochs"],
        check_val_every_n_epoch=config["val_every_n_epochs"],
        log_every_n_steps=5,
        logger=logger,
        callbacks=callbacks,
    )

    trainer.fit(ec_cem_model, datamodule=data_module)
    trainer.test(ec_cem_model, datamodule=data_module, ckpt_path='best')
    

