import numpy as np
import pytest
import pandas as pd
from sklearn.exceptions import NotFittedError
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Any
from numpy.testing import assert_array_equal
from collections import Counter
# Improved module import with better error handling
try:
    import weaver
except ImportError:
    import sys
    BASE_DIR = str(Path(__file__).parent.parent)
    sys.path.append(BASE_DIR)
    try:
        import weaver
    except ImportError:
        raise ImportError("Failed to import weaver module. Make sure it's installed or in your path.")

from weaver.models import Model


# Centralized model configuration
MODEL_PARAMS = {
    "logistic_regression": {
        "random_state": 42,
        "max_iter": 1000,
        "penalty": "l2",
        "class_weight": "balanced",
    },
    "naive_bayes": {
        "binarize_threshold": 0.5,
        "clip_min": 0.01,
        "clip_max": 0.99,
        "use_deps": "drop",
        "drop_imbalanced_verifiers": None,
    },
    "coverage": {},
    "first_sample": {},
    "majority_vote": {
        "k": 1,
        "majority_select": "majority",
    },
}

# Expected metrics keys that should be in all results
EXPECTED_METRICS = {
    "sample_accuracy",  # Sample accuracy per dataset
    "top1_positive",    # Top-1 positive per dataset
    "top1_idx",         # Top-1 index per dataset
    "prediction_accuracy", # Prediction accuracy per dataset
    "model_params",     # Model params per dataset
    "verifier_subset"   # Verifier subset per dataset
}


def get_model_config(model_type: str) -> Dict[str, Any]:
    """
    Generate model configuration dictionary based on model type.
    
    Args:
        model_type: Type of model to configure
        
    Returns:
        Dictionary containing model configuration
    """
    if model_type not in MODEL_PARAMS:
        raise ValueError(f"Unknown model type: {model_type}")
        
    return {
        "model_type": model_type,
        "model_class": "per_dataset",
        "model_params": MODEL_PARAMS[model_type]
    }


def generate_synthetic_data(
    num_problems: int = 20,
    num_samples: int = 10,
    num_verifiers: int = 3,
    random_seed: int = 42
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, List[str]]:
    """
    Generate synthetic data for testing.
    
    Args:
        num_problems: Number of problems to generate
        num_samples: Number of samples per problem
        num_verifiers: Number of verifiers
        random_seed: Random seed for reproducibility
        
    Returns:
        Tuple containing (X, y, train_answers, verifier_names)
    """
    np.random.seed(random_seed)
    
    # Generate features and labels
    X = np.random.rand(num_problems, num_samples, num_verifiers)
    y = np.random.randint(0, 2, (num_problems, num_samples))
    
    # Generate text answers
    train_answers = np.array([
        [f"Answer_{i}_{j}" for j in range(num_samples)]
        for i in range(num_problems)
    ])
    
    # Generate verifier names
    verifier_names = [f"v{i}" for i in range(num_verifiers)]
    
    return X, y, train_answers, verifier_names


def check_metrics(model_type: str, all_metrics: pd.DataFrame, y: np.ndarray, X: np.ndarray) -> None:
    """
    Validate that the metrics produced by the model meet expectations.
    
    Args:
        model_type: Type of model being tested
        all_metrics: DataFrame containing model metrics
        y: Ground truth labels
    """
    assert y.ndim == 2, "Expected y to be 2-dimensional"
    
    # Check that the metrics DataFrame has the expected keys
    missing_keys = EXPECTED_METRICS - set(all_metrics.columns)
    assert not missing_keys, f"Missing expected metric keys: {missing_keys}"
    
    # Check if there exists a correct (or positive) sample in each problem
    y_has_correct = np.any(y == 1, axis=1)
    
    # Model-specific metric checks
    if model_type == "coverage":
        # Coverage should find a positive sample if one exists
        assert np.all(all_metrics["top1_positive"].values == y_has_correct), \
            "Coverage model should find all positive samples"
    elif model_type  == "first_sample":
        # These models may not find all positive samples
        y_first_sample = y[:, 0]
        X_first_sample = (X[:, 0, :] >= 0.5).astype(int)
        top1_positive = (X_first_sample == 1).mean()
        prediction_accuracy = (X_first_sample == y_first_sample[..., None]).mean()
        assert np.allclose(all_metrics["top1_positive"].values.mean(),top1_positive)
        assert np.allclose(all_metrics["sample_accuracy"].values.mean(),prediction_accuracy)

    elif model_type == "majority_vote":
        # get the string with the most counts, and the indices of that string
        # For each problem, get the string with the most counts, and the indices of that string
        y_positive = 0
        for i in range(len(X)):
            most_common = Counter(X[i]).most_common(1)[0]
            most_common_idx = X[i] == most_common[0]
            y_most_common = y[i][most_common_idx]

            # is any of those positive?
            y_positive += np.any(y_most_common == 1)

        y_positive = y_positive / len(X)

        # is any of those positive?
        assert all_metrics["top1_positive"].values.mean() == y_positive
        assert all_metrics["prediction_accuracy"].values.mean() == y_positive
    
    elif model_type == "logistic_regression":
        # Check that top1_idx is consistent
        # for each problem, the accuracy has to be less than the ground truth accuracy
        y_mean = y.mean(axis=1)
        # Calculate correlation between sample accuracy and mean y
        correlation = np.corrcoef(all_metrics["sample_accuracy"].values, y_mean)[0, 1]
        # There should be at least some positive correlation
        # A negative or zero correlation would be suspicious
        assert correlation > -0.3, \
            f"Expected positive correlation between sample accuracy and ground truth, got {correlation:.3f}"
        
        print(f"Info: Correlation between sample accuracy and mean ground truth: {correlation:.3f}")
    

@pytest.mark.parametrize(
    "model_type",
    ["coverage", "first_sample", "majority_vote", "logistic_regression"]
)
def test_models(model_type: str) -> None:
    """
    Test a specific model type on synthetic data.
    
    Args:
        model_type: Type of model to test
    """
    # Generate synthetic test data
    X, y, train_answers, verifier_names = generate_synthetic_data()
    
    # Configure and instantiate model
    model_cfg = get_model_config(model_type)
    model = Model(verifier_names, clusters=None, **model_cfg)
    
    # Use train_answers for majority_vote instead of X
    input_data = train_answers if model_type == "majority_vote" else X
    
    # Fit the model
    model.fit(input_data, y)
    
    # Calculate metrics for each problem
    metrics_list = []
    for i in range(len(X)):
        if model_type == "majority_vote":
            problem_input = train_answers[i]
        elif model_type == "first_sample":
            problem_input = X[i] >= 0.5
        else:
            problem_input = X[i]
        metrics = model.calculate_metrics(problem_input, y[i])
        metrics_list.append(metrics)
    
    all_metrics = pd.DataFrame(metrics_list)
    
    # Validate metrics
    check_metrics(model_type, all_metrics, y, input_data)
    return


@pytest.mark.parametrize("model_type", ["logistic_regression", "naive_bayes"])
def test_model_not_fitted(model_type: str) -> None:
    """
    Test that models raise NotFittedError when predict is called before fit.
    
    Args:
        model_type: Type of model to test
    """
    # Generate synthetic test data
    X, y, _, verifier_names = generate_synthetic_data(num_problems=1)
    
    # Configure and instantiate model
    model_cfg = get_model_config(model_type)
    model = Model(verifier_names, clusters=None, **model_cfg)
    
    # Attempt to predict before fitting (should raise NotFittedError)
    #with pytest.raises(NotFittedError):
    #    output = model.calculate_metrics(X[0], y[0])
    #    breakpoint()
    output = model.calculate_metrics(X[0], y[0])
    assert np.all(np.isnan([output[k] for k in output.keys()]))
    pass

@pytest.mark.parametrize("k", [1, 3, 5])
def test_majority_vote_k_parameter(k: int) -> None:
    """
    Test that MajorityVote model respects the 'k' parameter.
    
    Args:
        k: Value for the 'k' parameter
    """
    # Generate synthetic test data
    X, y, train_answers, verifier_names = generate_synthetic_data()
    
    # Configure model with specific k value
    model_cfg = get_model_config("majority_vote")
    model_cfg["model_params"]["k"] = k
    
    model = Model(verifier_names, clusters=None, **model_cfg)
    
    # Fit and calculate metrics
    model.fit(train_answers, y)
    
    # Verify k parameter was applied
    assert model.model.k == k, f"MajorityVote k parameter not set correctly: {model.models[0].k} != {k}"