from time import time

import numpy as np
import torch
from sklearn.metrics import mean_absolute_percentage_error, r2_score

device = "cuda" if torch.cuda.is_available() else "cpu"


def evaluate_model(
    model: torch.nn.Module,
    test_inputs_scaled: torch.Tensor,
    test_outputs_scaled: torch.Tensor,
    scaling_params: dict,
    c: callable,
    batch_size: int = None,
) -> tuple[dict, np.ndarray, np.ndarray]:
    """
    Evaluate the given model on the provided test data and calculate various metrics.
    Parameters:
    model (torch.nn.Module): The trained model to be evaluated.
    test_inputs_scaled (torch.Tensor): Scaled test input data.
    test_outputs_scaled (torch.Tensor): Scaled test output data.
    scaling_params (dict): Dictionary containing the scaling parameters for inputs and outputs.
        - 'input_mean': Mean used for scaling the inputs.
        - 'input_std': Standard deviation used for scaling the inputs.
        - 'output_mean': Mean used for scaling the outputs.
        - 'output_std': Standard deviation used for scaling the outputs.
    c (callable): Constraint function that takes test inputs and predictions as arguments and returns residuals.
    Returns:
    tuple: A tuple containing:
        - metrics (dict): Dictionary containing the calculated metrics:
            - 'r2_y1': R^2 score for the first output.
            - 'mse_y1': Mean squared error for the first output.
            - 'mape_y1': Mean absolute percentage error for the first output.
            - 'r2_y2': R^2 score for the second output.
            - 'mse_y2': Mean squared error for the second output.
            - 'mape_y2': Mean absolute percentage error for the second output.
            - 'residual_avg': Average constraint residual.
            - 'residual_max': Maximum constraint residual.
            - 'projection iterations': Number of projection iterations.
            - 'inference time': Time taken for inference.
        - test_predictions (numpy.ndarray): The predicted outputs after scaling back to original scale.
        - test_prediction_before_projection (numpy.ndarray): The predicted outputs before projection, scaled back to original scale.
    """
    model.eval()
    if not batch_size:
        batch_size = test_inputs_scaled.shape[0]

    all_predictions = []
    all_predictions_before_proj = []
    total_proj_iter = 0
    total_inference_time = 0

    for i in range(0, test_inputs_scaled.shape[0], batch_size):
        batch_inputs = test_inputs_scaled[i : i + batch_size, :]
        st = time()
        batch_predictions, batch_before_proj, batch_proj_iter = model.predict(
            batch_inputs
        )
        end = time()
        all_predictions.append(batch_predictions)
        all_predictions_before_proj.append(batch_before_proj)
        total_proj_iter += batch_proj_iter
        total_inference_time += end - st

    test_predictions_scaled = torch.cat(all_predictions, dim=0)
    test_prediction_before_projection = torch.cat(all_predictions_before_proj, dim=0)
    proj_iter = total_proj_iter / len(all_predictions)  # Average projection iterations
    avg_inference_time = total_inference_time / len(
        all_predictions
    )  # Average inference time per batch

    test_inputs, test_predictions = model.unscale(
        test_inputs_scaled, test_predictions_scaled
    )
    _, test_outputs = model.unscale(test_inputs_scaled, test_outputs_scaled)

    test_predictions_npy = test_predictions.cpu().detach().numpy()
    test_outputs_npy = test_outputs.cpu().detach().numpy()

    # Calculate metrics for each output dimension
    metrics = {"projection iterations": proj_iter, "inference time": avg_inference_time}

    output_dim = test_outputs_npy.shape[1]
    for i in range(output_dim):
        r2 = r2_score(test_outputs_npy[:, i], test_predictions_npy[:, i])
        mse = np.mean((test_outputs_npy[:, i] - test_predictions_npy[:, i]) ** 2)
        mape = mean_absolute_percentage_error(
            test_outputs_npy[:, i], test_predictions_npy[:, i]
        )

        metrics[f"r2_y{i + 1}"] = r2
        metrics[f"mse_y{i + 1}"] = mse
        metrics[f"mape_y{i + 1}"] = mape

    # Calculate constraint residual
    residuals = c(test_inputs, test_predictions)
    if isinstance(residuals, tuple):
        # concatenate in a single tensor when residuals is a tuple of tensors
        residuals = torch.cat(residuals, dim=0)

    residual = torch.mean(torch.abs(residuals)).item()
    max_residual = torch.max(torch.abs(residuals)).item()
    metrics["residual_avg"] = residual
    metrics["residual_max"] = max_residual

    try:
        obj_value_opt = model.ssl_loss(test_inputs, test_outputs).item()
        obj_value_pred = model.ssl_loss(test_inputs, test_predictions).item()
        metrics["obj_value_opt"] = obj_value_opt
        metrics["obj_value_pred"] = obj_value_pred
        print(
            f"Inference time testing: {avg_inference_time}, EqMeanResidual: {residual}, EqMaxResidual: {max_residual}, ObjValueOpt: {obj_value_opt}, ObjValuePred: {obj_value_pred}"
        )
    except:
        print(
            f"Inference time testing: {avg_inference_time}, EqMeanResidual: {residual}, EqMaxResidual: {max_residual}"
        )

    return (
        metrics,
        test_predictions_npy,
        test_prediction_before_projection.cpu().detach().numpy()
        * scaling_params["output_std"]
        + scaling_params["output_mean"],
    )
