import argparse
import os

import torch
import yaml
import torch.nn as nn
from torch.utils.data import DataLoader

from utils.model import get_dpp_model 
from model.diversity_loss import ref_diversity_loss  
from dpp_dataset import Dataset  
from utils.dpp_tools import to_device, log, get_random_phrase 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def evaluate(text_encoder, sdp, spp, dpp_model, dpp_model_p, step, configs, logger=None):
    preprocess_config, model_config, train_config = configs 

    # Get dataset
    dataset = Dataset(
        "val4.txt", preprocess_config, train_config, sort=False, drop_last=False)
    batch_size = train_config["optimizer"]["batch_size"]

    loader = DataLoader(
        dataset,
        batch_size = batch_size,
        shuffle=False,
        collate_fn=dataset.collate_fn)
    
    # Get loss function 
    d_loss = ref_diversity_loss()  

    # Evaluation
    loss_sum = 0 
    for batchs in loader:
        for batch in batchs:
            batch = to_device(batch, device)
            
            with torch.no_grad():
                # Get idxs for random phrases  
                random_phrases = get_random_phrase(batch[-3], batch[-2], batch[-1])  
                
                # Get kernel and kernel mask 
                kernel, kernel_mask, _, _, _, _ = dpp_model(*random_phrases, text_encoder, sdp, batch[3], batch[4], batch[5])
                pkernel, pkernel_mask, _, _, _, _ = dpp_model_p(*random_phrases, text_encoder, spp, batch[3], batch[4], batch[5])                                                                           
                # Calculate diversity loss 
                loss1 = d_loss(kernel, kernel_mask, num_cw=2)* len(batch[0])
                loss2 = d_loss(pkernel, pkernel_mask, num_cw=2) * len(batch[0])
                loss = loss1 + loss2 
                loss = loss.detach().cpu()       
                loss_sum += loss    
    loss_mean = loss_sum / len(dataset)  
    message = "Validation Step {}, Loss: {:.4f}".format(step, loss_mean)    
    if logger is not None:
        log(logger, step, loss=loss_mean)

    return message 