import os
import sys
import yaml
import argparse
import numpy as np
import pandas as pd
from datetime import datetime
from tqdm import tqdm
import random


project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

from modules.survey_converter import Survey, BinaryExtendedSurvey
from modules.response_converter import Responses, BinaryExtendedResponses
from modules.endowment_manager import ActiveEndowments
from experiments.experiments import SimulationExperiment
import logging
import copy

# Configure logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

# Only add handler if not already added (avoids duplicate logs in Jupyter)
if not logger.hasHandlers():
    handler = logging.StreamHandler()
    formatter = logging.Formatter("[%(levelname)s] %(message)s")
    handler.setFormatter(formatter)
    logger.addHandler(handler)

    # File handler (with timestamped filename)
    log_dir = "logs"
    os.makedirs(log_dir, exist_ok=True)
    log_path = os.path.join(log_dir, f"entropy_sweep_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
    file_handler = logging.FileHandler(log_path)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)

from experiments.simulations import run_entropy_sweep, assign_ground_truth_from_modes

def compute_ordered_modes(endowments, responses, survey):
    """
    Compute ordered modes by entropy using raw responses (not binarized).

    Args:
        endowments (ActiveEndowments): Agent metadata.
        responses (BinaryExtendedResponses): Binary responses.
        survey (BinaryExtendedSurvey): Survey structure.

    Returns:
        List[Tuple[str]]: Modes ordered from high to low entropy.
    """
    original_survey = Survey(csv_path=survey.csv_path, config_path=survey.config_path)
    original_responses = Responses(source=responses.source_path, survey=original_survey, output_format="answer")
    sorted_modes = endowments.get_entropy_by_mode(original_responses, original_survey)
    ordered_modes = [mode for mode, entropy, count in sorted_modes[::-1]]  # Reverse sort
    return ordered_modes

def get_gt_modes_by_entropy_range(
    sorted_modes,
    lower: float = None,
    upper: float = None
) -> list:
    """
    Extracts mode tuples with entropy within a specified range.

    Args:
        sorted_modes (List[Tuple[Tuple[str], float, int]]): Output of get_entropy_by_mode().
        lower (float, optional): Minimum entropy (inclusive). If None, no lower bound.
        upper (float, optional): Maximum entropy (inclusive). If None, no upper bound.

    Returns:
        List[Tuple[str]]: Modes with entropy satisfying the specified bounds.
    """
    result = []
    for mode, entropy, _ in sorted_modes:
        if (lower is None or entropy >= lower) and (upper is None or entropy <= upper):
            result.append(mode)
    return result


def run_multi_round_entropy_sweep(config_path: str):
    # Load YAML config
    with open(config_path, "r") as f:
        config = yaml.safe_load(f)

    experiment_cfg = config.get("experiment", {})
    runner_cfg = config.get("runner", {})
    paths = config.get("paths", {})

    seed = runner_cfg.get("seed", 101)
    n_repeats = runner_cfg.get("n_repeats", 1)
    n_gt = runner_cfg.get("n_ground_truth", 10)
    n_proxies = runner_cfg.get("n_proxies", 30)
    models = runner_cfg.get("models", ["lasso"])
    plot = runner_cfg.get("plot", False)
    verbose = runner_cfg.get("verbose", True)

    experiment_name = experiment_cfg.get("name", "experiment").lower().replace(" ", "_")
    output_dir = runner_cfg.get("output_dir", f"outputs/entropy_sweep/{experiment_name}")
    os.makedirs(output_dir, exist_ok=True)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

    # --- Load raw data for entropy ---
    original_survey = Survey(csv_path=paths["survey_csv"], config_path=paths["survey_yaml"])
    original_responses = Responses(source=paths["responses_csv"], survey=original_survey, output_format="answer")
    endowments_base = ActiveEndowments.load(path=paths["endowments_csv"])

    sorted_modes = endowments_base.get_entropy_by_mode(original_responses, original_survey)
    ordered_modes = compute_ordered_modes(endowments_base, original_responses, original_survey)

    entropy_log_path = os.path.join(output_dir, f"entropy_sorted_{timestamp}.csv")
    pd.DataFrame(sorted_modes, columns=["mode", "entropy", "count"]).to_csv(entropy_log_path, index=False)

    lower = runner_cfg.get("entropy_threshold_lower", None)
    upper = runner_cfg.get("entropy_threshold_upper", None)

    if lower is not None or upper is not None:
        gt_modes = get_gt_modes_by_entropy_range(sorted_modes, lower=lower, upper=upper)
        if verbose:
            bound_desc = []
            if lower is not None:
                bound_desc.append(f"≥ {lower}")
            if upper is not None:
                bound_desc.append(f"≤ {upper}")
            logger.info(f"Using entropy threshold(s) {' and '.join(bound_desc)} → {len(gt_modes)} ground truth modes selected.")
        if not gt_modes:
            raise ValueError(f"No modes satisfy entropy thresholds: lower={lower}, upper={upper}. Try relaxing the bounds.")
    else:
        gt_modes = runner_cfg.get("ground_truth_modes", ordered_modes[-2:])
        if verbose:
            logger.info(f"Using manually specified ground truth modes: {gt_modes}")

    endowments_base.assign_roles()

    all_dfs = []

    for repeat in range(n_repeats):
        seed_i = seed + repeat
        logger.info(f"\n[⏱] Running repeat {repeat + 1} of {n_repeats} (seed={seed_i})...")

        endowments = copy.deepcopy(endowments_base)
        endowments = assign_ground_truth_from_modes(endowments, gt_modes=gt_modes, n_gt=n_gt, seed=seed_i)

        # --- Now load binarized response setup ---
        survey = BinaryExtendedSurvey(csv_path=paths["survey_csv"], config_path=paths["survey_yaml"])
        responses = BinaryExtendedResponses(source=paths["responses_csv"], survey=survey, output_format="code")

        experiment = SimulationExperiment(
            responses=responses,
            survey=survey,
            endowments=endowments,
            aggregate_stats=None,
            filter_binary=True,
            drop_na=True,
        )

        if verbose:
            logger.info("Ordered modes by entropy:")
            for i, mode in enumerate(ordered_modes):
                logger.info(f"  {i+1:02d}. {mode}")  

        for model_type in models:
            logger.info(f"\n>>> Running entropy sweep with {model_type} model...")

            # Paths for output
            output_csv = os.path.join(output_dir, f"entropysweep_{model_type}_rep{repeat+1}_{timestamp}.csv")
            plot_dir = os.path.join(output_dir, f"plots_{model_type}_rep{repeat+1}_{timestamp}") if plot else None
            config_copy_path = os.path.join(output_dir, f"config_copy_{model_type}_rep{repeat+1}_{timestamp}.yaml")

            # Save config for reproducibility
            with open(config_copy_path, "w") as f_out:
                yaml.dump(config, f_out)

            df = run_entropy_sweep(
                base_experiment=experiment,
                config=config,
                ordered_modes=ordered_modes,
                n_proxies=n_proxies,
                model_type=model_type,
                seed=seed_i,
                plot=plot,
                plot_dir=plot_dir,
                output_csv=output_csv,
                verbose=verbose
            )
            df["repeat"] = repeat + 1
            df["model"] = model_type
            all_dfs.append(df)

    # Aggregate results
    if n_repeats > 1:
        df_all = pd.concat(all_dfs, ignore_index=True)
        summary_path = os.path.join(output_dir, f"summary_all_{timestamp}.csv")
        df_all.to_csv(summary_path, index=False)
        logger.info(f"\n[✓] All rounds completed. Summary saved to: {summary_path}")
    else:
        df_all = all_dfs[0]

    return df_all


def main():
    parser = argparse.ArgumentParser(description="Run multi-round entropy sweep experiment.")
    parser.add_argument("--config", type=str, required=True, help="Path to YAML config file")
    args = parser.parse_args()
    run_multi_round_entropy_sweep(args.config)


if __name__ == "__main__":
    main()