'''
Evaluate the property predictors of type regression.
'''
import logging
# Suppress FAISS loading messages
logging.getLogger('faiss.loader').setLevel(logging.WARNING)
# Now your imports
import hydra
import os
import torch
import torch.nn.functional as F
import pandas as pd
from functools import partial

from multiguide.training.bucket_batch_ddp_sampler import DistributedBucketBatchSampler
from multiguide.dataset.molecule_dataset import MoleculeDataset
from multiguide.training.helpers import set_property_predictor
from multiguide.helpers import PROJECT_ROOT
from multiguide.training.helpers import collate_fn
from torch.utils.data import DataLoader

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def get_test_dataloader(config):

    test_df = pd.read_csv(os.path.join(PROJECT_ROOT,
                                        'data',
                                        'predictors',
                                        config.classifier_guidance.property,
                                        str(config.classifier_guidance.dataset.dataset_name),
                                        config.classifier_guidance.dataset.test_file))
    # cut with start and end indices
    test_df = test_df.iloc[config.classifier_guidance.dataset.start_idx:config.classifier_guidance.dataset.end_idx]
    print(f'len of test df {config.classifier_guidance.dataset.test_file}: {len(test_df)}')
    test_dataset = MoleculeDataset(test_df['rxn'].to_list(),
                                    test_df['property'].to_list(), 
                                    test_df['full_length'].to_list(),
                                    config)

    dist_batch_sampler = DistributedBucketBatchSampler(
        dataset=test_dataset,
        batch_sizes=config.classifier_guidance.dataset.batch_sizes,  # List of batch sizes for each bucket
        num_buckets=config.classifier_guidance.dataset.num_buckets,
        shuffle=True,
        rank=0,
        num_replicas=1
    )

    test_loader = DataLoader(
        dataset=test_dataset,
        batch_sampler=dist_batch_sampler,
        collate_fn=partial(collate_fn, pad_idx=test_dataset.pad_idx),
        prefetch_factor=config.classifier_guidance.dataset.prefetch_factor if config.classifier_guidance.dataset.num_workers > 0 else None,
        num_workers=config.classifier_guidance.dataset.num_workers,
        pin_memory=True
    )

    return test_loader

@hydra.main(config_path='../configs', config_name='config.yaml')
def evaluate_product_classifier(config):
    property_predictor, property_predictor_checkpoint = set_property_predictor(config, return_checkpoint=True)
    property_predictor.eval()
    test_loader = get_test_dataloader(config)
    mse_losses = []
    with torch.no_grad():
        for i, (seq_batch, value_batch, _, _) in enumerate(test_loader):
            print(f'batch {i} of {len(test_loader)}')
            seq_batch = seq_batch.to(device)
            value_batch = value_batch.to(device).float()
            regression_scores = property_predictor(seq_batch)
            # compute mse after unnormalizing
            regression_scores = regression_scores * property_predictor_checkpoint['target_std'] + property_predictor_checkpoint['target_mean']
            mse_losses.append(F.mse_loss(regression_scores, value_batch))
        
    print(f'average mse: {sum(mse_losses) / len(mse_losses)}')

if __name__ == "__main__":
    evaluate_product_classifier()