from dataloaders.dataloader import TextDataLoader
import torch
import torch.nn as nn
import argparse
import sys
import os
from tqdm import tqdm
import torch.backends.cudnn as cudnn
from tensorboardX import SummaryWriter
from models.losses import *
import time

from datetime import datetime
from models.text_classifier import Classifier
from models.image_classifier import ImageClassifier
from models.image_models import *
from models.model import *
from utils import post_process_depth, flip_lr
def convert_arg_line_to_args(arg_line):
    for arg in arg_line.split():
        if not arg.strip():
            continue
        yield arg


parser = argparse.ArgumentParser(description='TRIDE PyTorch implementation.', fromfile_prefix_chars='@')
parser.convert_arg_line_to_args = convert_arg_line_to_args

parser.add_argument('--mode',                                       type=str,   help='train or test', default='train')
parser.add_argument('--model_name',                                 type=str,   help='model name', default='TRIDE')
parser.add_argument('--main_path',                                  type=str,   help='main path of data', required=True)
parser.add_argument('--no_radar',                                               help='if set, train the text+image model', action='store_true')

parser.add_argument('--train_text_feature_general_path',            type=str,   help='path of training text feature general', required=True)
parser.add_argument('--train_text_feature_left_path',               type=str,   help='path of training text feature left', required=True)
parser.add_argument('--train_text_feature_mid_left_path',           type=str,   help='path of training text feature middle left', required=True)
parser.add_argument('--train_text_feature_mid_right_path',          type=str,   help='path of training text feature middle right', required=True)
parser.add_argument('--train_text_feature_right_path',              type=str,   help='path of training text feature right', required=True)
parser.add_argument('--train_radar_path',                           type=str,   help='path of training radar', required=True)
parser.add_argument('--train_weather_consition_path',               type=str,   help='path of training sample weather conditions', required=True)
parser.add_argument('--train_image_path',                           type=str,   help='path of training image', required=True)
parser.add_argument('--train_ground_truth_path',                    type=str,   help='path of D', required=True)
parser.add_argument('--train_ground_truth_nointer_path',            type=str,   help='path of D_s', required=True)
parser.add_argument('--train_lidar_path',                           type=str,   help='path of single lidar depth', required=True)

parser.add_argument('--validation_image_path',                      type=str,   help='path of validation image', required=True)
parser.add_argument('--validation_text_feature_general_path',       type=str,   help='path of validation text feature general', required=True)
parser.add_argument('--validation_text_feature_left_path',          type=str,   help='path of validation text feature left', required=True)
parser.add_argument('--validation_text_feature_mid_left_path',      type=str,   help='path of validation text feature middle left', required=True)
parser.add_argument('--validation_text_feature_mid_right_path',     type=str,   help='path of validation text feature middle right', required=True)
parser.add_argument('--validation_text_feature_right_path',         type=str,   help='path of validation text feature right', required=True)
parser.add_argument('--validation_weather_condition_path',          type=str,   help='path of validation sample weather conditions', required=True)
parser.add_argument('--validation_ground_truth_path',               type=str,   help='path of validation ground truth', required=True)
parser.add_argument('--validation_radar_path',                      type=str,   help='path of validation radar', required=True)

parser.add_argument('--k',                                          type=int,   help='k nearest neighbor', default=4)
parser.add_argument('--encoder_radar',                              type=str,   help='type of encoder of radar channels, resnet34', default='resnet18')
parser.add_argument('--radar_input_channels',                       type=int,   help='number of input radar channels', default=4)
parser.add_argument('--radar_gcn_channel_in',                       type=int,   help='input channels',  default=6)
parser.add_argument('--encoder',                                    type=str,   help='type of encoder', default='resnet34_bts')
parser.add_argument('--n_filters_decoder',                          type=int,   help='number of decoder filters', default=[256, 256, 128, 64, 32])
parser.add_argument('--input_height',                               type=int,   help='input height', default=352)
parser.add_argument('--input_width',                                type=int,   help='input width',  default=704)
parser.add_argument('--max_depth',                                  type=float, help='maximum depth in estimation', default=100)

parser.add_argument('--fuse',                                       type=str,   help='fusion of radar & image', default='wafb')
parser.add_argument('--text_fuse',                                  type=str,   help='fusion of text & image', default='cross_attention')
parser.add_argument('--text_encode_mode',                           type=str,   help='text feature processing mode', default='average')
parser.add_argument('--text_hidden_dim',                            type=int,   help='hidden dimension of text feature',  default=128)
parser.add_argument('--use_img_feat',                                           help='if set, during weather classification, also use img feature', action='store_true')
parser.add_argument('--point_hidden_dim',                           type=int,   help='hidden dimension of radar point feature',  default=128)

# Log and save
parser.add_argument('--log_directory',                              type=str,   help='directory to save checkpoints and summaries', default='')
parser.add_argument('--checkpoint_path',                            type=str,   help='path to a checkpoint to load', default='')
parser.add_argument('--log_freq',                                   type=int,   help='Logging frequency in global steps', default=100)
parser.add_argument('--save_freq',                                  type=int,   help='Checkpoint saving frequency in global steps', default=500)

# Training
parser.add_argument('--weight_decay',                               type=float, help='weight decay factor for optimization', default=1e-2)
parser.add_argument('--retrain',                                                help='if used with checkpoint_path, will restart training from step zero', action='store_true')
parser.add_argument('--adam_eps',                                   type=float, help='epsilon in Adam optimizer', default=1e-6)
parser.add_argument('--reg_loss',                                   type=str,   help='loss function for depth regression - l1/silog', default='l1')
parser.add_argument('--w_smoothness',                               type=float, help='Weight of local smoothness loss', default=0.00)

parser.add_argument('--variance_focus',                             type=float, help='lambda in paper: [0, 1], higher value more focus on minimizing variance of error', default=0.85)
parser.add_argument('--batch_size',                                 type=int,   help='batch size', default=4)
parser.add_argument('--num_epochs',                                 type=int,   help='number of epochs', default=100)
parser.add_argument('--learning_rate',                              type=float, help='initial learning rate', default=1e-4)
parser.add_argument('--end_learning_rate',                          type=float, help='end learning rate', default=-1)


parser.add_argument('--gpu',                                        type=int,   help='GPU id to use.', default=None)
parser.add_argument('--num_threads',                                type=int,   help='number of threads to use for data loading', default=1)
parser.add_argument('--do_online_eval',                                         help='if set, perform online eval in every eval_freq steps', action='store_true')
parser.add_argument('--min_depth_eval',                             type=float, help='minimum depth for evaluation', default=1e-3)
parser.add_argument('--max_depth_eval',                             type=float, help='maximum depth for evaluation', default=80)
parser.add_argument('--eval_freq',                                  type=int,   help='Online evaluation frequency in global steps', default=500)
parser.add_argument('--eval_summary_directory',                     type=str,   help='output directory for eval summary,'
                                                                                     'if empty outputs to checkpoint folder', default='')



if sys.argv.__len__() == 2:
    arg_filename_with_prefix = '@' + sys.argv[1]
    args = parser.parse_args([arg_filename_with_prefix])
else:
    args = parser.parse_args()


eval_metrics = ['silog', 'abs_rel', 'log10', 'rms', 'sq_rel', 'log_rms', 'mae', 'd1', 'd2', 'd3']


def compute_errors(gt, pred):
    thresh = np.maximum((gt / pred), (pred / gt))
    d1 = (thresh < 1.25).mean()
    d2 = (thresh < 1.25 ** 2).mean()
    d3 = (thresh < 1.25 ** 3).mean()

    mae = np.mean(np.abs(gt - pred))

    rms = (gt - pred) ** 2
    rms = np.sqrt(rms.mean())

    log_rms = (np.log(gt) - np.log(pred)) ** 2
    log_rms = np.sqrt(log_rms.mean())

    abs_rel = np.mean(np.abs(gt - pred) / gt)
    sq_rel = np.mean(((gt - pred) ** 2) / gt)

    err = np.log(pred) - np.log(gt)
    silog = np.sqrt(np.mean(err ** 2) - np.mean(err) ** 2) * 100

    err = np.abs(np.log10(pred) - np.log10(gt))
    log10 = np.mean(err)

    return [silog, abs_rel, log10, rms, sq_rel, log_rms, mae, d1, d2, d3]

def online_eval(model, dataloader_eval, gpu, post_process=False):
    correct = 0
    total = 0
    correct_rain = 0
    total_rain = 0
    correct_night = 0
    total_night = 0

    eval_measures = torch.zeros(11).cuda(device=gpu)
    for _, eval_sample_batched in enumerate(tqdm(dataloader_eval.data)):
        with torch.no_grad():
            image = eval_sample_batched['image'].cuda(gpu, non_blocking=True)
            text_mask = eval_sample_batched['text_mask'].cuda(gpu, non_blocking=True)
            text_feature_general = eval_sample_batched['text_feature_general'].cuda(gpu, non_blocking=True)
            text_feature_left = eval_sample_batched['text_feature_left'].cuda(gpu, non_blocking=True)
            text_feature_mid_left = eval_sample_batched['text_feature_mid_left'].cuda(gpu, non_blocking=True)
            text_feature_mid_right = eval_sample_batched['text_feature_mid_right'].cuda(gpu, non_blocking=True)
            text_feature_right = eval_sample_batched['text_feature_right'].cuda(gpu, non_blocking=True)
            text_length = eval_sample_batched['text_length'].cuda(gpu, non_blocking=True)
            if args.no_radar == False:
                radar_channels = eval_sample_batched['radar_channels'].cuda(gpu, non_blocking=True)
                radar_points = eval_sample_batched['radar_points'].cuda(gpu, non_blocking=True)


            label = eval_sample_batched['label'].cuda(gpu, non_blocking=True)
            gt_depth = eval_sample_batched['depth']
            
            pred_depth, class_pred = model(image, radar_channels, radar_points, text_feature_general, text_feature_left, text_feature_mid_left, \
                                            text_feature_mid_right, text_feature_right, text_mask, text_length)
            if post_process:
                image_flipped = flip_lr(image)
                radar_flipped = flip_lr(radar_channels)

                text_mask = flip_lr(text_mask)
                pred_depth_flipped, _ = model(image_flipped, radar_flipped, radar_points, text_feature_general, text_feature_left, text_feature_mid_left,\
                                        text_feature_mid_right, text_feature_right, text_mask, text_length)
                pred_depth = post_process_depth(pred_depth, pred_depth_flipped)
            _, predicted = class_pred.max(1)
            _, targets = label.max(1)

            total += label.size(0)
            correct += predicted.eq(targets).sum().item()
            
            # rain
            if targets==2:
                total_rain += 1
                correct_rain += (predicted==2).sum().item()

            # night
            if targets==1:
                total_night += 1
                correct_night += (predicted==1).sum().item()

            pred_depth = pred_depth.cpu().numpy().squeeze()
            gt_depth = gt_depth.cpu().numpy().squeeze()

        pred_depth[pred_depth < args.min_depth_eval] = args.min_depth_eval
        pred_depth[pred_depth > args.max_depth_eval] = args.max_depth_eval
        pred_depth[np.isinf(pred_depth)] = args.max_depth_eval
        pred_depth[np.isnan(pred_depth)] = args.min_depth_eval

        valid_mask = np.logical_and(gt_depth > args.min_depth_eval, gt_depth < args.max_depth_eval)


        measures = compute_errors(gt_depth[valid_mask], pred_depth[valid_mask])

        eval_measures[:-1] += torch.tensor(measures).cuda(device=gpu)
        eval_measures[-1] += 1
    
    if total != 0:
        print('--------------Classification--------------')
        acc = 100. * correct / total
        acc_rain = 100. * correct_rain / total_rain
        acc_night = 100. * correct_night / total_night

        print('{:>7}, {:>7}, {:>7}'.format('Accuracy', 'Accuracy Rain', 'Accuracy Night'))
        print('{:.3f}%, {:.3f}%, {:.3f}%'.format(acc, acc_rain, acc_night))
    
    eval_measures_cpu = eval_measures.cpu()
    cnt = eval_measures_cpu[-1].item()
    eval_measures_cpu /= cnt
    print('-------------Depth Estimation-------------')
    print('Computing errors for {} eval samples'.format(int(cnt)))
    print("{:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}".format('silog', 'abs_rel', 'log10', 'rms',
                                                                                        'sq_rel', 'log_rms', 'mae', 'd1', 'd2',
                                                                                        'd3'))
    for i in range(9):
        print('{:7.3f}, '.format(eval_measures_cpu[i]), end='')
    print('{:7.3f}'.format(eval_measures_cpu[9]))
    return eval_measures_cpu

def eval(model, dataloader_eval):
    correct = 0
    total = 0
    correct_rain = 0
    total_rain = 0
    correct_night = 0
    total_night = 0

    for _, eval_sample_batched in enumerate(tqdm(dataloader_eval.data)):
        with torch.no_grad():
            text_emb = eval_sample_batched['text_emb'].cuda()
            label = eval_sample_batched['label'].cuda()
            image = eval_sample_batched['image'].cuda()

            outputs = model(text_emb)
            # outputs = model(image)

            _, predicted = outputs.max(1)
            _, targets = label.max(1)

            total += label.size(0)
            correct += predicted.eq(targets).sum().item()
            
            # rain
            if targets==2:
                total_rain += 1
                correct_rain += (predicted==2).sum().item()

            # night
            if targets==1:
                total_night += 1
                correct_night += (predicted==1).sum().item()
    print(total, total_rain, total_night)
    print(correct, correct_rain, correct_night)
    acc = 100. * correct / total
    acc_rain = 100. * correct_rain / total_rain
    acc_night = 100. * correct_night / total_night

    print('{:>7}, {:>7}, {:>7}'.format('Accuracy', 'Accuracy Rain', 'Accuracy Night'))
    print('{:.3f}%, {:.3f}%, {:.3f}%'.format(acc, acc_rain, acc_night))


def main_worker(args):
    dataloader = TextDataLoader(args, 'train')
    dataloader_eval = TextDataLoader(args, 'test')


    model = TRIDE(args)
    model.decoder.apply(weights_init_xavier)

    model = torch.nn.DataParallel(model)
    model.cuda()

    num_params = sum([np.prod(p.size()) for p in model.parameters()])
    print("Total number of parameters: {}".format(num_params))

    num_params_update = sum([np.prod(p.shape) for p in model.parameters() if p.requires_grad])
    print("Total number of learning parameters: {}".format(num_params_update))

    global_step = 0
    best_eval_measures_lower_better = torch.zeros(7).cpu() + 1e3
    best_eval_measures_higher_better = torch.zeros(3).cpu()
    best_eval_steps = np.zeros(10, dtype=np.int32)


    optimizer = torch.optim.AdamW([{'params': model.module.image_encoder.parameters(), 'weight_decay': args.weight_decay},
                                    {'params': model.module.radar_encoder.parameters(), 'weight_decay': args.weight_decay},
                                    {'params': model.module.text_encoder.parameters(), 'weight_decay': args.weight_decay},
                                    {'params': model.module.decoder.parameters(), 'weight_decay': 0}],
                                    lr=args.learning_rate, eps=args.adam_eps)
    
    model_just_loaded = False
    if args.checkpoint_path != '':
        if os.path.isfile(args.checkpoint_path):
            print("Loading checkpoint '{}'".format(args.checkpoint_path))
            if args.gpu is None:
                checkpoint = torch.load(args.checkpoint_path)
            else:
                loc = 'cuda:{}'.format(args.gpu)
                checkpoint = torch.load(args.checkpoint_path, map_location=loc)
            global_step = checkpoint['global_step']
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            try:
                best_eval_measures_higher_better = checkpoint['best_eval_measures_higher_better'].cpu()
                best_eval_measures_lower_better = checkpoint['best_eval_measures_lower_better'].cpu()
                best_eval_steps = checkpoint['best_eval_steps']
            except KeyError:
                print("Could not load values for online evaluation")

            print("Loaded checkpoint '{}' (global_step {})".format(args.checkpoint_path, checkpoint['global_step']))
        else:
            print("No checkpoint found at '{}'".format(args.checkpoint_path))
        model_just_loaded = True

    if args.retrain:
        global_step = 0

    cudnn.benchmark = True

    writer = SummaryWriter(args.log_directory + '/' + args.model_name + '/summaries', flush_secs=30)
    if args.do_online_eval:
        if args.eval_summary_directory != '':
            eval_summary_path = os.path.join(args.eval_summary_directory, args.model_name)
        else:
            eval_summary_path = os.path.join(args.log_directory, 'eval')
        eval_summary_writer = SummaryWriter(eval_summary_path, flush_secs=30)

    if args.reg_loss == 'silog':
        loss_depth = silog_loss(variance_focus=args.variance_focus)
    elif args.reg_loss == 'l2':
        loss_depth = l2_loss()
    else:
        # default: L1 loss
        loss_depth = l1_loss()

    loss_classification = nn.CrossEntropyLoss()
    loss_smoothness = smoothness_loss_func()

    # evaluate before training:
    # print('-----------EVALUATE BEFORE TRAINING------------')
    # eval_measures = online_eval(model, dataloader_eval, args.gpu, post_process=True)
    end_learning_rate = args.end_learning_rate if args.end_learning_rate != -1 else 0.1 * args.learning_rate

    steps_per_epoch = len(dataloader.data)
    num_total_steps = args.num_epochs * steps_per_epoch
    start_epoch = global_step // steps_per_epoch
    for epoch in range(start_epoch, args.num_epochs + 1):
        with tqdm(dataloader.data, unit='batch') as tepoch:
            for sample_batched in tepoch:
                tepoch.set_description(f"Epoch {epoch}")
                optimizer.zero_grad()

                if args.device == 'cuda':
                    image = sample_batched['image'].cuda(args.gpu, non_blocking=True)
                    text_feature_general = sample_batched['text_feature_general'].cuda(args.gpu, non_blocking=True)
                    text_feature_left = sample_batched['text_feature_left'].cuda(args.gpu, non_blocking=True)
                    text_feature_mid_left = sample_batched['text_feature_mid_left'].cuda(args.gpu, non_blocking=True)
                    text_feature_mid_right = sample_batched['text_feature_mid_right'].cuda(args.gpu, non_blocking=True)
                    text_feature_right = sample_batched['text_feature_right'].cuda(args.gpu, non_blocking=True)

                    text_length = sample_batched['text_length'].cuda(args.gpu, non_blocking=True)

                    text_mask = sample_batched['text_mask'].cuda(args.gpu, non_blocking=True)
                    depth_gt = sample_batched['depth'].cuda(args.gpu, non_blocking=True)
                    nointer_depth_gt = sample_batched['nointer_depth'].cuda(args.gpu, non_blocking=True)
                    single_depth_gt = sample_batched['lidar'].cuda(args.gpu, non_blocking=True)
                    label = sample_batched['label'].cuda(args.gpu, non_blocking=True)
                    if args.no_radar == False:
                        radar_channels = sample_batched['radar_channels'].cuda(args.gpu, non_blocking=True)
                        radar_points = sample_batched['radar_points'].cuda(args.gpu, non_blocking=True)

                else:
                    image = sample_batched['image'].to(args.device)
                    text_feature_general = sample_batched['text_feature_general'].to(args.device)
                    text_feature_left = sample_batched['text_feature_left'].to(args.device)
                    text_feature_mid_left = sample_batched['text_feature_mid_left'].to(args.device)
                    text_feature_mid_right = sample_batched['text_feature_mid_right'].to(args.device)
                    text_feature_right = sample_batched['text_feature_right'].to(args.device)
                    text_length = sample_batched['text_length'].to(args.device)

                    text_mask = sample_batched['text_mask'].to(args.device)
                    label = sample_batched['label'].to(args.device)
                    depth_gt = sample_batched['depth'].to(args.device)
                    nointer_depth_gt = sample_batched['nointer_depth'].to(args.device)
                    single_depth_gt = sample_batched['lidar'].to(args.device)
                    if args.no_radar == False:
                        radar_channels = sample_batched['radar_channels'].to(args.device)
                        radar_points = sample_batched['radar_points'].to(args.device)


                depth_est, class_pred = model(image, radar_channels, radar_points, text_feature_general, text_feature_left, text_feature_mid_left, \
                                                text_feature_mid_right, text_feature_right, text_mask, text_length)
                loss_c = loss_classification(class_pred, label)
                    

                # calculate loss for depth
                mask_single = single_depth_gt > 0.01
                mask = torch.logical_and(depth_gt > 0.01, mask_single==0) 
                loss_d = loss_depth.forward(depth_est, depth_gt, mask.to(torch.bool)) + \
                         loss_depth.forward(depth_est, single_depth_gt, mask_single.to(torch.bool))

                if args.w_smoothness > 0.00:
                    loss_s = loss_smoothness.forward(depth_est, image)
                    loss_s = loss_s * args.w_smoothness
                else:
                    loss_s = 0.0

                loss = loss_d + loss_s + loss_c

                loss.backward() 

                for param_group in optimizer.param_groups:
                    current_lr = (args.learning_rate - end_learning_rate) * (1 - global_step / num_total_steps) ** 0.9 + end_learning_rate
                    param_group['lr'] = current_lr
                
                optimizer.step()
                
                if not args.do_online_eval and global_step and global_step % args.save_freq == 0:
                    if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0):
                        checkpoint = {'global_step': global_step,
                                      'model': model.state_dict(),
                                      'optimizer': optimizer.state_dict()}
                        torch.save(checkpoint, args.log_directory + '/' + args.model_name + '/model-{}'.format(global_step))

                if args.do_online_eval and global_step and global_step % args.eval_freq == 0 and not model_just_loaded:
                    time.sleep(0.1)
                    model.eval()
                    eval_measures = online_eval(model, dataloader_eval, args.gpu, post_process=True)

                    print(f'model: {args.model_name}')
                    if eval_measures is not None:
                        for i in range(10):
                        # for i in [3, 6]:
                            eval_summary_writer.add_scalar(eval_metrics[i], eval_measures[i].cpu(), int(global_step))
                            measure = eval_measures[i]
                            is_best = False
                            if i < 7 and measure < best_eval_measures_lower_better[i]:
                                old_best = best_eval_measures_lower_better[i].item()
                                best_eval_measures_lower_better[i] = measure.item()
                                is_best = True
                            elif i >= 7 and measure > best_eval_measures_higher_better[i-7]:
                                old_best = best_eval_measures_higher_better[i-7].item()
                                best_eval_measures_higher_better[i-7] = measure.item()
                                is_best = True
                            if is_best:
                                old_best_step = best_eval_steps[i]
                                old_best_name = '/model-{}-best_{}_{:.5f}'.format(old_best_step, eval_metrics[i], old_best)
                                model_path = args.log_directory + '/' + args.model_name + old_best_name
                                if os.path.exists(model_path):
                                    command = 'rm {}'.format(model_path)
                                    os.system(command)
                                best_eval_steps[i] = global_step
                                model_save_name = '/model-{}-best_{}_{:.5f}'.format(global_step, eval_metrics[i], measure)
                                print('New best for {}. Saving model: {}'.format(eval_metrics[i], model_save_name))
                                checkpoint = {'global_step': global_step,
                                            'model': model.state_dict(),
                                            'optimizer': optimizer.state_dict(),
                                            'best_eval_measures_higher_better': best_eval_measures_higher_better,
                                            'best_eval_measures_lower_better': best_eval_measures_lower_better,
                                            'best_eval_steps': best_eval_steps
                                            }
                                torch.save(checkpoint, args.log_directory + '/' + args.model_name + model_save_name)
                        eval_summary_writer.flush()
                    model.train()

                model_just_loaded = False
                global_step += 1

                tepoch.set_postfix(loss=loss.item(), depth=loss_d.item(), classification=loss_c.item())

                
    if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0):
        writer.close()
        if args.do_online_eval:
            eval_summary_writer.close()
        

def main():
    torch.manual_seed(42)

    if args.mode != 'train':
        print('main.py is only for training. Use test.py instead.')
        return -1
    runtime = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
    args.model_name = runtime + '_' + args.model_name
    model_filename = args.model_name + '.py'
    command = 'mkdir ' + args.log_directory + '/' + args.model_name
    os.system(command)

    args_out_path = args.log_directory + '/' + args.model_name + '/' + sys.argv[1]
    command = 'cp ' + sys.argv[1] + ' ' + args_out_path
    os.system(command)

    torch.cuda.empty_cache()
    if args.do_online_eval:
        print("You have specified --do_online_eval.")
        print("This will evaluate the model every eval_freq {} steps and save best models for individual eval metrics."
              .format(args.eval_freq))
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    args.device = device

    main_worker(args)

if __name__ == '__main__':
    
    main()