#!/usr/bin/env python3
import argparse
import json
import logging
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from typing import Dict, List
import warnings
import shutil

try:
    import statsmodels.api as sm
    from jinja2 import Template
except ImportError:
    logging.error(
        "Required libraries not found. Please run reproduce_all.sh to install dependencies."
    )
    exit(1)
from analysis_plots.ablation_plotter import run_ablation_analysis
from analysis_plots.kaplan_meier_plotter import create_kaplan_meier_plots
from analysis_plots.hazard_plotter import create_hazard_plots
from analysis_plots.tool_usage_plotter import create_tool_usage_plots
from analysis_plots.peer_exposure_plotter import create_peer_exposure_plots
from analysis_plots.event_distribution_plotter import create_event_distribution_plot
from analysis_plots.statistical_modeler import run_statistical_models
from analysis_plots.forest_plotter import create_hazard_forest_plot
from analysis_plots.summary_metrics_calculator import calculate_summary_metrics
from analysis_plots.bulk.suite_survival_plotter import create_suite_survival_plot

warnings.filterwarnings("ignore")
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
plt.style.use("default")
sns.set_palette("husl")
SUITE_COMPARISON_TEMPLATE = """
<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>{{ experiment_name }}</title>
    <style>
        body { font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif; margin: 40px auto; max-width: 1200px; line-height: 1.6; color: #333; }
        h1, h2, h3 { line-height: 1.2; color: #000; }
        h1 { font-size: 2.2em; }
        h2 { font-size: 1.8em; border-bottom: 2px solid #eee; padding-bottom: 10px; margin-top: 40px; }
        img { max-width: 100%; height: auto; border-radius: 8px; box-shadow: 0 4px 8px rgba(0,0,0,0.1); margin-top: 20px; }
        .container { display: flex; flex-wrap: wrap; justify-content: center; }
        .plot { width: 100%; margin-bottom: 30px; }
        .header { background-color: #f8f9fa; padding: 20px; border-radius: 8px; margin-bottom: 30px; border: 1px solid #dee2e6; }
    </style>
</head>
<body>
    <div class="header">
        <h1>{{ experiment_name }}</h1>
        <p><strong>Analysis Date:</strong> {{ analysis_date }}</p>
        <p><strong>Total Experiments Analyzed:</strong> {{ total_experiments }}</p>
        <p><strong>Total Agent Trajectories:</strong> {{ "%.0f"|format(total_trajectories) }}</p>
    </div>
    <h2>Cross-Model Survival Comparison</h2>
    <div class="container">
        <div class="plot">
            <img src="suite_survival_comparison.png" alt="Suite Survival Comparison Plot">
        </div>
    </div>
</body>
</html>
"""
MAIN_REPORT_TEMPLATE = """
<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>{{ experiment_name }}</title>
    <style>
        body { font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif; margin: 40px auto; max-width: 1200px; line-height: 1.6; color: #333; }
        h1, h2, h3 { line-height: 1.2; color: #000; }
        img { max-width: 100%; height: auto; border-radius: 8px; box-shadow: 0 4px 8px rgba(0,0,0,0.1); margin-top: 20px; }
        .container { display: flex; flex-wrap: wrap; justify-content: space-between; }
        .plot { width: 48%; margin-bottom: 30px; }
        .full-width { width: 100%; }
        .header { background-color: #f8f9fa; padding: 20px; border-radius: 8px; margin-bottom: 30px; border: 1px solid #dee2e6; }
        pre { background-color: #f8f9fa; padding: 15px; border-radius: 5px; border: 1px solid #dee2e6; white-space: pre-wrap; word-wrap: break-word; }
    </style>
</head>
<body>
    <div class="header">
        <h1>{{ experiment_name }}</h1>
        <p><strong>Analysis Date:</strong> {{ analysis_date }}</p>
        <p><strong>Total Trajectories:</strong> {{ "%.0f"|format(total_trajectories) }}</p>
        <p><strong>Total Cells:</strong> {{ "%.0f"|format(total_cells) }}</p>
        <p><strong>Risk Horizon:</strong> {{ risk_horizon }} steps</p>
    </div>

    <h2>Key Survival Metrics</h2>
    <div class="container">
        <div class="full-width">
            <h3>Overall Results</h3>
            <table>
                <tr><th style="text-align: left;">Metric</th><th style="text-align: left;">Value</th></tr>
                <tr><td>Initial Eat Rate (t=0 impulse)</td><td>{{ "%.3f"|format(metrics.overall_results.initial_eat_rate) }}</td></tr>
                <tr><td>Total Eat Rate</td><td>{{ "%.3f"|format(metrics.overall_results.total_eat_rate) }}</td></tr>
                <tr><td>Winners Rate</td><td>{{ "%.3f"|format(metrics.overall_results.winners_rate) }}</td></tr>
                <tr><td>Median Time-to-Eat</td><td>{{ metrics.overall_results.median_tte }}</td></tr>
                <tr><td>RMST (Expected Waiting Steps)</td><td>{{ metrics.overall_results.rmst }}</td></tr>
                <tr><td>Total Winners</td><td>{{ metrics.overall_results.total_winners }}</td></tr>
            </table>
            <h3>Data Quality</h3>
            <table>
                <tr><th style="text-align: left;">Metric</th><th style="text-align: left;">Value</th></tr>
                <tr><td>Data Quality Rate</td><td>{{ metrics.data_quality.data_quality_rate }}</td></tr>
                <tr><td>Invalid Outcome Rate</td><td>{{ metrics.data_quality.invalid_outcome_rate }}</td></tr>
                <tr><td>Valid Agent Trajectories</td><td>{{ metrics.data_quality.valid_agent_trajectories }}</td></tr>
                <tr><td>Invalid Agent Trajectories</td><td>{{ metrics.data_quality.invalid_agent_trajectories }}</td></tr>
            </table>
        </div>
    </div>

    <h2>Overall Summary Plots</h2>
    <div class="container"><div class="plot full-width"><img src="suite_survival_comparison.png" alt="Suite Survival Comparison"></div><div class="plot"><img src="overall_hazard_rates.png" alt="Overall Hazard Rates"></div><div class="plot"><img src="overall_event_distribution.png" alt="Overall Event Distribution"></div></div>
    <h2>Survival Analysis by Experimental Factors</h2>
    <div class="container"><div class="plot full-width"><img src="kaplan_meier_curves.png" alt="Kaplan-Meier Curves by Factors"></div><div class="plot full-width"><img src="hazard_rates.png" alt="Hazard Rates by Factors"></div><div class="plot full-width"><img src="event_distribution.png" alt="Event Distribution by Factors"></div></div>
    <h2>Ablation Analysis: Survival Curves</h2>
    <div class="container"><div class="plot full-width"><img src="ablation_survival_curves.png" alt="Ablation Survival Curves"></div></div>
    <h2>Behavioral Analysis</h2>
    <div class="container"><div class="plot full-width"><img src="tool_usage_patterns.png" alt="Tool Usage Patterns"></div><div class="plot full-width"><img src="tool_usage_by_policy.png" alt="Tool Usage by Policy"></div><div class="plot full-width"><img src="tool_usage_ablation.png" alt="Tool Usage Ablation"></div></div>
    <h2>Social Influence Analysis</h2>
    <div class="container"><div class="plot full-width"><img src="peer_exposure_patterns.png" alt="Peer Exposure Patterns"></div><div class="plot full-width"><img src="peer_exposure_by_factors.png" alt="Peer Exposure By Factors"></div></div>
    <h2>Statistical Model Results</h2>
    <div class="container">
        <div class="plot full-width">
            <h3>Discrete-Time Hazard Model Effects</h3>
            <img src="hazard_forest_plot.png" alt="Hazard Model Forest Plot">
        </div>
        <div class="full-width">
            <h3>Discrete-Time Hazard Model Summary</h3>
            <pre><code>{{ model_results.hazard_model.summary }}</code></pre>
        </div>
    </div>
</body>
</html>
"""


class ExperimentAnalyzer:
    def __init__(
        self,
        suite_dir: Path,
        output_dir: Path,
        social_influence: bool = False,
        survival_prob_range: List[float] = [0.0, 1.05],
    ):
        self.suite_dir = Path(suite_dir)
        self.analysis_dir = Path(output_dir)
        self.social_influence = social_influence
        self.survival_prob_range = survival_prob_range
        self.suite_name = self.suite_dir.name
        self.timestamp = pd.Timestamp.now().strftime("%Y-%m-%d_%H-%M-%S")
        logger.info(f"Initializing analyzer for experiment suite: {self.suite_name}")

    def load_data(self):
        logger.info(f"Scanning for experiment runs in: {self.suite_dir}")
        all_dfs = {
            "agent_outcomes": [],
            "step_level_data": [],
            "cell_aggregates": [],
            "cell_summary": [],
        }
        experiment_dirs = [
            d
            for d in self.suite_dir.iterdir()
            if d.is_dir() and (d / "aggregated_data").exists()
        ]
        for exp_dir in experiment_dirs:
            logger.info(f"Loading data for experiment: {exp_dir.name}")
            agg_dir = exp_dir / "aggregated_data"
            for name in all_dfs.keys():
                df = pd.read_csv(agg_dir / f"{name}.csv")
                df["experiment_id"] = exp_dir.name
                all_dfs[name].append(df)
        self.agent_outcomes = pd.concat(all_dfs["agent_outcomes"], ignore_index=True)
        self.step_level_data = pd.concat(all_dfs["step_level_data"], ignore_index=True)
        self.cell_aggregates = pd.concat(all_dfs["cell_aggregates"], ignore_index=True)
        self.cell_summary = pd.concat(all_dfs["cell_summary"], ignore_index=True)
        logger.info(f"Loaded data across {len(experiment_dirs)} experiments.")

    def generate_report(self, metrics: Dict, model_results: Dict):
        logger.info("Generating analysis reports...")
        self.analysis_dir.mkdir(exist_ok=True, parents=True)
        template = Template(MAIN_REPORT_TEMPLATE)
        context = {
            "experiment_name": f"Analysis Suite: {self.suite_name}",
            "analysis_date": self.timestamp,
            "total_trajectories": len(self.agent_outcomes),
            "total_cells": len(self.cell_summary),
            "risk_horizon": (
                self.agent_outcomes["risk_horizon"].iloc[0]
                if len(self.agent_outcomes) > 0
                else "N/A"
            ),
            "model_results": model_results,
            "metrics": metrics,
        }
        with open(self.analysis_dir / "analysis_report.html", "w") as f:
            f.write(template.render(context))
        suite_template = Template(SUITE_COMPARISON_TEMPLATE)
        suite_context = {
            "experiment_name": f"Suite Comparison: {self.suite_name}",
            "analysis_date": self.timestamp,
            "total_trajectories": len(self.agent_outcomes),
            "total_experiments": len(self.agent_outcomes["experiment_id"].unique()),
        }
        with open(self.analysis_dir / "suite_comparison_report.html", "w") as f:
            f.write(suite_template.render(suite_context))
        for name, data in [
            ("summary_metrics.json", metrics),
            ("model_results.json", model_results),
        ]:
            with open(self.analysis_dir / name, "w") as f:
                json.dump(data, f, indent=2, default=str)
        logger.info(f"✓ Generated analysis reports in: {self.analysis_dir}")

    def run_complete_analysis(self):
        logger.info("Starting complete analysis pipeline...")
        try:
            self.load_data()

            # Ensure analysis directory exists before creating plots
            self.analysis_dir.mkdir(exist_ok=True, parents=True)

            create_kaplan_meier_plots(
                self.agent_outcomes,
                self.analysis_dir,
                self.social_influence,
                self.survival_prob_range,
            )
            create_hazard_plots(
                self, self.cell_aggregates, self.analysis_dir, self.social_influence
            )
            create_tool_usage_plots(
                self.step_level_data, self.analysis_dir, self.social_influence
            )
            create_peer_exposure_plots(self.step_level_data, self.analysis_dir)
            create_event_distribution_plot(
                self, self.cell_aggregates, self.analysis_dir, self.social_influence
            )
            create_suite_survival_plot(
                self.agent_outcomes,
                self.analysis_dir,
                self.social_influence,
                self.survival_prob_range,
            )
            metrics = calculate_summary_metrics(self.cell_summary, self.agent_outcomes)
            model_results = run_statistical_models(
                self.cell_aggregates, self.step_level_data
            )
            create_hazard_forest_plot(model_results, self.analysis_dir)
            run_ablation_analysis(
                self.agent_outcomes,
                self.step_level_data,
                self.analysis_dir,
                self.social_influence,
                self.survival_prob_range,
            )
            self.generate_report(metrics, model_results)
            logger.info("Analysis pipeline completed successfully!")
            return True
        except Exception as e:
            logger.error(f"Analysis pipeline failed: {e}", exc_info=True)
            return False


def main():
    parser = argparse.ArgumentParser(description="Run a standalone analysis suite.")
    parser.add_argument(
        "--suite_directory",
        required=True,
        help="Path to the directory of experiment folders.",
    )
    parser.add_argument(
        "--output_directory", required=True, help="Path to save the analysis report."
    )
    parser.add_argument(
        "--survival-prob-range",
        type=float,
        nargs=2,
        default=[0.4, 1.05],
        help="Y-axis range for survival plots.",
    )
    args = parser.parse_args()
    analyzer = ExperimentAnalyzer(
        suite_dir=Path(args.suite_directory),
        output_dir=Path(args.output_directory),
        social_influence=True,
        survival_prob_range=args.survival_prob_range,
    )
    success = analyzer.run_complete_analysis()
    exit(0 if success else 1)


if __name__ == "__main__":
    main()
