import tool

import torch
import logging

# 1. Read & Set Parameters
args = tool.parse_arg()
device = torch.device(args.device)

tool.set_seed(args.random_seed)
tool.setup_logging(args.task_name, args.save_path, args.model_name, args.ft_type, args.tr_method, args.random_seed)

logging.warning('\n <ARGUMENTS>')
logging.warning(vars(args))

# 2. Load Dataset & Metric
train_dataloader, eval_dataloader, test_dataloader, metric = tool.load_data(args.task_name, args.model_name, args.batch_size)

# 3. Load Model & Optimizer
model = tool.load_model(args.model_name, args.task_name, args.ft_type).to(device)
optimizer = tool.load_optimizer(model, args.optimizer_type, args.optimizer_lr)

# 4. train model
logging.warning('\n <TRAIN>')
tool.train_model(model, optimizer, args, train_dataloader, device)
peak_mem  = torch.cuda.max_memory_allocated(device=device)
logging.warning(f'- {'peak mem':10s}: {peak_mem} bytes')
allo_mem  = torch.cuda.memory_allocated(device=device)
logging.warning(f'- {'allo mem':10s}: {allo_mem} bytes')

# 5. evaluate model
logging.warning('\n <TEST>')
train_metric = tool.test_model(model, metric, train_dataloader, device)
eval_metric  = tool.test_model(model, metric, eval_dataloader , device)
test_metric  = tool.test_model(model, metric, test_dataloader , device)

for k in train_metric.keys():
    logging.warning(f'- {k:10s}: {train_metric[k]*100.:5.2f}% | {eval_metric[k]*100.:5.2f}% | {test_metric[k]*100.:5.2f}%')

