import os
import sys
import yaml
import argparse
import shutil
import logging
import json
from datetime import datetime

# Set up project root and import paths
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

# Import runner modules
from scripts.run_parallel_elicitation import main as run_parallel_elicitation
from scripts.run_regression_experiment import run_regression_experiment
from modules.response_converter import ResponseUtils


def setup_logger(log_path):
    logger = logging.getLogger("plugin_pipeline_logger")
    logger.setLevel(logging.INFO)
    logger.handlers = []  # Clear existing handlers

    formatter = logging.Formatter("%(asctime)s [%(levelname)s] %(message)s")

    # File handler
    file_handler = logging.FileHandler(log_path)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)

    # Console handler
    console_handler = logging.StreamHandler(sys.stdout)
    console_handler.setFormatter(formatter)
    logger.addHandler(console_handler)

    return logger


def generate_regression_config(elicitation_config, output_dir, output_path):
    """
    Create a regression config YAML using outputs from plugin elicitation.
    """
    meta = elicitation_config["metadata"]
    run_paths = elicitation_config["paths"]

    # Construct relative paths to generated files
    responses_csv = os.path.join(output_dir, "responses.csv")
    endowments_csv = run_paths["endowments_csv"]  # pre-supplied

    regression_config = {
        "experiment": {
            "name": f"{meta['name']}_weight_determination",
            "description": f"Weight determination on responses elicited from: {meta['name']}",
            "seed": meta.get("seed", 101)
        },
        "paths": {
            "survey_csv": run_paths["survey_csv"],
            "survey_yaml": run_paths["survey_yaml"],
            "responses_csv": responses_csv,
            "endowments_csv": endowments_csv,
            "aggregate_json": run_paths.get("aggregate_json", None)
        },
        "lasso": elicitation_config.get("lasso", {
            "alpha_expr": "np.logspace(-5, 1, 20)",
            "max_iter": 10000,
            "validation": elicitation_config.get("validation", {
                "strategy": "cv",
                "cv_folds": 5
            }),
            "post_selection_refit": False
        }),
        "elasticnet": elicitation_config.get("elasticnet", {
            "alpha_expr": "np.logspace(-5, 1, 20)",
            "l1_ratio_expr": "np.linspace(0.1,1,10)",
            "max_iter": 10000,
            "validation": elicitation_config.get("validation", {
                "strategy": "cv",
                "cv_folds": 5
            }),
            "plot_style": "2D",
            "post_selection_refit": False
        }),
        "split_settings": elicitation_config.get("split_settings", {
            "use_proxy_only": True,
            "train_val_split": ["train", "valid"],
            "test_split": "test"
        }),
        "weight_assignment_model": elicitation_config.get("weight_assignment_model", "lasso"),
        "report": {"css": elicitation_config.get("report", {}).get("regression", "styles/lasso_report.css")}
    }

    with open(output_path, "w") as f:
        yaml.safe_dump(regression_config, f)

    return output_path


def run_plugin_pipeline(config_path, token_tracker = None, output_dir=None):
    """
    Run both parallel elicitation and regression experiment sequentially.
    """
    # === Load original config
    with open(config_path, "r") as f:
        elicitation_config = yaml.safe_load(f)

    # === Create/reuse output directory
    if output_dir is None:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        exp_name = elicitation_config["metadata"]["name"]
        output_dir = os.path.join("outputs", f"{exp_name}_{timestamp}")
        os.makedirs(output_dir, exist_ok=True)
    else:
        os.makedirs(output_dir, exist_ok=True)

    # === Setup logger
    log_path = os.path.join(output_dir, "pipeline.log")
    logger = setup_logger(log_path)
    logger.info(f"[Pipeline] Output directory: {output_dir}")

    # === Copy original config
    config_copy_path = os.path.join(output_dir, "original_config.yaml")
    shutil.copy(config_path, config_copy_path)

    # === Run Parallel Elicitation
    logger.info("[Pipeline] Running plugin (parallel) survey elicitation...")
    responses, _ = run_parallel_elicitation(config_copy_path, output_dir=output_dir, return_outputs=True, token_tracker=token_tracker)
    responses_code = responses.to_format(output_format="code")  # for validation
    survey = responses.survey

    # === Count Valid Questions
    logger.info("[Pipeline] Computing valid question counts after response cleaning...")

    summary_stats = {
        "valid_questions_by_split": {},
        "valid_eids_by_split": {},
        "total_questions": len(responses.responses),
        "total_endowments": len(responses)
    }

    for split in ["train", "valid", "test"]:
        problematic_answers, *_ = ResponseUtils.analyze_missing_mappings(
            code_responses=responses_code,
            answer_responses=responses,
            split=split,
            verbose=False
        )
        qids_to_remove = list(problematic_answers.keys())

        df_split = responses.get_matrix_by_split(split, dropna=True)
        df_clean = df_split.drop(index=qids_to_remove, errors="ignore")

        valid_eids_count = ResponseUtils.count_valid_endowments(responses_code, split=split)

        summary_stats["valid_questions_by_split"][split] = df_clean.shape[0]
        summary_stats["valid_eids_by_split"][split] = valid_eids_count

        logger.info(f"[Pipeline] Valid questions in {split}: {df_clean.shape[0]} "
                    f"(cleaned from {df_split.shape[0]} original questions)")
        logger.info(f"[Pipeline] Valid endowments in {split}: {valid_eids_count}")

    # === Count valid EIDs across all questions
    valid_eids_all = ResponseUtils.count_valid_endowments(responses_code, split=None)
    logger.info(f"[Pipeline] Valid endowments across all splits: {valid_eids_all}")
    summary_stats["valid_eids_all"] = valid_eids_all

    summary_path = os.path.join(output_dir, "valid_counts_summary.json")
    with open(summary_path, "w") as f:
        json.dump(summary_stats, f, indent=2)
    logger.info(f"[Pipeline] Valid count summary saved to {summary_path}")

    # === Generate Regression Config
    regression_config_path = os.path.join(output_dir, "regression_config.yaml")
    generate_regression_config(elicitation_config, output_dir, regression_config_path)
    logger.info("[Pipeline] Regression config saved.")

    # === Run Regression Experiment
    logger.info("[Pipeline] Running regression experiment...")
    run_regression_experiment(regression_config_path, output_dir)

    logger.info("[Pipeline] All stages complete.")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--config", type=str, required=True,
        help="Path to plugin (parallel elicitation) config YAML file."
    )
    args = parser.parse_args()
    run_plugin_pipeline(args.config)
