from sklearn.preprocessing import MinMaxScaler
import torch
import numpy as np
from .context_fid import Context_FID
from .cross_correlation import CrossCorrelLoss
from .discriminative_metric import discriminative_score_metrics
from .predictive_metric import predictive_score_metrics
from .metric_utils import display_scores


def apply_scale_factor(generated: np.ndarray, original: np.ndarray):
    """
    Apply a scale factor to the generated data based on the original data.
    The scale factor is computed as the ratio of the mean of the original data
    to the mean of the generated data.
    """
    mean_generated = np.mean(generated, axis=0, keepdims=True)
    std_generated = np.std(generated, axis=0, keepdims=True)

    mean_original = np.mean(original, axis=0, keepdims=True)
    std_original = np.std(original, axis=0, keepdims=True)

    scaled_generated = mean_original + (generated - mean_generated) * (
        std_original / std_generated
    )

    return scaled_generated


def scale_generated_data(generated: np.ndarray, original: np.ndarray):
    B, L, K = original.shape

    # Reshape to 2D for scaling
    ori_data = original.reshape(-1, K)
    fake_data = generated.reshape(-1, K)

    # 1. Fit scaler ONLY on the original data
    predictive_scaler = MinMaxScaler()
    predictive_scaler.fit(ori_data)

    # 2. Transform both datasets using the SAME scaler
    ori_data = predictive_scaler.transform(ori_data).reshape(B, L, K)
    fake_data = predictive_scaler.transform(fake_data).reshape(B, L, K)

    return fake_data, ori_data


def evaluate_model(
    ori_data,
    fake_data,
    metrics_iterations=5,
    model_training_iterations: dict = None,
    score_file: str = None,
    compute_fid=True,
    compute_cross_corr=True,
    compute_discriminative=True,
    compute_predictive=True,
) -> dict:
    assert (
        ori_data.shape == fake_data.shape
    ), f"Original and fake data must have the same shape: {ori_data.shape} != {fake_data.shape}"

    B, L, K = ori_data.shape
    context_fid_model_iterations = (
        model_training_iterations.get("context_fid", None)
        if model_training_iterations
        else None
    )
    discriminative_model_iterations = (
        model_training_iterations.get("discriminative", None)
        if model_training_iterations
        else None
    )
    predictive_model_iterations = (
        model_training_iterations.get("predictive", None)
        if model_training_iterations
        else None
    )

    fake_data = apply_scale_factor(fake_data, ori_data)

    # CONTEXT FID METRIC
    if compute_fid:
        print("\nComputing CONTEXT-FID")
        context_fid_score = []
        for i in range(metrics_iterations):
            context_fid = Context_FID(
                ori_data, fake_data, model_training_iterations=context_fid_model_iterations
            )
            context_fid_score.append(context_fid)
            print(f"Iter {i}: context-fid = {context_fid}\n")
    else:
        context_fid_score = [0.0] * metrics_iterations

    display_scores(
        context_fid_score,
        metric_name=f"CONTEXT-FID sl_{L}",
        score_file=score_file,
    )

    # CROSS CORRELATION METRIC
    if compute_cross_corr:
        print("\nComputing CROSS-CORRELATION")
        x_real = torch.from_numpy(ori_data).to(torch.float64)
        x_fake = torch.from_numpy(fake_data).to(torch.float64)

        def random_choice(size, num_select=100):
            select_idx = np.random.randint(low=0, high=size, size=(num_select,))
            return select_idx

        correlational_score = []
        size = int(x_real.shape[0] / metrics_iterations)
        for i in range(metrics_iterations):
            real_idx = random_choice(x_real.shape[0], size)
            fake_idx = random_choice(x_fake.shape[0], size)
            corr = CrossCorrelLoss(x_real[real_idx, :, :], name="CrossCorrelLoss")
            loss = corr.compute(x_fake[fake_idx, :, :])
            correlational_score.append(loss.item())
            print(f"Iter {i}: cross-correlation = {loss.item()}\n")
    else:
        correlational_score = [0.0] * metrics_iterations
        print("\nSkipping CROSS-CORRELATION computation")

    display_scores(
        correlational_score,
        metric_name=f"CROSS-CORRELATION sl_{L}",
        score_file=score_file,
    )

    if compute_discriminative:
    # DISCRIMINATIVE SCORE
        print("\nComputing DISCRIMINATIVE SCORE")
        discriminative_score = []
        for i in range(metrics_iterations):
            temp_disc, fake_acc, real_acc = discriminative_score_metrics(
                ori_data,
                fake_data,
                model_training_iterations=discriminative_model_iterations,
                device="cuda",
            )
            discriminative_score.append(temp_disc)
            print(
                f"Iter {i}: disc_score = {temp_disc}, fake_acc = {fake_acc}, real_acc = {real_acc}\n"
            )
    else:
        discriminative_score = [0.0] * metrics_iterations
        print("\nSkipping DISCRIMINATIVE SCORE computation")

    display_scores(
        discriminative_score,
        metric_name=f"DISCRIMINATIVE-SCORE sl_{L}",
        score_file=score_file,
    )

    if compute_predictive:
        # Scale the original and the generated data to the range [0, 1]
        # because the predictive model has a layer of sigmoid at the end
        fake_data, ori_data = scale_generated_data(fake_data, ori_data)

        # PREDICTIVE SCORE
        print("\nComputing PREDICTIVE SCORE")
        predictive_score = []
        for i in range(metrics_iterations):
            temp_pred = predictive_score_metrics(
                ori_data,
                fake_data,
                model_training_iterations=predictive_model_iterations,
                device="cuda",
            )
            predictive_score.append(temp_pred)
            print(f"Iter {i}: {temp_pred}\n")
    else:
        print("\nSkipping PREDICTIVE SCORE computation")
        predictive_score = [0.0] * metrics_iterations

    display_scores(
        predictive_score,
        score_file=score_file,
        metric_name=f"PREDICTIVE-SCORE sl_{L}",
    )

    return {
        "context_fid": np.mean(context_fid_score),
        "cross_correlation": np.mean(correlational_score),
        "discriminative_score": np.mean(discriminative_score),
        "predictive_score": np.mean(predictive_score),
    }


def calculate_final_weighted_score(
    current_scores: dict, baseline_scores: dict, weights: dict = None
) -> float:
    """
    Calculates a single, final score based on the weighted average of
    percentage improvements for each metric.

    This method is balanced and symmetrical, meaning a 20% improvement in one
    metric and a 20% degradation in another will cancel each other out.

    The final score represents performance relative to the baseline:
    - Score < 1.0: The current model is better than the baseline on average.
    - Score > 1.0: The current model is worse than the baseline on average.
    - Score = 1.0: The models have the same average performance.

    Args:
        current_scores (dict): A dictionary of the new model's scores.
        baseline_scores (dict): Dictionary of scores from a reference model.
        weights (dict, optional): Weights for each metric. If None, all are
                                  equally weighted.
    Returns:
        float: The final combined balanced score.
    """
    avg_improvement = 0.0
    epsilon = 1e-10

    # Ensure all baseline keys are in the current scores
    assert all(
        key in current_scores for key in baseline_scores
    ), "All keys in baseline_scores must be present in current_scores."

    # If no weights are provided, create equal weights
    if weights is None:
        num_metrics = len(baseline_scores)
        weights = {key: 1.0 / num_metrics for key in baseline_scores}

    # Calculate the weighted average of percentage improvements
    for key in baseline_scores.keys():
        current_val = current_scores[key]
        baseline_val = baseline_scores[key]

        # Calculate percentage improvement: (old - new) / old
        # A positive value means the new score is better (lower)
        improvement = (baseline_val - current_val) / (baseline_val + epsilon)

        avg_improvement += weights[key] * improvement

    # Transform the average improvement into the desired 1.0-centered score
    final_score = 1.0 - avg_improvement

    return final_score
