from functools import partial
from utils.utils import worker_init_fn
import numpy as np
import torch
from nets.model_zoo.ReCOT.ReCOT import ReCOT as ReCOT
from utils.utils_acc import Acc_calculate
from utils.dataloader import Dataset, dataset_collate
from torch.utils.data import DataLoader
    
    
if __name__ == "__main__":
    # settings ------------------------------------------------------------ #
    dataset_cate    = ''
    batch_size      = 1
    recurrent_steps = 3
    model_path      = ""
    backbone        = "swin_v2_t"
    experiment_root_path = ''
    test_data_path       = experiment_root_path + 'CVOGL_' + dataset_cate + '/CVOGL_'+ dataset_cate + '_test.pth'
    input_shape_reimg = [1024, 1024]
    if dataset_cate == 'DroneAerial':
        input_shape_qimg  = [256, 256]
    else:
        input_shape_qimg  = [256, 512]
    num_classes     = 1
    num_workers     = 8
    
    # loading -------------------------------------------------------------- #
    print('\nloading settings...')
    print('loading dataset...')
    test_lines   = torch.load(test_data_path)
    dataset      = Dataset(experiment_root_path + 'CVOGL_' + dataset_cate, test_lines, input_shape_reimg, input_shape_qimg, num_classes, train = False)
    test_loader  = DataLoader(dataset, shuffle = False, batch_size = batch_size, num_workers = num_workers, pin_memory=True, 
                                    drop_last=False, collate_fn=dataset_collate, sampler=None, worker_init_fn=partial(worker_init_fn, rank=0, seed=11))
    print('Dataset done!\n')
    
    print('loading model...')
    model = ReCOT(dataset_cate, backbone, 'sine', 256, num_classes, 100, input_shape_qimg, recurrent_steps, pretrained=True)
    print('Load weights {}.'.format(model_path))
    
    model_dict      = model.state_dict()
    device          = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    pretrained_dict = torch.load(model_path, map_location = device)
    load_key, no_load_key, temp_dict = [], [], {}
    for k, v in pretrained_dict.items():
        if k in model_dict.keys() and np.shape(model_dict[k]) == np.shape(v):
            temp_dict[k] = v
            load_key.append(k)
        else:
            no_load_key.append(k)
    model_dict.update(temp_dict)
    model.load_state_dict(model_dict)
    print("\nSuccessful Load Key:", str(load_key)[:500], "\nSuccessful Load Key Num:", len(load_key))
    print("\nFail To Load Key:", str(no_load_key)[:500], "\nFail To Load Key num:", len(no_load_key))
    model = model.cuda()
    model = model.eval()
    print('Model done!\n')
    
    # Evaluating ---------------------------------------------------------- #
    evalGEO = Acc_calculate(model, test_loader, input_shape_reimg, batch_size)
    evalGEO.calculate()
    print('Evaluation done!')
