import torch
import torch.nn as nn
from tqdm import tqdm
from easydict import EasyDict as edict

from src.models.utils import cal_3d_loss
from src.data_loader.utils import get_data, get_train_val_split
from src.data_loader.data_set import Data_Set
from src.experiments.evaluation_utils import evaluate
from src.models.rn_25D_wMLPref import RN_25D_wMLPref
import argparse

from src.constants import (
    COMET_KWARGS,
    HYBRID2_CONFIG,
    BASE_DIR,
    TRAINING_CONFIG_PATH,
)
from src.utils import get_console_logger, read_json

def main():
    parser = argparse.ArgumentParser(description='fine tuning')
    parser.add_argument('--model', required=True, help='model.pth')
    args = parser.parse_args()

    BATCH_SIZE = 64
    NUM_WORKERS = 4
    MODEL_PATH = args.model
    model = RN_25D_wMLPref(backend_model='rn152')
    model.load_state_dict(torch.load(MODEL_PATH))

    train_param = edict(read_json(TRAINING_CONFIG_PATH))
    train_param['train_ratio'] = 0.999999999

    
    # data preperation
    data = get_data(
        Data_Set, edict(train_param), sources=["freihand"], experiment_type="supervised", split='test'
    )

    # evaluate code
    result = evaluate(model, data, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
    print(result)
    
if __name__ == "__main__":
    main()