
from utils import comet
from config import cf, get_config
from data import get_gradual_domains
from model import get_trained_model
from method import get_method

from utils.log_acc import acclog

print(f"Using Device: {cf.device}")

def main(data_name, model_name, method_name, domain_num, corruption=None):
    cf.data_name = data_name
    cf.model_name = model_name
    cf.method_name = method_name
    cf.domain_num = domain_num
    
    comet.start(name = f"{domain_num} {method_name} {model_name} {cf.time}", tags = [data_name, model_name, method_name, str(domain_num)])
    comet.log_parameters(cf.get_dict())
    comet.log_code(f"method/{method_name.lower()}.py")
    
    gra_domains = get_gradual_domains(data_name, domain_num, corruption=corruption)
    model = get_trained_model(data_name, model_name).to(cf.device)
    method = get_method(method_name, model)
    acc = method.gradual_adapt(gra_domains)
    
    comet.log_parameters({method_name.lower(): getattr(method.cf, method_name.lower()).get_dict()})
    acclog(acc, cf.get_dict(), getattr(method.cf, method_name.lower()).get_dict())
    comet.log_metrics({"final acc": acc})
    comet.end()
    
    
if __name__ == "__main__":    
    main(data_name = cf.data_name, model_name = cf.model_name, method_name = cf.method_name, domain_num = cf.domain_num, corruption = cf.corruption)

    # main(data_name = "imagenet", model_name = "default", method_name = "GST", domain_num = 2, corruption = "gaussian_noise")
"""
python main.py --data_name rotate_mnist --model_name resnet --method_name GMMA --domain_num 2
python main.py --data_name portraits --model_name resnet --method_name GST --domain_num 2
python main.py --data_name rotate_mnist --model_name resnet --method_name GOAT --domain_num 2
python main.py --data_name covertype --model_name fc --method_name GST --domain_num 2
CUDA_VISIBLE_DEVICES=5 
"""

"""
"gaussian_noise", "shot_noise", "impulse_noise", "defocus_blur", 
"glass_blur", "motion_blur", "zoom_blur", "snow", "frost", "fog", 
"brightness", "contrast", "elastic_transform", "pixelate", "jpeg_compression"
python main.py --data_name cifar10 --method_name GST --domain_num 6 --corruption gaussian_noise
python main.py --data_name cifar10 --method_name GOAT --domain_num 6 --corruption gaussian_noise
python main.py --data_name cifar10 --method_name GDO --domain_num 6 --corruption gaussian_noise
python main.py --data_name cifar10 --method_name GMMA --domain_num 6 --corruption brightness

python main.py --data_name cifar10 --method_name GMMA --domain_num 6 --corruption 
"""
