import hydra
import os
import torch
import torch.nn.functional as F
from multiguide.training.helpers import set_property_predictor
from multiguide.dataset.helpers import turn_seq_to_ids, get_batch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

@hydra.main(config_path='../configs', config_name='config.yaml')
def evaluate_product_classifier(config):
    batch = get_batch(config)
    true_classes = []
    predicted_classes = []
    predicted_confidences = []
    # NOTE: to use this, we need to set the checkpoint to classifier_guidance.checkpoint_path
    # not classifier_guidance.product_classifier_checkpoint_path
    if config.classifier_guidance.checkpoint_path is None and config.classifier_guidance.product_classifier_checkpoint_path is not None:
        config.classifier_guidance.checkpoint_path = config.classifier_guidance.product_classifier_checkpoint_path
    product_classifier = set_property_predictor(config)
    product_classifier.eval()
    with torch.no_grad():
        for i, (_, product, class_idx) in enumerate(batch):
            input_ids = turn_seq_to_ids(config, product)
            input_ids = input_ids.unsqueeze(0).to(device)
            classifier_scores = product_classifier(input_ids)
            confidence, output = F.softmax(classifier_scores, dim=1).max(dim=1)
            true_classes.append(class_idx)
            predicted_classes.append(int(output.item()))
            predicted_confidences.append(confidence.item())
        print(f'average confidence: {sum(predicted_confidences) / len(predicted_confidences)}')
        print(f'accuracy: {sum([true_classes[i] == predicted_classes[i] for i in range(len(true_classes))]) / len(true_classes)}')

if __name__ == "__main__":
    evaluate_product_classifier()