import pytorch_lightning as pl

from configs import get_config
from models import construct_model
from data import construct_data_module
from utils import init_logger, construct_callbacks
import torch
import numpy as np
import random
import os
from pytorch_lightning.callbacks import EarlyStopping
os.environ["CUDA_VISIBLE_DEVICES"] = "1"  
torch.set_float32_matmul_precision('medium')
os.environ['HTTP_PROXY'] = 'http://127.0.0.1:7890'
os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:7890'
def set_seed(seed):
    pl.seed_everything(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
    np.random.seed(seed)
    
    random.seed(seed)
    
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    os.environ['PYTHONHASHSEED'] = str(seed)



if __name__ == "__main__":
    config = get_config()
    set_seed(config["seed"])
    data_module = construct_data_module(config)
    ec_cem_model = construct_model(config, data_module.imbalance_weight)
    
    trainer = pl.Trainer(
        accelerator="gpu",
        devices=[config["device"]],
    )
    model = ec_cem_model.load_from_checkpoint('checkpoints/CUB/EC_CEM/best_acc_seed_15.ckpt')
    trainer.test(model, datamodule=data_module)

