import argparse
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

from libs.utils import set_seed
from model import build_model
from dataset import build_dataloader
from libs.ConfTS import ConformalTemperatureScaling
from libs.predictor import Predictor


def main():
    parser = argparse.ArgumentParser(description='ConfTS')
    parser.add_argument('--seed', type=int, default=42, help='seed')
    parser.add_argument('--trials', type=int, default=1, help='number of trials')
    parser.add_argument('--model', type=str, default='resnet50', help='model')
    parser.add_argument('--data_dir', '-s', type=str, default='/data/dataset/', help='dataset name.')
    parser.add_argument('--conformal', type=str, default='aps', help='conformal prediction')
    parser.add_argument('--alpha', type=float, default=0.1, help="error rate")
    args = parser.parse_args()
    return args


if __name__ == "__main__":
    args = main()
    set_seed(args.seed)
    trails = args.trials
    model_name = args.model
    data_dir = args.data_dir
    conformal = args.conformal
    alpha = args.alpha

    model = build_model(model_name)
    model = model.cuda()
    calib_calibloader, conf_calibloader, testloader = build_dataloader(data_dir)
    preprocessor = ConformalTemperatureScaling(model, alpha)
    predictor = Predictor(model, preprocessor, conformal, alpha)
    predictor.calibrate(calib_calibloader, conf_calibloader)
    result = predictor.evaluate(testloader)
    print(result)
