from __future__ import annotations

from dataclasses import dataclass
from typing import Dict, Iterable, List, Sequence, Tuple

import numpy as np
import matplotlib.pyplot as plt


def pareto_scatter(results: Dict[str, Sequence[Tuple[float, float]]], path: str):
    plt.figure(figsize=(7, 5))
    for name, pts in results.items():
        arr = np.array(pts)
        if len(arr) == 0:
            continue
        plt.scatter(arr[:, 0], arr[:, 1], label=name)
    plt.xlabel("Return")
    plt.ylabel("Cost")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(path)
                
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     


def hist(values: Sequence[float], path: str, title: str = "Histogram"):
    plt.figure(figsize=(6, 4))
    plt.hist(values, bins=30, alpha=0.8)
    plt.title(title)
    plt.tight_layout()
    plt.savefig(path)


def compute_pareto(points: Sequence[Tuple[float, float]]) -> List[Tuple[float, float]]:
    pts = np.array(points)
    order = np.argsort(pts[:, 0])
    pts = pts[order]
    hull = []
    best_cost = float("inf")
    for r, c in pts:
        if c < best_cost:
            hull.append((float(r), float(c)))
            best_cost = c
    return hull


def merge_results(all_results: Dict[str, Sequence[Tuple[float, float]]]) -> Dict[str, Tuple[float, float]]:
    out = {}
    for name, pts in all_results.items():
        if not pts:
            continue
        arr = np.array(pts)
        out[name] = (float(np.mean(arr[:, 0])), float(np.mean(arr[:, 1])))
    return out


@dataclass
class PlotConfig:
    xlabel: str = "Return"
    ylabel: str = "Cost"
    legend: bool = True
    grid: bool = True
    figsize: Tuple[int, int] = (7, 5)


def scatter_with_config(results: Dict[str, Sequence[Tuple[float, float]]], path: str, cfg: PlotConfig):
    plt.figure(figsize=cfg.figsize)
    for name, pts in results.items():
        arr = np.array(pts)
        if len(arr) == 0:
            continue
        plt.scatter(arr[:, 0], arr[:, 1], label=name)
    plt.xlabel(cfg.xlabel)
    plt.ylabel(cfg.ylabel)
    if cfg.legend:
        plt.legend()
    if cfg.grid:
        plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(path)


def plot_lines(series: Dict[str, Sequence[float]], path: str, cfg: PlotConfig):
    plt.figure(figsize=cfg.figsize)
    for name, ys in series.items():
        xs = np.arange(len(ys))
        plt.plot(xs, ys, label=name)
    plt.xlabel("Step")
    plt.ylabel(cfg.ylabel)
    if cfg.legend:
        plt.legend()
    if cfg.grid:
        plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(path)
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
