import argparse
import io
import logging
import pathlib
from datetime import date
from typing import Any, List, Literal

import cairosvg
import dvc.api
import numpy as np
import pandas as pd
import yaml
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from matplotlib.dates import date2num
from matplotlib.figure import Figure
from matplotlib.offsetbox import AnnotationBbox, HPacker, OffsetImage, TextArea
from PIL import Image

import src.utils.plots
from src.plot.logistic import (
    _get_title,
    _process_agent_summaries,
    fit_trendline,
    plot_horizon_graph,
    plot_trendline,
)

logger = logging.getLogger(__name__)


def _add_watermark(fig: Figure, logo_path: pathlib.Path) -> None:
    pass


def _add_time_markers(ax: Axes) -> None:
    """Add reference time markers on the y-axis."""
    time_markers = {
        15 / 60: "Answer question",
        2: "Count words in passage",
        10: "Find fact on web",
        49: "Train classifier",
        4 * 60: "Train adversarially robust image model",
    }

    for seconds, label in time_markers.items():
        ax.axhline(
            y=seconds,
            color="#2c7c58",
            linestyle="-",
            alpha=0.4,
            zorder=1,
            xmin=0,
            xmax=0.01,
        )
        ax.text(
            0.02,
            seconds,
            label,
            transform=ax.get_yaxis_transform(),
            verticalalignment="center",
            horizontalalignment="left",
            fontsize=10,
            bbox=dict(facecolor="white", edgecolor="none", alpha=0.7, pad=0.5),
        )


def _add_individual_labels(
    axs: list[Axes],
    agent_summaries: pd.DataFrame,
    script_params: dict[str, Any],
    scale: Literal["log", "linear"],
) -> None:
    # Label first and last points
    sorted_data = agent_summaries.sort_values("release_date")
    rename_map = script_params["rename_legend_labels"]

    def rename_with_default(x: str) -> str:
        return rename_map.get(x, x)

    sorted_data["agent"] = sorted_data["agent"].map(rename_with_default)
    first_point = sorted_data.iloc[0]
    last_point = sorted_data.iloc[-1]

    log_agents_to_label = [
        (first_point, (4, -6)),
        (sorted_data.iloc[1], (8, 0)),
        (sorted_data.iloc[2], (8, 0)),
        (sorted_data.iloc[3], (-6, 0)),  # gpt 4
        (sorted_data.iloc[5], (4, -2)),  # gpt 4
        # (sorted_data.iloc[6], (-4, 5)),  # c 3.5
        (sorted_data.iloc[9], (4, -10)),
        (last_point, (8, -10)),
    ]

    linear_agents_to_label = [
        (first_point, (0, 6)),
        (sorted_data.iloc[1], (0, 6)),
        (sorted_data.iloc[2], (0, 6)),
        (sorted_data.iloc[3], (0, 6)),  # gpt 4
        (sorted_data.iloc[4], (0, 6)),
        (sorted_data.iloc[5], (6, -2)),  # gpt 4o
        (sorted_data.iloc[6], (-12, 0.1)),
        (sorted_data.iloc[7], (-12, 0.1)),
        (sorted_data.iloc[8], (-12, 0.1)),
        (sorted_data.iloc[9], (-12, 0.1)),
        (last_point, (-12, 0.1)),
    ]
    if scale == "linear":
        agents_to_label = linear_agents_to_label
    else:
        agents_to_label = log_agents_to_label

    for point, label_pos in agents_to_label:
        axs[0].annotate(
            point["agent"],
            xy=(
                pd.to_datetime(point["release_date"]),
                point[f"p{script_params.get('success_percent', 50)}"],
            ),
            xytext=label_pos,
            textcoords="offset points",
            ha="left" if label_pos[0] > 0 else "right",
            va="bottom" if label_pos[1] > 0 else "top",
            fontsize=12,
            color="grey",
        )


def add_bootstrap_confidence_region(
    ax: Axes,
    bootstrap_results: pd.DataFrame,
    release_dates: dict[str, dict[str, date]],
    after_date: str,
    max_date: pd.Timestamp,
    confidence_level: float,
    exclude_agents: list[str],
) -> List[float]:
    """Add bootstrap confidence intervals and region to an existing plot.

    Args:
        ax: matplotlib axes
        bootstrap_results: DataFrame with columns for each agent containing p50s
        release_dates: Dictionary mapping agent names to release dates

    Returns:
        List of doubling times from the trendlines
    """
    dates = release_dates["date"]
    focus_agents = sorted(list(dates.keys()), key=lambda x: dates[x])
    focus_agents = [agent for agent in focus_agents if agent not in exclude_agents]
    doubling_times = []
    # Calculate and plot the confidence region
    n_bootstraps = len(bootstrap_results)

    # Create time points for prediction
    time_points = pd.date_range(
        start=pd.to_datetime(after_date),
        end=max_date,
        freq="D",
    )
    predictions = np.zeros((n_bootstraps, len(time_points)))
    # Calculate predictions for each bootstrap sample
    assert n_bootstraps > 0
    for sample_idx in range(n_bootstraps):
        # Collect valid p50 values and dates for this sample
        valid_p50s = []
        valid_dates = []

        for agent in focus_agents:
            if f"{agent}_p50" not in bootstrap_results.columns:
                continue

            p50 = pd.to_numeric(
                bootstrap_results[f"{agent}_p50"].iloc[sample_idx], errors="coerce"
            )

            if pd.isna(p50) or np.isinf(p50) or p50 < 1e-3:
                continue

            valid_p50s.append(p50)
            valid_dates.append(dates[agent])

        if len(valid_p50s) < 2:
            continue

        reg, _ = fit_trendline(
            pd.Series(valid_p50s),
            pd.Series(pd.to_datetime(valid_dates)),
            log_scale=True,
        )
        time_x = date2num(time_points)
        predictions[sample_idx] = np.exp(reg.predict(time_x.reshape(-1, 1)))
        slope = reg.coef_[0]
        doubling_time = np.log(2) / slope
        if doubling_time > 0:  # Only include positive doubling times
            doubling_times.append(doubling_time)

    # Calculate confidence bounds
    low_q = (1 - confidence_level) / 2
    high_q = 1 - low_q
    lower_bound = np.nanpercentile(predictions, low_q * 100, axis=0)
    upper_bound = np.nanpercentile(predictions, high_q * 100, axis=0)

    # Plot confidence region
    ax.fill_between(
        time_points,
        lower_bound,
        upper_bound,
        color="#d2dfd7",
        alpha=0.4,
    )

    return doubling_times


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--fig-name", type=str, required=True)
    parser.add_argument("--input-file", type=pathlib.Path, required=True)
    parser.add_argument("--agent-summaries-file", type=pathlib.Path, required=True)
    parser.add_argument("--release-dates", type=pathlib.Path, required=True)
    parser.add_argument("--output-file", type=pathlib.Path, required=True)
    parser.add_argument("--log-level", type=str, default="INFO")
    parser.add_argument(
        "--y-scale",
        choices=["log", "linear"],
        default="log",
        help="Scale type for y-axis",
    )
    args = parser.parse_args()

    logging.basicConfig(
        level=args.log_level.upper(),
        format="%(asctime)s - %(levelname)s - %(message)s",
    )

    params = dvc.api.params_show(stages="plot_bootstrap_ci", deps=True)
    plot_params = params["plots"]
    script_params = params["figs"]["plot_logistic_regression"][args.fig_name]

    confidence_level = 0.95

    # Load data
    bootstrap_results = pd.read_csv(args.input_file)
    agent_summaries = pd.read_csv(args.agent_summaries_file)
    release_dates = yaml.safe_load(args.release_dates.read_text())
    agent_summaries = _process_agent_summaries(
        script_params["exclude_agents"], agent_summaries, release_dates
    )

    subtitle = script_params["subtitle"] or ""
    title = _get_title(script_params, script_params.get("success_percent", 50))

    # Create plot with two subplots
    if script_params.get("show_boxplot", False):
        fig, axs = plt.subplots(1, 2, width_ratios=[6, 1], figsize=(12, 6))
    else:
        fig, axs = plt.subplots(1, 1, figsize=(12, 6))
        axs = [axs]

    if "linear_overrides" in script_params and args.y_scale == "linear":
        for key in script_params["linear_overrides"]:
            if isinstance(script_params["linear_overrides"][key], dict):
                if key not in script_params:
                    script_params[key] = {}
                script_params[key] = {
                    **script_params[key],
                    **script_params["linear_overrides"][key],
                }
            else:
                script_params[key] = script_params["linear_overrides"][key]
    if "plot_style_overrides" in script_params:
        for key in script_params["plot_style_overrides"]:
            plot_params[key] = plot_params[key] if key in plot_params else {}
            for subkey in script_params["plot_style_overrides"][key]:
                if subkey not in plot_params[key]:
                    plot_params[key][subkey] = {}
                plot_params[key][subkey] = {
                    **plot_params[key][subkey],
                    **script_params["plot_style_overrides"][key][subkey],
                }

    end_date = script_params["x_lim_end"]
    upper_y_lim = script_params["upper_y_lim"]
    trendline_end_date = script_params["x_lim_end"]
    if args.y_scale == "linear":
        end_date = agent_summaries["release_date"].max() + pd.Timedelta(days=60)
        upper_y_lim = agent_summaries["p50"].max() * 1.2
        trendline_end_date = agent_summaries["release_date"].max()
    plot_horizon_graph(
        plot_params,
        agent_summaries,
        title=title,
        release_dates=release_dates,
        runs_df=pd.DataFrame(),  # Empty DataFrame since we don't need task distribution
        subtitle=subtitle,
        lower_y_lim=script_params["lower_y_lim"],
        upper_y_lim=upper_y_lim,
        x_lim_start=script_params["x_lim_start"],
        x_lim_end=end_date,
        include_task_distribution="none",
        weight_key=script_params["weighting"],
        trendlines=None,
        exclude_agents=script_params["exclude_agents"],
        fig=fig,
        success_percent=script_params.get("success_percent", 50),
        y_scale=args.y_scale,
        script_params=script_params,
        marker_override="o",
    )

    if script_params.get("show_example_tasks", False):
        _add_time_markers(axs[0])

    if script_params["individual_labels"]:
        _add_individual_labels(axs, agent_summaries, script_params, args.y_scale)

    doubling_times = None
    if not script_params.get("hide_trendline", False):
        reg, score = fit_trendline(
            agent_summaries[f"p{script_params.get('success_percent', 50)}"],
            pd.to_datetime(agent_summaries["release_date"]),
            log_scale=True,
        )
        dashed_outside = (
            agent_summaries["release_date"].min(),
            agent_summaries["release_date"].max(),
        )
        annotation = plot_trendline(
            ax=axs[0],
            dashed_outside=dashed_outside,
            plot_params=plot_params,
            trendline_params={
                "after_date": script_params["trendlines"][0]["line_start_date"],
                "color": "#2c7c58",
                "line_start_date": None,
                "line_end_date": trendline_end_date,
                "display_r_squared": True,
                "data_file": None,
                "styling": None,
                "caption": None,
                "skip_annotation": False,
                "fit_type": "exponential",
            },
            reg=reg,
            score=score,
            log_scale=True,
        )
        if (
            not script_params.get("hide_regression_info", False)
            and annotation is not None
        ):
            # Position annotation at bottom right
            axs[0].annotate(
                **annotation,
                xy=(0.98, 0.02),
                xycoords="axes fraction",
                ha="right",
                va="bottom",
            )

        doubling_times = add_bootstrap_confidence_region(
            ax=axs[0],
            bootstrap_results=bootstrap_results,
            release_dates=release_dates,
            after_date=script_params["trendlines"][0]["line_start_date"],
            max_date=trendline_end_date,
            confidence_level=confidence_level,
            exclude_agents=script_params["exclude_agents"],
        )
        lower_bound = np.percentile(doubling_times, 2.5)
        upper_bound = np.percentile(doubling_times, 97.5)
        median = np.median(doubling_times)
        logger.info(
            f"95% CI for doubling times: [{lower_bound:.0f}, {upper_bound:.0f}] days (+{(upper_bound - median) / median:.0%}/-{(median - lower_bound) / median:.0%})"
        )

    # Add confidence region and existing scatter points to legend
    handles, labels = axs[0].get_legend_handles_labels()
    sorted_items = sorted(
        zip(handles, labels), key=lambda x: plot_params["legend_order"].index(x[1])
    )
    handles, labels = zip(*sorted_items)
    rename_map = script_params["rename_legend_labels"]
    labels = [rename_map.get(label, label) for label in labels]

    if script_params["individual_labels"]:
        axs[0].get_legend().remove()
    else:
        axs[0].legend(
            handles,
            labels,
            loc="upper left",
            fontsize=script_params["legend_fontsize"],
            frameon=script_params["legend_frameon"],
        )
    axs[0].grid(script_params.get("show_grid", True))
    axs[0].grid(which="minor", linestyle=":", alpha=0.6, color="#d2dfd7")

    if script_params.get("show_boxplot", False):
        assert doubling_times is not None  # show_boxplot implies not hide_trendline
        axs[1].boxplot([doubling_times], vert=True, showfliers=False, whis=(10, 90))
        axs[1].set_xticklabels(["Doubling times\n(days)"])
        axs[1].set_ylim(0, 365)

    src.utils.plots.save_or_open_plot(args.output_file, params["plot_format"])


if __name__ == "__main__":
    main()
