from utils.metrics import  masked_voc, compute_contrastive_clip_reward, compute_goal_baseline_clip_reward
import torch

def evaluate(val_loader, model, ttt_lr_eval, ttt_epochs_eval, device, baseline_clip, window_size=8, reset=False, shuffling_online = False):
    """
    Evaluate the model on the validation data and calculate various metrics.
    
    Args:
        val_loader (DataLoader): The validation data loader.
        model (torch.nn.Module): The model to evaluate.
        ttt_lr_eval (float): Learning rate for test-time training.
        ttt_epochs_eval (int): Number of epochs for test-time training.
        device (torch.device): The device to run the evaluation on (CPU/GPU).
        baseline_clip (str): The baseline clip for the CLIP-based reward computation.
        window_size (int): The size of the window for the windowed TTT method
        reset (bool): Whether to reset the model for each window.
        shuffling_online (bool): Whether to shuffle the online method.

    Returns:
        dict: Averaged metrics for the evaluation.
    """
    model.eval()

    # Initialize a dictionary to store metrics per method
    metrics = {
        "online": {"boot_voc": [], "adj_ci_high": [], "adj_ci_low": [], "voc": []},
        "online_res": {"boot_voc": [], "adj_ci_high": [], "adj_ci_low": [], "voc": []},
        "window": {"boot_voc": [], "adj_ci_high": [], "adj_ci_low": [], "voc": []},
        "window_res": {"boot_voc": [], "adj_ci_high": [], "adj_ci_low": [], "voc": []},
        "offline": {"boot_voc": [], "adj_ci_high": [], "adj_ci_low": [], "voc": []},
        "no_ttt": {"boot_voc": [], "adj_ci_high": [], "adj_ci_low": [], "voc": []},
        "clip": {"boot_voc": [], "adj_ci_high": [], "adj_ci_low": [], "voc": []},
        "clip_reg": {"boot_voc": [], "adj_ci_high": [], "adj_ci_low": [], "voc": []},
        "clip_cont": {"boot_voc": [], "adj_ci_high": [], "adj_ci_low": [], "voc": []},
    }

    with torch.no_grad():
        # Iterate through the validation data
        for frames, goal_text, progress_label, valid_mask in val_loader:
            # Get valid length from mask
            valid_len = valid_mask[0].sum().long().item()
            if valid_len == 0:
                continue  # Skip empty trajectories

            # Slice to valid portion only
            frames_seq = frames[0, :valid_len].unsqueeze(0).to(device)  # (1, valid_T, C, H, W)
            goal = [goal_text[0]]
            mask_seq = valid_mask[0, :valid_len].unsqueeze(0).to(device)  # (1, valid_T)

            # Perform inference for all methods
            no_ttt = model.inference_no_ttt(frames_seq, goal)
            
            online_preds = model.windowed_ttt_inference(frames_seq, goal, ttt_lr=ttt_lr_eval, ttt_epochs=ttt_epochs_eval, window_size=1, reset=False)
            
            online_res_preds = model.windowed_ttt_inference(frames_seq, goal, ttt_lr=ttt_lr_eval, ttt_epochs=ttt_epochs_eval, window_size=1, reset=True)
            
            window_preds = model.windowed_ttt_inference(frames_seq, goal, ttt_lr=ttt_lr_eval, ttt_epochs=ttt_epochs_eval, window_size=8, reset=False)
            
            window_res_preds = model.windowed_ttt_inference(frames_seq, goal, ttt_lr=ttt_lr_eval, ttt_epochs=ttt_epochs_eval, window_size=8, reset=True)
            
            offline_preds = model.offline_ttt_inference(frames_seq, goal, ttt_lr=ttt_lr_eval, ttt_epochs=ttt_epochs_eval)
            
            clip_preds = model.compute_clip_similarity_score(frames_seq, goal)[:, :valid_len]
            
            clip_preds_reg = compute_goal_baseline_clip_reward(model.clip_feature_extractor, frames_seq, goal, baseline_clip, alpha=0.5)[:, :valid_len]
            
            clip_preds_cont = compute_contrastive_clip_reward(model, frames_seq, goal, baseline_clip, tau=0.01, beta=0.5)

            # Evaluate metrics for each method
            for name, preds in zip(
                ["online", "online_res", "window", "window_res", "offline", "no_ttt", "clip", "clip_reg", "clip_cont"],
                [online_preds, online_res_preds, window_preds, window_res_preds, offline_preds, no_ttt, clip_preds, clip_preds_reg, clip_preds_cont]
            ):
                if preds is not None:
                    metrics[name]["voc"].append(masked_voc(preds, mask_seq))

    # Calculate the average metrics for each method
    averaged = {
        f"val/{method}_{metric}": float(sum(values) / len(values))
        for method, m_dict in metrics.items()
        for metric, values in m_dict.items()
        if len(values) > 0
    }

    return averaged
