from tqdm import tqdm
import torch

from models.base import load_model

from loaders.image_loader import load_images
from models.base import load_model
from models.position import load_positional_embedding
from models.base.model_helper import ModelHelper
from loss.task_loss import load_task_loss

@torch.no_grad()
def test_ae(positional_embedding_dict, ae, test_model_dict, task_loss, device):
    ae.eval()
    
    test_losses = {}
    test_accuracies = {}
    
    for model_classnum, task_info in test_model_dict.keys():
        positional_embeddings = positional_embedding_dict[model_classnum]
        model_name, num_classes = model_classnum.split('&')
        
        test_model_path = task_info['test_model_path']
        cluster_cfg_path = task_info['cluster_cfg_path']
        task_name = task_info['task_name']
        data_dir = task_info['data_dir']
        
        tgt_exp_name = f"{model_name}_{task_name}"
        
        test_loader = load_images(data_dir, task_name, data_type='test', batch_size=128)
        
        original_model = load_model(model_name, num_classes=int(num_classes)).to(device)
        original_model.eval()
        
        model_helper = ModelHelper(original_model)
        
        model_helper.load(test_model_path, device)
        model_helper.set_cluster(True, cluster_cfg_path)
        
        print(f'\n Starting eval {test_model_path} on test set.')
        learnable_weights = model_helper.get_learnable_weights()
        reconstructed_weights = ae.predict_all(positional_embeddings, learnable_weights)
        model_helper.update_weights(reconstructed_weights)
        
        test_loss = 0
        correct = 0
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                output = model_helper.model(data)
                test_loss += task_loss(output, target).item()  # sum up batch loss
                pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
                correct += pred.eq(target.view_as(pred)).sum().item()

        test_loss /= len(test_loader.dataset)
        accuracy = 100. * correct / len(test_loader.dataset)
        
        test_losses[tgt_exp_name] = test_loss
        test_accuracies[tgt_exp_name] = accuracy
            
        print('Test set on {}: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
            tgt_exp_name,
            test_loss, correct, len(test_loader.dataset),
            accuracy))

    return test_losses, test_accuracies