import argparse
import os
from collections import defaultdict
from pathlib import Path
from typing import Any, DefaultDict, Dict, Iterable, List, Optional, Set, Tuple

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import to_rgb
from tensorboard.backend.event_processing import event_accumulator


LOG_PATH = Path("log")

RunSeries = Tuple[List[int], List[float], str, Optional[int]]
RunData = Dict[str, Dict[str, List[RunSeries]]]


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--alg",
        nargs="+",
        default=["rffsac"],
        help="Algorithm folder(s) inside the log directory",
    )
    parser.add_argument("--env", default="HalfCheetah-v5", help="Environment folder inside the log directory")
    parser.add_argument("--seed", default=0, type=int, help="Seed folder to read (numeric)")
    parser.add_argument(
        "--seeds",
        nargs="+",
        type=int,
        help="Space-separated list of seed folders to average; overrides --seed when provided",
    )
    parser.add_argument(
        "--dir",
        default=0,
        type=int,
        help="Directory index between algorithm and seed (log/env/alg/dir/seed); negative lists all directories",
    )
    parser.add_argument(
        "--run",
        default=-1,
        type=int,
        help="Optional sub-run index inside each seed; negative includes every run directory",
    )
    parser.add_argument("--tags", nargs="*", default=None, help="Specific scalar tags to plot")
    parser.add_argument("--smooth", default=1, type=int, help="Moving-average window for smoothing")
    parser.add_argument("--dpi", default=150, type=int, help="Figure resolution")
    parser.add_argument("--output", default=None, help="Destination directory for the figures")
    parser.add_argument("--show", action="store_true", help="Display plots on screen as well")
    parser.add_argument(
        "--merge-events",
        action="store_true",
        help="Treat all event files within a run directory as a single run",
    )
    parser.add_argument("--max-steps", type=float, default=None, help="Upper bound on steps to include (e.g. 2e5)")
    parser.add_argument("--points", type=int, default=300, help="Number of samples for the averaging grid")
    parser.add_argument(
        "--ci",
        type=float,
        default=1.96,
        help="Standard-error multiplier for the confidence region; set 0 to disable",
    )
    parser.add_argument(
        "--show-runs",
        action="store_true",
        help="Overlay individual run curves alongside the averaged result",
    )
    parser.add_argument("--list-tags", action="store_true", help="List available scalar tags and exit")
    return parser.parse_args()


def moving_average(values: List[float], window: int) -> List[float]:
    if window <= 1:
        return values
    averaged = []
    acc = 0.0
    queue: List[float] = []
    for val in values:
        queue.append(val)
        acc += val
        if len(queue) > window:
            acc -= queue.pop(0)
        averaged.append(acc / len(queue))
    return averaged


def lighten_color(color: Any, amount: float = 0.6) -> Any:
    """Return a lightened variant of the input matplotlib color."""
    if color is None:
        return color
    try:
        r, g, b = to_rgb(color)
    except ValueError:
        return color
    amount = max(0.0, min(1.0, amount))
    r = 1 - (1 - r) * amount
    g = 1 - (1 - g) * amount
    b = 1 - (1 - b) * amount
    return (r, g, b)


def collect_run_paths(base_dir: Path, dir_index: int, merge_events: bool) -> Dict[str, Path]:
    """Resolve run sources (directories or individual event files) ready for TensorBoard loading."""
    if not base_dir.exists():
        raise FileNotFoundError(f"Missing log directory: {base_dir}")

    run_dirs: Dict[str, Path] = {}
    targets: List[Path]

    subdirs = sorted([p for p in base_dir.iterdir() if p.is_dir()])

    if dir_index < 0:
        targets = subdirs if subdirs else [base_dir]
    else:
        targets = [base_dir / str(dir_index)]

    def register_events(container: Path, label_prefix: str) -> None:
        event_files = sorted(container.glob("events.out.tfevents.*"))
        if not event_files:
            return
        if merge_events:
            run_dirs[label_prefix] = container
            return
        if len(event_files) == 1:
            run_dirs[label_prefix] = event_files[0]
            return
        for file_path in event_files:
            run_dirs[f"{label_prefix}/{file_path.name}"] = file_path

    for candidate in targets:
        if candidate.is_dir():
            register_events(candidate, candidate.name)
        elif candidate.is_file() and candidate.name.startswith("events.out.tfevents"):
            run_dirs[candidate.name] = candidate

    if not run_dirs:
        register_events(base_dir, base_dir.name)
        if not run_dirs and dir_index >= 0:
            available = sorted(p.name for p in base_dir.iterdir() if p.is_dir())
            raise FileNotFoundError(
                f"No run directory '{dir_index}' under {base_dir}. Available: {', '.join(available) or 'none'}"
            )

    if not run_dirs:
        raise FileNotFoundError(f"No event files found under {base_dir}")

    return run_dirs


def append_series(
    steps: Iterable[int],
    values: Iterable[float],
    extra_steps: Iterable[int],
    extra_values: Iterable[float],
) -> Tuple[List[int], List[float], Optional[int]]:
    base_steps = [int(step) for step in steps]
    base_values = [float(val) for val in values]
    extra_pairs = [(int(step), float(val)) for step, val in zip(extra_steps, extra_values)]
    separation_step: Optional[int] = base_steps[0] if base_steps else None
    combined_pairs = list(zip(base_steps, base_values)) + extra_pairs
    if not combined_pairs:
        return [], [], separation_step
    combined_pairs.sort(key=lambda pair: pair[0])
    ordered_steps = [step for step, _ in combined_pairs]
    ordered_values = [val for _, val in combined_pairs]
    return ordered_steps, ordered_values, separation_step

def enumerate_run_directories(dir_root: Path, dir_index: int) -> List[Path]:
    try:
        entries = list(dir_root.iterdir())
    except FileNotFoundError as error:
        raise SystemExit(f"Missing algorithm directory: {dir_root}") from error

    run_dirs = sorted(
        [entry for entry in entries if entry.is_dir() and entry.name != "figures"],
        key=lambda path: path.name,
    )

    if dir_index < 0:
        if not run_dirs:
            raise SystemExit(f"No run directories found under {dir_root}")
        return run_dirs

    dir_candidate = dir_root / str(dir_index)
    if dir_candidate.exists():
        return [dir_candidate]

    available = [path.name for path in run_dirs]
    raise SystemExit(
        f"No directory '{dir_index}' under {dir_root}. Available: {', '.join(available) or 'none'}"
    )


def load_scalars(run_dir: Path, tags: Optional[Iterable[str]]) -> Tuple[Dict[str, Tuple[List[int], List[float]]], List[str]]:
    accumulator = event_accumulator.EventAccumulator(str(run_dir), size_guidance={"scalars": 0})
    accumulator.Reload()

    available = accumulator.Tags().get("scalars", [])
    selected = list(tags) if tags else sorted(available)

    data: Dict[str, Tuple[List[int], List[float]]] = {}
    for tag in selected:
        if tag not in available:
            continue
        events = accumulator.Scalars(tag)
        steps = [event.step for event in events]
        values = [event.value for event in events]
        data[tag] = (steps, values)

    return data, available


def plot_scalars(
    run_data: RunData,
    run_sources: Dict[str, Path],
    tags: Iterable[str],
    output_dir: Path,
    smooth_window: int,
    dpi: int,
    show: bool,
    max_steps: Optional[int],
    num_points: int,
    ci_multiplier: float,
    show_runs: bool,
    env_name: str,
) -> bool:
    output_dir.mkdir(parents=True, exist_ok=True)

    num_points = max(2, num_points)
    plotted_any = False
    best_statistics: DefaultDict[str, Dict[str, Tuple[float, float]]] = defaultdict(dict)

    for tag in tags:
        series_by_alg: Dict[str, List[Tuple[np.ndarray, np.ndarray, str, Optional[float]]]] = {}
        algorithm_cutoffs: Dict[str, Optional[float]] = {}
        best_values_by_alg: Dict[str, List[float]] = {}
        for algorithm, tag_runs in run_data.items():
            runs_for_tag = tag_runs.get(tag)
            if not runs_for_tag:
                continue

            processed_runs: List[Tuple[np.ndarray, np.ndarray, str, Optional[float]]] = []
            per_run_best: List[float] = []
            run_cutoffs: List[float] = []
            for steps, values, run_label, separation_step in runs_for_tag:
                if separation_step is not None:
                    filtered_steps: List[int] = []
                    filtered_values: List[float] = []
                    for step, value in zip(steps, values):
                        if step >= separation_step:
                            filtered_steps.append(step)
                            filtered_values.append(value)
                    if not filtered_steps:
                        continue
                    steps = filtered_steps
                    values = filtered_values
                    separation_step = None
                smoothed = moving_average(values, smooth_window)
                trimmed_steps = steps
                if smooth_window > 1:
                    trimmed_steps = steps[: len(smoothed)]

                steps_array = np.asarray(trimmed_steps, dtype=float)
                values_array = np.asarray(smoothed, dtype=float)

                if steps_array.size > 1:
                    sort_order = np.argsort(steps_array)
                    steps_array = steps_array[sort_order]
                    values_array = values_array[sort_order]
                    _, unique_indices = np.unique(steps_array, return_index=True)
                    steps_array = steps_array[unique_indices]
                    values_array = values_array[unique_indices]

                if max_steps is not None:
                    mask = steps_array <= max_steps
                    steps_array = steps_array[mask]
                    values_array = values_array[mask]

                if steps_array.size == 0:
                    continue

                processed_runs.append((steps_array, values_array, run_label, float(separation_step) if separation_step is not None else None))
                max_value = np.nanmax(values_array) if values_array.size else np.nan
                if not np.isnan(max_value):
                    per_run_best.append(float(max_value))
                if separation_step is not None:
                    run_cutoffs.append(float(separation_step))

            if processed_runs:
                series_by_alg[algorithm] = processed_runs
                algorithm_cutoffs[algorithm] = max(run_cutoffs) if run_cutoffs else None
                if per_run_best:
                    best_values_by_alg[algorithm] = per_run_best

        if not series_by_alg:
            continue

        plt.figure()
        plotted_tag = False

        effective_max_step = max_steps
        effective_min_step: Optional[float] = None
        if effective_max_step is None:
            max_candidate: Optional[float] = None
            for series in series_by_alg.values():
                for steps_array, _, _, _ in series:
                    if steps_array.size == 0:
                        continue
                    last_step = float(steps_array[-1])
                    if max_candidate is None or last_step > max_candidate:
                        max_candidate = last_step
                    first_step = float(steps_array[0])
                    if effective_min_step is None or first_step < effective_min_step:
                        effective_min_step = first_step
            effective_max_step = max_candidate if max_candidate is not None else 0.0
        else:
            for series in series_by_alg.values():
                for steps_array, _, _, _ in series:
                    if steps_array.size == 0:
                        continue
                    first_step = float(steps_array[0])
                    if effective_min_step is None or first_step < effective_min_step:
                        effective_min_step = first_step

        if effective_min_step is None:
            effective_min_step = 0.0

        if effective_max_step < effective_min_step:
            effective_max_step = effective_min_step

        if effective_max_step <= 0:
            plt.close()
            continue

        grid = np.linspace(effective_min_step, effective_max_step, num_points)
        color_map = plt.cm.get_cmap("tab10")
        tag_cutoffs: List[float] = []
        for color_index, (algorithm, series) in enumerate(series_by_alg.items()):
            if hasattr(color_map, "N") and color_map.N > 1:
                color_fraction = (color_index % color_map.N) / (color_map.N - 1)
                color = color_map(color_fraction)
            else:
                color = None
            cutoff = algorithm_cutoffs.get(algorithm)
            if cutoff is not None:
                tag_cutoffs.append(cutoff)
            values_matrix = []
            empty_runs: List[str] = []
            for idx, (steps_array, values_array, run_label, separation_step) in enumerate(series):
                interpolated = np.interp(grid, steps_array, values_array, left=np.nan, right=np.nan)
                values_matrix.append(interpolated)
                if np.all(np.isnan(interpolated)):
                    empty_runs.append(run_label)
                if show_runs:
                    random_color = lighten_color(color)
                    label = None
                    if separation_step is not None:
                        random_mask = steps_array < separation_step
                        train_mask = steps_array >= separation_step
                        if np.any(train_mask):
                            plt.plot(
                                steps_array[train_mask],
                                values_array[train_mask],
                                alpha=0.3,
                                linewidth=1,
                                label=label,
                                color=color,
                            )
                            label = None
                        if np.any(random_mask):
                            plt.plot(
                                steps_array[random_mask],
                                values_array[random_mask],
                                alpha=0.3,
                                linewidth=1,
                                label=label,
                                color=random_color,
                            )
                    else:
                        plt.plot(
                            steps_array,
                            values_array,
                            alpha=0.3,
                            linewidth=1,
                            label=label,
                            color=color,
                        )

            values_matrix = np.vstack(values_matrix)
            counts = np.sum(~np.isnan(values_matrix), axis=0)
            if not np.any(counts):
                if empty_runs:
                    for label in empty_runs:
                        source_path = run_sources.get(label)
                        source_text = str(source_path) if source_path is not None else "unknown"
                        print(
                            f"Warning: run '{label}' from {source_text} has no usable data for tag '{tag}' (algorithm '{algorithm}')."
                        )
                continue

            if empty_runs:
                for label in empty_runs:
                    source_path = run_sources.get(label)
                    source_text = str(source_path) if source_path is not None else "unknown"
                    print(
                        f"Warning: run '{label}' from {source_text} contributed only NaNs for tag '{tag}' (algorithm '{algorithm}')."
                    )

            mean_values = np.nanmean(values_matrix, axis=0)
            std_values = np.nanstd(values_matrix, axis=0)
            denom = np.sqrt(counts.astype(float))
            std_error = np.divide(
                std_values,
                denom,
                out=np.zeros_like(std_values),
                where=counts > 0,
            )
            random_color = lighten_color(color)
            if cutoff is not None:
                random_mask = grid < cutoff
                train_mask = grid >= cutoff
                label = algorithm
                if np.any(train_mask):
                    plt.plot(grid[train_mask], mean_values[train_mask], label=label, linewidth=2, color=color)
                    label = None
                if np.any(random_mask):
                    plt.plot(grid[random_mask], mean_values[random_mask], label=label, linewidth=2, color=random_color)
            else:
                plt.plot(grid, mean_values, label=algorithm, linewidth=2, color=color)
            plotted_tag = True

            ci_band = std_error * ci_multiplier if ci_multiplier > 0 else np.zeros_like(std_error)
            ci_mask = counts > 1 if ci_multiplier > 0 else np.zeros_like(counts, dtype=bool)
            if ci_multiplier > 0 and np.any(ci_mask):
                lower_full = mean_values - ci_band
                upper_full = mean_values + ci_band
                if cutoff is not None:
                    random_ci_mask = ci_mask & (grid < cutoff)
                    train_ci_mask = ci_mask & (grid >= cutoff)
                    if np.any(train_ci_mask):
                        plt.fill_between(
                            grid[train_ci_mask],
                            lower_full[train_ci_mask],
                            upper_full[train_ci_mask],
                            alpha=0.2,
                            color=color,
                        )
                    if np.any(random_ci_mask):
                        plt.fill_between(
                            grid[random_ci_mask],
                            lower_full[random_ci_mask],
                            upper_full[random_ci_mask],
                            alpha=0.2,
                            color=random_color,
                        )
                else:
                    lower = lower_full[ci_mask]
                    upper = upper_full[ci_mask]
                    plt.fill_between(
                        grid[ci_mask],
                        lower,
                        upper,
                        alpha=0.2,
                        color=color,
                    )

            zero_columns = np.where(counts == 0)[0]
            if zero_columns.size > 0:
                first_missing = float(grid[zero_columns[0]])
                last_missing = float(grid[zero_columns[-1]])
                longest_run_label = None
                longest_run_step = -np.inf
                for steps_array, _, run_label, _ in series:
                    if steps_array.size and steps_array[-1] > longest_run_step:
                        longest_run_step = float(steps_array[-1])
                        longest_run_label = run_label
                source_path = run_sources.get(longest_run_label or "") if longest_run_label else None
                source_text = str(source_path) if source_path is not None else "unknown"
                if first_missing == last_missing:
                    span_text = f"step {first_missing:.0f}"
                else:
                    span_text = f"steps {first_missing:.0f}-{last_missing:.0f}"
                print(
                    "Warning: all runs for tag "
                    f"'{tag}' (algorithm '{algorithm}') lack data at {span_text}. "
                    f"Longest run: {longest_run_label} (max step {longest_run_step:.0f}, source={source_text})."
                )

        for algorithm, run_maxima in best_values_by_alg.items():
            if run_maxima and algorithm not in best_statistics[tag]:
                maxima_array = np.asarray(run_maxima, dtype=float)
                best_statistics[tag][algorithm] = (
                    float(np.mean(maxima_array)),
                    float(np.std(maxima_array)),
                )

        if not plotted_tag:
            plt.close()
            continue

        cutoff_x = max(tag_cutoffs) if tag_cutoffs else None
        if cutoff_x is not None:
            print(f"Separation step for tag '{tag}': {cutoff_x:.0f}")
            plt.axvline(cutoff_x, linestyle="--", color="0.4", linewidth=1.2)

        plt.title(f"{env_name}")
        plt.xlabel("Step")
        plt.ylabel(tag)
        plt.xlim(effective_min_step, effective_max_step if effective_max_step is not None else grid[-1])
        plt.legend()
        plt.grid(True, linestyle="--", linewidth=0.5)
        filename = tag.replace("/", "_").replace(":", "-")
        figure_path = output_dir / f"{filename}.png"
        plt.tight_layout()
        plt.savefig(figure_path, dpi=dpi)
        if show:
            plt.show()
        plt.close()
        plotted_any = True

    if best_statistics:
        print("Best mean and std per tag and algorithm:")
        for tag, alg_stats in best_statistics.items():
            print(f"  Tag '{tag}':")
            for algorithm, (mean_value, std_value) in sorted(alg_stats.items()):
                print(f"    {algorithm}: mean={mean_value:.6f}, std={std_value:.6f}")

    return plotted_any


def main() -> None:
    args = parse_args()

    seeds = list(dict.fromkeys(args.seeds if args.seeds else [args.seed]))

    algorithms = list(dict.fromkeys(args.alg))

    run_data: DefaultDict[str, DefaultDict[str, List[RunSeries]]] = defaultdict(
        lambda: defaultdict(list)
    )
    run_dirs_for_report: Dict[str, Path] = {}
    all_tags: List[str] = []
    missing_locations: List[Tuple[str, str, int, Exception]] = []
    loaded_alg_dirs: Set[Tuple[str, str]] = set()
    loaded_dir_seeds: Set[Tuple[str, str, int]] = set()

    for algorithm in algorithms:
        dir_root = LOG_PATH / args.env / algorithm
        try:
            dir_root_exists = dir_root.exists()
        except OSError as error:
            raise SystemExit(f"Unable to access directory {dir_root}: {error}") from error
        if not dir_root_exists:
            raise SystemExit(f"Missing algorithm directory: {dir_root}")

        dir_candidates = enumerate_run_directories(dir_root, args.dir)
        print('candidates:', dir_candidates)

        for dir_path in dir_candidates:
            dir_name = dir_path.name
            for seed in seeds:
                seed_dir = dir_path / str(seed)
                try:
                    seed_run_paths = collect_run_paths(seed_dir, args.run, args.merge_events)
                except FileNotFoundError as error:
                    missing_locations.append((algorithm, dir_name, seed, error))
                    continue

                print(seed_dir, 'seed_run_paths:', seed_run_paths)

                seeded_any = False
                for run_name, run_dir in seed_run_paths.items():
                    data, available = load_scalars(run_dir, args.tags)
                    run_label = f"{algorithm}/{dir_name}/{seed}/{run_name}"
                    run_dirs_for_report[run_label] = run_dir

                    for tag in available:
                        if tag not in all_tags:
                            all_tags.append(tag)
                        if tag in data:
                            steps, values = data[tag]
                            run_data[algorithm][tag].append((steps, values, run_label, None))
                    seeded_any = True

                if seeded_any:
                    loaded_alg_dirs.add((algorithm, dir_name))
                    loaded_dir_seeds.add((algorithm, dir_name, seed))

    if not run_data:
        if missing_locations:
            missing_report = os.linesep.join(
                f"alg {algorithm}, dir {dir_name}, seed {seed}: {error}"
                for algorithm, dir_name, seed, error in missing_locations
            )
            raise SystemExit(
                f"No runs found for requested directories/seeds. Details:{os.linesep}{missing_report}"
            )
        raise SystemExit("No runs found for requested directories/seeds.")

    selected_tags = list(args.tags) if args.tags else sorted(all_tags)

    if args.list_tags:
        if selected_tags:
            for tag in sorted(all_tags):
                print(tag)
        return

    if not selected_tags:
        if missing_locations:
            missing_report = os.linesep.join(
                f"alg {algorithm}, dir {dir_name}, seed {seed}: {error}"
                for algorithm, dir_name, seed, error in missing_locations
            )
            raise SystemExit(f"No scalar tags found to plot. Missing entries:{os.linesep}{missing_report}")
        raise SystemExit("No scalar tags found to plot.")

    if args.output:
        output_dir = Path(args.output)
    else:
        output_root = LOG_PATH / "Figures" / args.env
        alg_component = "_".join(algorithms)
        output_root = output_root / alg_component
        dir_names = sorted({dir_name for _, dir_name in loaded_alg_dirs})
        if len(dir_names) == 1:
            output_root = output_root / dir_names[0]
        seeds_in_dirs = sorted({seed for _, _, seed in loaded_dir_seeds})
        if len(seeds_in_dirs) == 1:
            output_root = output_root / str(seeds_in_dirs[0])
        output_dir = output_root
    max_steps = int(args.max_steps) if args.max_steps is not None else None
    plotted_any = plot_scalars(
        run_data,
        run_dirs_for_report,
        selected_tags,
        output_dir,
        args.smooth,
        args.dpi,
        args.show,
        max_steps,
        args.points,
        args.ci,
        args.show_runs,
        args.env,
    )

    if not plotted_any:
        raise SystemExit("Requested tags were not found in the selected runs.")

    print(f"Saved figures for tags: {', '.join(selected_tags)}")
    print(f"Output directory: {output_dir}")
    printed_runs = os.linesep.join(f"{label}: {run_dir}" for label, run_dir in run_dirs_for_report.items())
    print(f"Runs plotted:{os.linesep}{printed_runs}")
    if missing_locations:
        missing_report = os.linesep.join(
            f"alg {algorithm}, dir {dir_name}, seed {seed}: {error}"
            for algorithm, dir_name, seed, error in missing_locations
        )
        print(f"Skipped entries:{os.linesep}{missing_report}")


if __name__ == "__main__":
    main()

    