import itertools
from tqdm import tqdm
import json
import os
from datetime import datetime
import numpy as np
from ar import ActivationReasoning, LogicConfig
import pandas as pd
from collections import defaultdict
from sklearn.metrics import (
    accuracy_score,
)
import os
from typing import List, Dict, Tuple
import json
import pandas as pd
import numpy as np


def hyperparameter_search(
    ar_model: ActivationReasoning,
    trains: List,
    labels: List,
    hyperparameters: Dict,
    strategy: str = "grid",
    n_trials: int = None,
    metric: str = "model_accuracy",
    verbose: bool = False,
    save_path: str = None,
    batch_size: int = 80,
    existing_results_path: str = None,
    evaluation_data: Tuple = None,
    detection_mode: bool = False,
):
    """
    Perform hyperparameter search for the Activation Reasoning model.

    Args:
        ar_model (ActivationReasoning): The model to evaluate
        trains (list): List of training examples
        labels (list): List of labels
        hyperparameters (dict): Dictionary of hyperparameters to search, format:
            {
                'detection_top_k_output': [1, 2, 3],
                'detection_top_k_concepts': [1, 2, 3],
                'steering_top_k_rule': [1, 3, 5, 10],
                'steering_weighting_function': ['uniform', 'linear_decay', 'softmax_based'],
                'steering_factor': [0.2, 0.4, 0.6],
                'detection_threshold': [0.0, 0.1, 0.2],
                'detection_allow_multi': [True, False],
                'steering_norm': [2, 'off']
            }
        strategy (str): Search strategy - 'grid', 'random', or 'bayesian' (default: 'grid')
        n_trials (int): Number of trials for random/bayesian search (default: 10 for random, None for grid)
        metric (str): Metric to optimize - For generation: 'model_accuracy' or 'rule_accuracy'
                      For detection: 'accuracy', 'f1_score', 'precision', 'recall', 'auc' (default: 'model_accuracy')
        verbose (bool): Whether to print verbose output (default: False)
        save_path (str): Path to save the results (default: None)
        batch_size (int): Batch size for processing (default: 80)
        existing_results_path (str): Path to existing results file to resume/continue search (default: None)
        evaluation_data (tuple): Tuple of (eval_inputs, eval_labels) for evaluation (default: None)
        detection_mode (bool): Whether to use detection mode (with batch_detect) or generation mode (default: False)

    Returns:
        dict: Dictionary containing all results and the best configuration
    """

    # Validate hyperparameters
    if not hyperparameters:
        raise ValueError("Hyperparameters dictionary cannot be empty")

    # Validate metric based on mode
    if detection_mode:
        valid_metrics = [
            "accuracy",
            "balanced_accuracy",
            "f1_score",
            "precision",
            "recall",
            "auc",
            "roc_auc",
        ]
        if metric not in valid_metrics:
            raise ValueError(
                f"Invalid detection metric '{metric}'. Must be one of {valid_metrics}"
            )
    else:
        valid_metrics = ["model_accuracy", "rule_accuracy"]
        if metric not in valid_metrics:
            raise ValueError(
                f"Invalid generation metric '{metric}'. Must be one of {valid_metrics}"
            )

    # Set base model hyperparameters (for generation mode)
    model_hyp = {"do_sample": False, "temperature": None, "top_k": None, "top_p": None}

    # Load existing results if provided
    existing_configs = set()
    if existing_results_path and os.path.exists(existing_results_path):
        try:
            file_ext = os.path.splitext(existing_results_path)[1].lower()
            if file_ext == ".json":
                with open(existing_results_path, "r") as f:
                    existing_results = json.load(f)

                # Extract existing configurations
                for result in existing_results.get("all_configs", []):
                    # Create a hashable representation of the config
                    config = result.get("config", {})
                    config_tuple = tuple(sorted((k, str(v)) for k, v in config.items()))
                    existing_configs.add(config_tuple)

                if verbose:
                    print(
                        f"Loaded {len(existing_configs)} existing configurations from {existing_results_path}"
                    )

                # Initialize results with existing data
                results = existing_results

                # Update best score if needed
                if "best_score" not in results or results["best_score"] == 0:
                    # Find best score from existing configs
                    best_score = 0
                    best_config = None
                    for config_result in results.get("all_configs", []):
                        score = config_result.get("score", 0)
                        if score > best_score:
                            best_score = score
                            best_config = config_result.get("config")

                    results["best_score"] = best_score
                    results["best_config"] = best_config

            elif file_ext == ".csv":
                # Load CSV and extract configurations
                df = pd.read_csv(existing_results_path)

                # Identify config columns (exclude metrics)
                metric_cols = [
                    "score",
                    "model_acc",
                    "rule_acc",
                    "model_accuracy",
                    "rule_accuracy",
                    "accuracy",
                    "balanced_accuracy",
                    "f1_score",
                    "precision",
                    "recall",
                    "auc",
                    "roc_auc",
                    "best_threshold",
                ]
                config_cols = [col for col in df.columns if col not in metric_cols]

                # Extract existing configurations
                for _, row in df.iterrows():
                    config = {col: row[col] for col in config_cols}
                    config_tuple = tuple(sorted((k, str(v)) for k, v in config.items()))
                    existing_configs.add(config_tuple)

                # Create results structure
                results = {
                    "all_configs": [],
                    "best_config": None,
                    "best_score": 0,
                    "search_metadata": {
                        "strategy": strategy,
                        "num_examples": len(trains),
                        "metric": metric,
                        "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
                        "resumed_from": existing_results_path,
                        "detection_mode": detection_mode,
                    },
                }

                # Convert DataFrame to results format
                for _, row in df.iterrows():
                    config = {col: row[col] for col in config_cols}

                    # Handle different column names based on mode
                    if detection_mode:
                        result = {
                            "config": config,
                            "score": row.get(metric, 0),
                            "accuracy": row.get("accuracy", 0),
                            "balanced_accuracy": row.get(
                                "balanced_accuracy", row.get("accuracy", 0)
                            ),
                            "f1_score": row.get("f1_score", 0),
                            "precision": row.get("precision", 0),
                            "recall": row.get("recall", 0),
                            "roc_auc": row.get("roc_auc", 0.0),
                            "auc": row.get("auc", 0.0),
                            "best_threshold": row.get("best_threshold", 0.5),
                        }
                    else:
                        model_acc = row.get("model_acc", row.get("model_accuracy", 0))
                        rule_acc = row.get("rule_acc", row.get("rule_accuracy", 0))
                        score = row.get(
                            "score",
                            model_acc if metric == "model_accuracy" else rule_acc,
                        )
                        result = {
                            "config": config,
                            "score": score,
                            "model_acc": model_acc,
                            "rule_acc": rule_acc,
                        }

                    results["all_configs"].append(result)

                    # Update best config
                    if result["score"] > results["best_score"]:
                        results["best_score"] = result["score"]
                        results["best_config"] = config

                if verbose:
                    print(
                        f"Loaded {len(existing_configs)} existing configurations from {existing_results_path}"
                    )

            else:
                raise ValueError(
                    f"Unsupported file format: {file_ext}. Please use .json or .csv"
                )

        except Exception as e:
            print(f"Error loading existing results: {str(e)}")
            print("Starting fresh search...")
            existing_configs = set()

            # Initialize results structure
            results = {
                "all_configs": [],
                "best_config": None,
                "best_score": 0,
                "search_metadata": {
                    "strategy": strategy,
                    "num_examples": len(trains),
                    "metric": metric,
                    "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
                    "detection_mode": detection_mode,
                },
            }
    else:
        # Initialize results structure
        results = {
            "all_configs": [],
            "best_config": None,
            "best_score": 0,
            "search_metadata": {
                "strategy": strategy,
                "num_examples": len(trains),
                "metric": metric,
                "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
                "detection_mode": detection_mode,
            },
        }

    # Generate configurations based on strategy
    configs = []

    if strategy == "grid":
        # Generate all combinations
        keys = hyperparameters.keys()
        values = hyperparameters.values()
        all_configs = [dict(zip(keys, v)) for v in itertools.product(*values)]

        # Filter out already evaluated configurations
        for config in all_configs:
            config_tuple = tuple(sorted((k, str(v)) for k, v in config.items()))
            if config_tuple not in existing_configs:
                configs.append(config)

        if verbose:
            total_configs = len(all_configs)
            new_configs = len(configs)
            skipped_configs = total_configs - new_configs
            print(f"Grid search: {total_configs} total combinations")
            print(f"Skipping {skipped_configs} already evaluated configurations")
            print(f"Will evaluate {new_configs} new configurations")

    elif strategy == "random":
        # Random search
        if n_trials is None:
            n_trials = min(10, np.prod([len(v) for v in hyperparameters.values()]))

        attempts = 0
        max_attempts = n_trials * 3  # Avoid infinite loop for dense search spaces

        if verbose:
            print(f"Random search with {n_trials} trials")

        while len(configs) < n_trials and attempts < max_attempts:
            config = {}
            for key, values in hyperparameters.items():
                config[key] = np.random.choice(values)

            # Check if this config has already been evaluated
            config_tuple = tuple(sorted((k, str(v)) for k, v in config.items()))
            if config_tuple not in existing_configs:
                configs.append(config)
                existing_configs.add(config_tuple)  # Add to avoid duplicates

            attempts += 1

        if verbose and attempts >= max_attempts:
            print(
                f"Warning: Could only generate {len(configs)} unique configurations after {attempts} attempts"
            )

    elif strategy == "bayesian":
        try:
            from skopt import gp_minimize
            from skopt.space import Categorical, Integer, Real
            from skopt.utils import use_named_args

            # Define search space
            search_space = []
            param_names = []

            for key, values in hyperparameters.items():
                param_names.append(key)

                # Determine parameter type and create appropriate dimension
                if all(isinstance(v, bool) for v in values):
                    search_space.append(Categorical([True, False], name=key))
                elif all(isinstance(v, int) for v in values):
                    search_space.append(Integer(min(values), max(values), name=key))
                elif all(isinstance(v, float) for v in values):
                    search_space.append(Real(min(values), max(values), name=key))
                else:
                    search_space.append(Categorical(values, name=key))

            # Function to evaluate a configuration
            @use_named_args(search_space)
            def evaluate_config(**params):
                nonlocal results, existing_configs

                # Create config dictionary
                config = {name: params[name] for name in param_names}

                # Check if this config has already been evaluated
                config_tuple = tuple(sorted((k, str(v)) for k, v in config.items()))
                if config_tuple in existing_configs:
                    # Return worst possible score to discourage re-sampling this point
                    if verbose:
                        print(f"Skipping already evaluated config: {config}")
                    return 1.0  # For minimization

                # Evaluate configuration
                if detection_mode:
                    eval_data = evaluation_data if evaluation_data else (trains, labels)
                    result = _evaluate_detection_config(
                        ar_model,
                        trains,
                        labels,
                        config,
                        eval_data,
                        batch_size,
                        metric,
                        verbose,
                    )
                else:
                    result = _evaluate_generation_config(
                        ar_model,
                        trains,
                        labels,
                        config,
                        model_hyp,
                        batch_size,
                        metric,
                        verbose,
                    )

                # Track results
                results["all_configs"].append(result)

                # Mark as evaluated
                existing_configs.add(config_tuple)

                # Update best config
                if result["score"] > results["best_score"]:
                    results["best_score"] = result["score"]
                    results["best_config"] = config

                # Minimize for skopt
                return -result["score"]

            # Run Bayesian optimization
            if n_trials is None:
                n_trials = 10

            if verbose:
                print(f"Bayesian optimization search with {n_trials} trials")

            _ = gp_minimize(
                evaluate_config,
                search_space,
                n_calls=n_trials,
                random_state=42,
                verbose=verbose,
            )

            # Already evaluated in the objective function
            return results

        except ImportError:
            print(
                "Bayesian optimization requires scikit-optimize. Falling back to random search."
            )
            strategy = "random"
            # Fall back to random search
            if n_trials is None:
                n_trials = 10

            if verbose:
                print(f"Random search with {n_trials} trials")

            attempts = 0
            max_attempts = n_trials * 3

            while len(configs) < n_trials and attempts < max_attempts:
                config = {}
                for key, values in hyperparameters.items():
                    config[key] = np.random.choice(values)

                # Check if this config has already been evaluated
                config_tuple = tuple(sorted((k, str(v)) for k, v in config.items()))
                if config_tuple not in existing_configs:
                    configs.append(config)
                    existing_configs.add(config_tuple)  # Add to avoid duplicates

                attempts += 1
    else:
        raise ValueError(
            f"Invalid search strategy: {strategy}. Must be 'grid', 'random', or 'bayesian'"
        )

    # Evaluate all configurations
    if configs:
        progress_bar = tqdm(
            configs, desc=f"{strategy.capitalize()} Search", disable=not verbose
        )
        for config in progress_bar:
            # Create LogicConfig object
            logic_config = LogicConfig(**config)

            # Evaluate configuration based on mode
            if detection_mode:
                eval_data = evaluation_data if evaluation_data else (trains, labels)
                result = _evaluate_detection_config(
                    ar_model,
                    trains,
                    labels,
                    config,
                    eval_data,
                    batch_size,
                    metric,
                    verbose,
                )
            else:
                result = _evaluate_generation_config(
                    ar_model,
                    trains,
                    labels,
                    config,
                    model_hyp,
                    batch_size,
                    metric,
                    verbose,
                )

            # Store results
            results["all_configs"].append(result)

            # Update progress bar
            progress_bar.set_postfix(
                {"best_score": results["best_score"], "current_score": result["score"]}
            )

            # Update best config
            if result["score"] > results["best_score"]:
                results["best_score"] = result["score"]
                results["best_config"] = config

            # Periodically save results if requested
            if save_path and (len(results["all_configs"]) % 5 == 0):
                _save_results(results, save_path, verbose=False)
    else:
        if verbose:
            print("No new configurations to evaluate - all have been tested already")

    # Sort results by score
    results["all_configs"] = sorted(
        results["all_configs"], key=lambda x: x["score"], reverse=True
    )

    # Print best results
    if verbose:
        print("\nTop 3 configurations:")
        for i, result in enumerate(results["all_configs"][:3]):
            print(f"{i + 1}. Score: {result['score']:.3f}")
            if detection_mode:
                print(
                    f"   Balanced Acc: {result.get('balanced_accuracy', 0.0):.3f}, Accuracy: {result['accuracy']:.3f}, F1: {result['f1_score']:.3f}, "
                    f"Precision: {result['precision']:.3f}, Recall: {result['recall']:.3f}"
                )
            else:
                print(
                    f"   Model Acc: {result['model_acc']:.3f}, Rule Acc: {result['rule_acc']:.3f}"
                )
            print(f"   Config: {result['config']}")

    # Save results if requested
    if save_path:
        _save_results(results, save_path, verbose)

    return results


def _save_results(results, save_path, verbose=True):
    """Helper function to save results to disk"""
    os.makedirs(
        os.path.dirname(save_path) if os.path.dirname(save_path) else ".", exist_ok=True
    )

    # Convert to DataFrame for easier analysis
    if results.get("search_metadata", {}).get("detection_mode", False):
        df_results = pd.DataFrame(
            [
                {
                    **r["config"],
                    "score": r["score"],
                    "accuracy": r["accuracy"],
                    "balanced_accuracy": r.get("balanced_accuracy", r["accuracy"]),
                    "f1_score": r["f1_score"],
                    "precision": r["precision"],
                    "recall": r["recall"],
                    "roc_auc": r.get("roc_auc", 0.0),
                    "auc": r.get("auc", 0.0),
                    "best_threshold": r.get("best_threshold", 0.5),
                }
                for r in results["all_configs"]
            ]
        )
    else:
        df_results = pd.DataFrame(
            [
                {
                    **r["config"],
                    "score": r["score"],
                    "model_acc": r["model_acc"],
                    "rule_acc": r["rule_acc"],
                }
                for r in results["all_configs"]
            ]
        )

    # Save CSV
    df_results.to_csv(f"{save_path}.csv", index=False)

    # Save JSON with full results
    with open(f"{save_path}.json", "w") as f:
        json.dump(results, f, indent=2)

    if verbose:
        print(f"Results saved to {save_path}.csv and {save_path}.json")


def _evaluate_generation_config(
    ar_model, trains, labels, config, model_hyp, batch_size, metric, verbose
):
    """
    Helper method to evaluate a specific configuration for generation tasks.

    Args:
        ar_model: The Activation Reasoning model
        trains: List of training examples
        labels: List of labels
        config: Configuration to evaluate
        model_hyp: Model hyperparameters
        batch_size: Batch size for processing
        metric: Metric to optimize
        verbose: Whether to print verbose output

    Returns:
        dict: Result dictionary with scores
    """
    # Create LogicConfig object
    logic_config = LogicConfig(**config)

    # Initialize result collections
    al_labels, al_rules_labels = defaultdict(list), defaultdict(list)

    # Generate outputs
    al_outs, al_rules = ar_model.generate(
        trains,
        model_hyp=model_hyp,
        logic_config=logic_config,
        verbose=False,
        return_meta_data=True,
        batch_size=batch_size,
    )

    # Process outputs
    for train, label, al_out, al_rule in zip(trains, labels, al_outs, al_rules):
        # Clean output text
        # cleaned_output = al_out.strip().replace(".", "")
        # al_labels[label].append(cleaned_output)

        # # Get primary rule or 'False' if none
        # rule_output = al_rule[0] if len(al_rule) == 1 else "False"
        # al_rules_labels[label].append(rule_output)

        al_labels[label].append(al_out.strip().replace(".", "").replace("the ", ""))
        al_rules_labels[label].append(
            al_rule["rules"][0] if len(al_rule["rules"][0]) == 1 else ["False"]
        )

    # Calculate per-rule accuracy
    rule_accuracies = {}
    model_accuracies = {}
    for rule in al_labels.keys():
        # Model output accuracy
        acc_ar_model = accuracy_score([rule] * len(al_labels[rule]), al_labels[rule])
        model_accuracies[rule] = acc_ar_model

        # Rule activation accuracy
        acc_al_rule = accuracy_score(
            [rule] * len(al_rules_labels[rule]), al_rules_labels[rule]
        )
        rule_accuracies[rule] = acc_al_rule

        if verbose:
            print(
                f"Rule: {rule}, Model: {acc_ar_model * 100:.2f}%, Rule: {acc_al_rule * 100:.2f}%"
            )

    # Flatten for overall accuracy calculation
    al_preds = [pred for preds in al_labels.values() for pred in preds]
    al_true = [
        label for label in al_labels.keys() for _ in range(len(al_labels[label]))
    ]

    al_rules_preds = [pred for preds in al_rules_labels.values() for pred in preds]
    al_rules_true = [
        label
        for label in al_rules_labels.keys()
        for _ in range(len(al_rules_labels[label]))
    ]

    # Calculate overall accuracy
    model_accuracy = accuracy_score(al_true, al_preds)
    rule_accuracy = accuracy_score(al_rules_true, al_rules_preds)

    # Determine score based on selected metric
    score = model_accuracy if metric == "model_accuracy" else rule_accuracy

    # Print summary if verbose
    if verbose:
        print(f"Config: {config}")
        print(f"Model Accuracy: {model_accuracy * 100:.3f}%")
        print(f"Rule Accuracy: {rule_accuracy * 100:.3f}%")
        print("-" * 40)

    return {
        "config": config,
        "score": score,
        "model_acc": model_accuracy,
        "rule_acc": rule_accuracy,
    }


def _evaluate_detection_config(
    ar_model, train_inputs, train_labels, config, eval_data, batch_size, metric, verbose
):
    """
    Helper method to evaluate a specific configuration for detection tasks.
    Uses evaluate_model function from eval.py for consistent evaluation.

    Args:
        ar_model: The Activation Reasoning model
        train_inputs: List of training examples
        train_labels: List of training labels
        config: Configuration to evaluate
        eval_data: Tuple of (eval_inputs, eval_labels) for evaluation
        batch_size: Batch size for processing
        metric: Metric to optimize ('accuracy', 'f1_score', 'precision', 'recall', 'auc', 'balanced_accuracy')
        verbose: Whether to print verbose output

    Returns:
        dict: Result dictionary with scores
    """
    # Import evaluate_model to ensure consistent evaluation
    from ar.eval import evaluate_model

    # Create LogicConfig object
    logic_config = LogicConfig(**config)

    # Configure the model with the current configuration
    ar_model.configure(logic_config)

    # Use provided evaluation data if available, otherwise use training data
    eval_inputs, eval_labels = eval_data

    # Use evaluate_model for consistent evaluation
    eval_result = evaluate_model(
        model=ar_model,
        test_data=eval_inputs,
        test_labels=eval_labels,
        batch_size=batch_size,
        threshold=None,  # Let evaluate_model find the optimal threshold
        verbose=verbose,
        save_path=None,  # Don't save plots during hyperparameter search
    )

    # Map metric names to match the evaluation function
    metric_mapping = {
        "accuracy": "accuracy",
        "f1_score": "f1",
        "precision": "precision",
        "recall": "recall",
        "auc": "auprc",
        "balanced_accuracy": "balanced_accuracy",
    }

    # Use balanced accuracy as default if metric not found
    eval_metric = metric_mapping.get(metric, "balanced_accuracy")

    # Extract scores and threshold from evaluation results
    metrics = eval_result["metrics"]
    best_threshold = metrics["threshold"]

    # Determine score based on selected metric
    if eval_metric in metrics:
        score = metrics[eval_metric]
    else:
        score = metrics["balanced_accuracy"]  # Default to balanced accuracy

    # Print summary if verbose
    if verbose:
        print(f"Config: {config}")

    return {
        "config": config,
        "score": score,
        "accuracy": metrics["accuracy"],
        "balanced_accuracy": metrics["balanced_accuracy"],
        "f1_score": metrics["f1"],
        "precision": metrics["precision"],
        "recall": metrics["recall"],
        "roc_auc": metrics["roc_auc"],
        "auc": metrics["auprc"],
        "best_threshold": best_threshold,
    }
