import os
import torch
from utils import *
from omegaconf import OmegaConf

def main(args):
    torch.set_num_threads(args.torch_num_threads)
    torch.random.manual_seed(0)
    np.random.seed(0)
    random.seed(0)

    args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    args.dsa_param = ParamDiffAug()
    args.dsa = True if args.dsa == 'True' else False
    if args.dsa_strategy in ['none', 'None']:
        args.dsa = False

    save_dir = os.path.dirname(args.pt)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    it = int(os.path.splitext(os.path.basename(args.pt))[0].split('_iter')[1])
    logger = Logger(os.path.join(save_dir, f'cross_arch_log_iter{it}.txt'))

    ae_config = OmegaConf.load(args.ae_config)
    ae_model = load_autoencoder_from_config(ae_config, args.ae_ckpt).to(args.device)

    pt = torch.load(args.pt)
    latent_syn = pt['latent'].cuda()
    label_syn = pt['label_syn'].cuda()
    args.latent_size = latent_syn.shape[-2:]
    args.channel, args.im_size, args.num_classes, _, class_map, _, _, _, dst_train, _, testloader, _ = get_dataset(args.dataset, args.data_path, args.batch_real, args.res, args=args)
    args.f = args.im_size[0] // args.latent_size[0]

    model_eval_pool = get_eval_pool(args.eval_mode, args.model, args.model)

    logger.log('%s evaluation begins' % get_time())
    logger.log('Evaluation model pool: ', model_eval_pool)
    logger.log(f'Dataset info: {args.dataset}, {args.channel} * {args.im_size[0]} * {args.im_size[1]}, {args.num_classes} classes')
    logger.log('Args: ' + str(args.__dict__))

    best_acc, best_std = eval_and_save(args, latent_syn, label_syn, ae_model, logger, testloader=testloader, model_eval_pool=model_eval_pool, 
                                       it = it, save = False, verbose = True, use_lr_net = args.use_lr_net)

    logger.log(best_acc)
    logger.log(best_std)

if __name__ == '__main__':
    import shared_args
    parser = shared_args.add_shared_args()
    parser.add_argument('--pt', type = str, help='path to the saved pt file')
    parser.add_argument('--use_lr_net', action = 'store_true')
    args = parser.parse_args()
    main(args)