"""
Reward Combiner Logging Module

This module provides logging utilities for the RewardCombiner class,
including test cases and detailed information logging.
"""

import logging
from typing import Dict, List, Any, Optional, Tuple
import numpy as np
import torch
try:
    from .reward_combiner import RewardCombiner
except ImportError:
    from reward_combiner import RewardCombiner

def log_test_reward_combiner(
    logger: logging.Logger,
    reward_combiner: RewardCombiner
) -> None:
    """
    Log comprehensive details about a RewardCombiner object including test cases.

    This function logs information about the RewardCombiner configuration and
    runs a comprehensive set of test cases with realistic LLM-as-judge scores
    in the range [1, 10], then logs the inputs and outputs.

    Args:
        logger: The logger instance to use for logging
        reward_combiner: The RewardCombiner object to test and log
    """
    # Get objective names directly from the combination function
    objective_names = reward_combiner.combination_function.objective_names
    num_objectives = len(objective_names)

    logger.info("\n" + "="*80)
    logger.info("REWARD COMBINER DETAILED ANALYSIS")
    logger.info("="*80)

    # Log RewardCombiner configuration
    logger.info("\n--- RewardCombiner Configuration ---")
    logger.info(f"Combination Function Type: {reward_combiner.combination_function.get_name()}")
    logger.info(f"Number of Objectives: {num_objectives}")
    logger.info(f"Objective Names: {objective_names}")
    logger.info(f"Normalize Inputs: {reward_combiner.normalize_inputs}")
    logger.info(f"Normalize Outputs: {reward_combiner.normalize_outputs}")
    logger.info(f"Input Range: {reward_combiner.input_range}")
    logger.info(f"Output Range: {reward_combiner.output_range}")
    logger.info(f"Is Fitted: {reward_combiner.combination_function.is_fitted}")

    # Log model-specific details
    _log_model_specific_details(logger, reward_combiner, objective_names)

    # Generate and test with realistic test cases
    logger.info("\n--- Test Cases with Realistic LLM Scores ---")
    logger.info("(Scores in range [1, 10] as typical LLM-as-judge outputs)\n")

    # Generate test cases
    test_cases = _generate_test_cases(objective_names)

    # Run test cases and log results
    for i, (description, scores_dict) in enumerate(test_cases, 1):
        try:
            # Compute combined reward using the RewardCombiner's combine_rewards method
            combined_reward = reward_combiner.combine_rewards(scores_dict)

            # Format scores for logging
            scores_str = ", ".join([f"{obj}={scores_dict[obj]:.1f}" for obj in objective_names])

            logger.info(f"Test Case {i}: {description}")
            logger.info(f"  Input Scores: [{scores_str}]")
            logger.info(f"  Combined Output: {combined_reward:.6f}")

        except Exception as e:
            logger.error(f"Test Case {i} failed: {e}")

    # Add statistical summary
    _log_statistical_summary(logger, reward_combiner, test_cases)

    logger.info("\n" + "="*80 + "\n")


def _log_model_specific_details(logger: logging.Logger, reward_combiner: RewardCombiner, objective_names: List[str]) -> None:
    """
    Log details specific to the combination function type.

    Args:
        logger: Logger instance
        reward_combiner: The RewardCombiner object
        objective_names: List of objective names
    """
    cf = reward_combiner.combination_function
    combiner_type = cf.get_name()

    if combiner_type == "linear_manual":
        logger.info("\n--- Linear Manual Configuration ---")
        logger.info("  Weights:")
        for obj in objective_names:
            logger.info(f"    {obj}: {cf.weights[obj]:.4f}")
        logger.info(f"  Bias: {cf.bias:.4f}")

        # Log weight sum for normalization check
        weight_sum = sum(cf.weights.values())
        logger.info(f"  Weight Sum: {weight_sum:.4f}")

    elif combiner_type == "linear_regression" and cf.is_fitted:
        logger.info("\n--- Linear Regression Model ---")
        try:
            model = cf.model
            logger.info("  Coefficients:")
            for i, obj in enumerate(objective_names):
                logger.info(f"    {obj}: {model.coef_[i]:.4f}")
            logger.info(f"  Intercept: {model.intercept_:.4f}")

            # Log coefficient sum
            coef_sum = np.sum(model.coef_)
            logger.info(f"  Coefficient Sum: {coef_sum:.4f}")
        except Exception as e:
            logger.warning(f"Could not extract linear regression details: {e}")

    elif combiner_type == "gradient_boosting" and cf.is_fitted:
        logger.info("\n--- Gradient Boosting Model ---")
        try:
            model = cf.model
            logger.info(f"  N Estimators: {model.n_estimators}")
            logger.info(f"  Max Depth: {model.max_depth}")
            logger.info(f"  Learning Rate: {model.learning_rate}")

            if hasattr(model, 'feature_importances_'):
                logger.info("  Feature Importances:")
                for i, obj in enumerate(objective_names):
                    logger.info(f"    {obj}: {model.feature_importances_[i]:.4f}")

                # Verify importances sum to 1
                importance_sum = np.sum(model.feature_importances_)
                logger.info(f"  Importance Sum: {importance_sum:.4f}")
        except Exception as e:
            logger.warning(f"Could not extract gradient boosting details: {e}")

    elif combiner_type == "mlp" and cf.is_fitted:
        logger.info("\n--- MLP Neural Network ---")
        logger.info(f"  Architecture: {cf.input_size} -> {cf.hidden_sizes} -> 1")
        logger.info(f"  Dropout Rate: {cf.dropout_rate}")
        logger.info(f"  Learning Rate: {cf.learning_rate}")

        # Log number of parameters
        try:
            total_params = sum(p.numel() for p in cf.model.parameters())
            trainable_params = sum(p.numel() for p in cf.model.parameters() if p.requires_grad)
            logger.info(f"  Total Parameters: {total_params}")
            logger.info(f"  Trainable Parameters: {trainable_params}")
        except Exception as e:
            logger.debug(f"Could not count MLP parameters: {e}")


def _generate_test_cases(objective_names: List[str]) -> List[Tuple[str, Dict[str, float]]]:
    """
    Generate comprehensive test cases with realistic LLM-as-judge scores.
    Does not make any assumptions about objective names or their meanings.

    Args:
        objective_names: List of objective names

    Returns:
        List of tuples (description, scores_dict) for each test case
    """
    test_cases = []
    num_objectives = len(objective_names)

    # Test Case 1: All minimum scores
    test_cases.append((
        "All Minimum (Score: 1.0)",
        {obj: 1.0 for obj in objective_names}
    ))

    # Test Case 2: All maximum scores
    test_cases.append((
        "All Maximum (Score: 10.0)",
        {obj: 10.0 for obj in objective_names}
    ))

    # Test Case 3: All medium scores
    test_cases.append((
        "All Medium (Score: 5.5)",
        {obj: 5.5 for obj in objective_names}
    ))

    # Test Case 4: Linear gradient from low to high
    if num_objectives > 1:
        gradient_scores = {}
        values = np.linspace(2.0, 9.0, num_objectives)
        for i, obj in enumerate(objective_names):
            gradient_scores[obj] = float(values[i])
        test_cases.append((
            "Linear Gradient (2.0 → 9.0)",
            gradient_scores
        ))

    # Test Case 5: Alternating high and low
    alternating_scores = {}
    for i, obj in enumerate(objective_names):
        alternating_scores[obj] = 8.0 if i % 2 == 0 else 3.0
    test_cases.append((
        "Alternating Pattern (8.0/3.0)",
        alternating_scores
    ))

    # Test Case 6: First high, rest low
    one_high_scores = {obj: 3.0 for obj in objective_names}
    if objective_names:
        one_high_scores[objective_names[0]] = 9.0
    test_cases.append((
        "First High (9.0), Rest Low (3.0)",
        one_high_scores
    ))

    # Test Case 7: Last high, rest low
    last_high_scores = {obj: 3.0 for obj in objective_names}
    if objective_names:
        last_high_scores[objective_names[-1]] = 9.0
    test_cases.append((
        "Last High (9.0), Rest Low (3.0)",
        last_high_scores
    ))

    # Test Case 8: Random uniform distribution (good range)
    np.random.seed(42)  # For reproducibility
    good_random = {}
    for obj in objective_names:
        good_random[obj] = float(np.random.uniform(6.5, 8.5))
    test_cases.append((
        "Random Uniform [6.5, 8.5]",
        good_random
    ))

    # Test Case 9: Random uniform distribution (mediocre range)
    np.random.seed(43)
    mediocre_random = {}
    for obj in objective_names:
        mediocre_random[obj] = float(np.random.uniform(4.0, 6.0))
    test_cases.append((
        "Random Uniform [4.0, 6.0]",
        mediocre_random
    ))

    # Test Case 10: Decreasing linear pattern
    if num_objectives > 1:
        decreasing_scores = {}
        values = np.linspace(9.0, 2.0, num_objectives)
        for i, obj in enumerate(objective_names):
            decreasing_scores[obj] = float(values[i])
        test_cases.append((
            "Linear Decreasing (9.0 → 2.0)",
            decreasing_scores
        ))

    # Test Case 11: Bell curve pattern (if enough objectives)
    if num_objectives >= 3:
        bell_scores = {}
        center = num_objectives // 2
        for i, obj in enumerate(objective_names):
            distance = abs(i - center) / max(1, (num_objectives - 1) / 2)
            bell_scores[obj] = float(10.0 - (distance * 7.0))  # Peak at 10, edges at 3
        test_cases.append((
            "Bell Curve Pattern",
            bell_scores
        ))

    # Test Case 12: Mostly high with center low (if enough objectives)
    if num_objectives > 2:
        dip_scores = {obj: 8.0 for obj in objective_names}
        dip_scores[objective_names[num_objectives // 2]] = 2.0
        test_cases.append((
            "Center Dip (8.0 with center at 2.0)",
            dip_scores
        ))

    # Test Case 13: Stepped increases
    if num_objectives >= 3:
        step_scores = {}
        step_size = 8.0 / max(1, num_objectives - 1)
        for i, obj in enumerate(objective_names):
            step_scores[obj] = float(1.0 + i * step_size)
        test_cases.append((
            f"Stepped Increase (step: {step_size:.2f})",
            step_scores
        ))

    # Test Case 14: Boundary values
    boundary_scores = {}
    for i, obj in enumerate(objective_names):
        if i % 3 == 0:
            boundary_scores[obj] = 1.1  # Near minimum
        elif i % 3 == 1:
            boundary_scores[obj] = 9.9  # Near maximum
        else:
            boundary_scores[obj] = 5.0  # Middle
    test_cases.append((
        "Boundary Test (1.1/9.9/5.0)",
        boundary_scores
    ))

    # Test Case 15: Sine wave pattern (if enough objectives)
    if num_objectives >= 4:
        sine_scores = {}
        for i, obj in enumerate(objective_names):
            # Map to [0, 2π] and apply sine, then scale to [2, 9]
            angle = (i / (num_objectives - 1)) * 2 * np.pi if num_objectives > 1 else 0
            sine_val = (np.sin(angle) + 1) / 2  # Normalize to [0, 1]
            sine_scores[obj] = float(2.0 + sine_val * 7.0)  # Scale to [2, 9]
        test_cases.append((
            "Sine Wave Pattern",
            sine_scores
        ))

    return test_cases


def _log_statistical_summary(logger: logging.Logger, reward_combiner: RewardCombiner,
                            test_cases: List[Tuple[str, Dict[str, float]]]) -> None:
    """
    Log statistical summary of test results.

    Args:
        logger: Logger instance
        reward_combiner: The RewardCombiner object
        test_cases: List of test cases
    """
    logger.info("\n--- Statistical Summary of Test Results ---")

    try:
        outputs = []
        for description, scores_dict in test_cases:
            try:
                output = reward_combiner.combine_rewards(scores_dict)
                outputs.append(output)
            except Exception as e:
                logger.debug(f"Failed to compute output for '{description}': {e}")
                continue

        if outputs:
            outputs_array = np.array(outputs)
            logger.info(f"  Number of Successful Tests: {len(outputs)}/{len(test_cases)}")
            logger.info(f"  Mean Output: {np.mean(outputs_array):.6f}")
            logger.info(f"  Std Dev: {np.std(outputs_array):.6f}")
            logger.info(f"  Min Output: {np.min(outputs_array):.6f}")
            logger.info(f"  Max Output: {np.max(outputs_array):.6f}")
            logger.info(f"  Median Output: {np.median(outputs_array):.6f}")
            logger.info(f"  25th Percentile: {np.percentile(outputs_array, 25):.6f}")
            logger.info(f"  75th Percentile: {np.percentile(outputs_array, 75):.6f}")
            logger.info(f"  Range: {np.ptp(outputs_array):.6f}")

            # Check if outputs are within expected range
            if reward_combiner.output_range:
                min_expected, max_expected = reward_combiner.output_range
                within_range = np.all((outputs_array >= min_expected) & (outputs_array <= max_expected))
                out_of_range_count = np.sum((outputs_array < min_expected) | (outputs_array > max_expected))
                logger.info(f"  Expected Output Range: [{min_expected:.2f}, {max_expected:.2f}]")
                logger.info(f"  All Outputs Within Range: {within_range}")
                if not within_range:
                    logger.info(f"  Number Out of Range: {out_of_range_count}")

            # Analyze monotonicity for gradient test cases
            gradient_outputs = []
            for desc, scores in test_cases:
                if "Gradient" in desc or "Decreasing" in desc or "Stepped" in desc:
                    gradient_outputs.append((desc, reward_combiner.combine_rewards(scores)))

            if gradient_outputs:
                logger.info("\n  Monotonic Pattern Tests:")
                for desc, output in gradient_outputs:
                    logger.info(f"    {desc}: {output:.6f}")

        else:
            logger.warning("  No successful test outputs to summarize")

    except Exception as e:
        logger.error(f"  Failed to compute statistical summary: {e}")