from __future__ import annotations

import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d
from typing import Dict, List, Optional, Tuple


def plot_regret_vs_cost_from_json(
    json_files: List[str],
    problem_name: str,
    methods_list: List[str],
    seed_list: List[int],
    *,
    n_grid: Optional[int] = None,
    cost_key: str = "cost",
    regret_key: str = "best_nsga2_regret",
    budget_key: str = "budget",
    budget: Optional[float] = None,
    title: Optional[str] = None,
    use_sem_shading: bool = False,
    normalized_bounds: Optional[Tuple[float, float]] = None,
    step_interval: Optional[int] = None,
    method_colors: Optional[Dict[str, str]] = None,
    figsize: tuple = (5, 4)
) -> Dict[str, Dict[str, np.ndarray]]:
    """
    Plot regret vs COST for multiple methods from JSON files.
    
    Args:
        json_files: List of JSON file paths to load
        problem_name: Name of the problem
        methods_list: List of method names to compare
        seed_list: List of random seeds to aggregate over
        n_grid: Number of points in interpolation grid. If None, set to common_budget
        cost_key: Column name for cost values
        regret_key: Column name for regret values
        budget_key: Column name for budget values
        budget: If provided, clip all runs to this target budget
        title: Custom plot title
        use_sem_shading: If True, use SEM for error bars; otherwise use std
        normalized_bounds: If provided, normalize cost axis to (low, high)
        step_interval: If provided, only plot every Nth point
        method_colors: Dictionary mapping method names to colors
    
    Returns:
        Dictionary of results for each method
    """
    # ---------- 0) Build full_history_df from JSON files ----------
    frames = []
    for json_file in json_files:
        with open(json_file, 'r') as f:
            data = json.load(f)
        
        if not data:
            continue
            
        # Get method and seed from first entry
        method = data[0].get('method', 'unknown')
        seed = data[0].get('seed', 0)
        run_name = f"{problem_name}_{method}_run.{seed}"
        
        # Convert to DataFrame
        df = pd.DataFrame(data)
        df["run_name"] = run_name
        frames.append(df)
    
    if not frames:
        raise ValueError("No data found in JSON files")
    
    full_history_df = pd.concat(frames, ignore_index=True)
    
    # ---------- 1) First pass: identify invalid seeds ----------
    invalid_seeds = set()
    discarded_runs: List[str] = []

    for method in methods_list:
        for seed in seed_list:
            run_name = f"{problem_name}_{method}_run.{seed}"
            df = full_history_df[full_history_df["run_name"] == run_name].copy()

            if df.empty:
                continue  # Skip if run not found

            # Required columns
            missing = [k for k in (cost_key, regret_key, budget_key) if k not in df.columns]
            if missing:
                raise ValueError(f"Missing columns {missing} in {run_name}")

            # Extract arrays
            costs = df[cost_key].to_numpy(dtype=float)
            
            # Budget validation
            bud_vals = df[budget_key].to_numpy(dtype=float)
            bud_vals = bud_vals[~np.isnan(bud_vals)]
            if bud_vals.size == 0:
                raise ValueError(f"No finite '{budget_key}' in {run_name}")
            unique_b = np.unique(bud_vals)
            if unique_b.size != 1:
                raise ValueError(f"Non-constant '{budget_key}' in {run_name}: {unique_b.tolist()}")
            run_budget = float(unique_b[0])

            effective_budget = budget if budget is not None else run_budget
            
            # Check if any points within budget
            m = costs <= effective_budget
            costs_clipped = costs[m]
            if costs_clipped.size == 0:
                if seed not in invalid_seeds:
                    print(f"Warning: Seed {seed} has 0 points ≤ effective budget {effective_budget} in {method}; excluding seed {seed} from ALL methods.")
                    invalid_seeds.add(seed)
                    discarded_runs.append(f"Seed {seed} (0 points ≤ effective budget {effective_budget} in {method})")
                continue
            
            if budget is not None:
                max_cost_in_run = costs.max()
                if max_cost_in_run < budget:
                    raise ValueError(
                        f"Run '{run_name}' has maximum cost {max_cost_in_run:.2f} which is less than "
                        f"specified target budget {budget:.2f}."
                    )

    # ---------- 2) Second pass: load valid runs only ----------
    valid_seed_list = [seed for seed in seed_list if seed not in invalid_seeds]
    all_payloads: Dict[str, List[dict]] = {m: [] for m in methods_list}
    all_budgets: List[float] = []

    for method in methods_list:
        for seed in valid_seed_list:
            run_name = f"{problem_name}_{method}_run.{seed}"
            df = full_history_df[full_history_df["run_name"] == run_name].copy()
            
            if df.empty:
                continue

            # Extract arrays
            costs = df[cost_key].to_numpy(dtype=float)
            regrets = df[regret_key].to_numpy(dtype=float)

            # Budget validation
            bud_vals = df[budget_key].to_numpy(dtype=float)
            bud_vals = bud_vals[~np.isnan(bud_vals)]
            unique_b = np.unique(bud_vals)
            run_budget = float(unique_b[0])

            # Check monotone costs
            if not np.all(np.diff(costs) >= 0):
                raise ValueError(f"{run_name}: non-monotone costs.")

            effective_budget = budget if budget is not None else run_budget
            
            # Clip within budget
            m = costs <= effective_budget
            costs, regrets = costs[m], regrets[m]

            # Pad if only 1 point
            if costs.size == 1:
                print(f"Warning: {run_name}: only 1 point ≤ effective budget {effective_budget}; padding.")
                costs = np.array([costs[0], effective_budget])
                regrets = np.array([regrets[0], regrets[0]])

            all_payloads[method].append(
                {"seed": seed, "costs": costs, "regrets": regrets, "budget": effective_budget}
            )
            all_budgets.append(run_budget)

    # Print discarded seeds
    if invalid_seeds:
        print(f"\nDiscarded seeds {sorted(invalid_seeds)} from ALL methods.")
        print(f"Using {len(valid_seed_list)}/{len(seed_list)} seeds for comparison.")

    if len(all_budgets) == 0:
        raise ValueError("No runs loaded; cannot determine common budget.")

    # ---------- 3) Determine common budget ----------
    if budget is not None:
        common_budget = float(budget)
        print(f"Using specified target budget: {common_budget}")
    else:
        unique_budgets = {b for b in all_budgets}
        if len(unique_budgets) != 1:
            raise ValueError(f"Budget mismatch. Unique budgets: {sorted(unique_budgets)}")
        common_budget = float(next(iter(unique_budgets)))
    
    if n_grid is None:
        n_grid = int(common_budget)
    
    grid = np.linspace(1.0, common_budget, n_grid)

    # Normalize if requested
    if normalized_bounds is not None:
        low, high = normalized_bounds
        x_plot = low + (grid / common_budget) * (high - low)
        xlabel = f"Cost (normalized to [{low}, {high}])"
    else:
        x_plot = grid
        xlabel = "Cost"

    # ---------- 4) Interpolate and aggregate ----------
    results: Dict[str, Dict[str, np.ndarray]] = {}
    plt.figure(figsize=figsize)

    for method in methods_list:
        if not all_payloads[method]:
            print(f"Warning: No runs found for method '{method}'. Skipping.")
            continue

        per_run = []
        for rp in all_payloads[method]:
            costs_sub = rp["costs"]
            regrets_sub = rp["regrets"]
            
            if step_interval is not None:
                indices = np.arange(0, len(costs_sub), step_interval)
                if indices[-1] != len(costs_sub) - 1:
                    indices = np.append(indices, len(costs_sub) - 1)
                costs_sub = costs_sub[indices]
                regrets_sub = regrets_sub[indices]
            
            f = interp1d(
                costs_sub,
                regrets_sub,
                kind="linear",
                bounds_error=False,
                fill_value=(regrets_sub[0], regrets_sub[-1]),
                assume_sorted=True,
            )
            per_run.append(f(grid))

        arr = np.vstack(per_run)
        mean = arr.mean(axis=0)
        std = arr.std(axis=0)
        sem = std / np.sqrt(arr.shape[0]) if arr.shape[0] > 1 else np.zeros_like(std)

        color = method_colors.get(method) if method_colors else None
        
        # Only show uncertainty bands if we have multiple runs
        if arr.shape[0] > 1:
            spread = sem if use_sem_shading else std
            spread_label = "SEM" if use_sem_shading else "Std. Dev."
            line = plt.plot(x_plot, mean, label=f"{method} (mean)", color=color, linewidth=2)
            # Use the same color as the line for the uncertainty band
            line_color = line[0].get_color()
            plt.fill_between(x_plot, mean - spread, mean + spread, alpha=0.2,
                             label=f"{method} ± {spread_label}", color=line_color)
        else:
            # Single run - just plot the line without uncertainty
            plt.plot(x_plot, mean, label=f"{method}", color=color, linewidth=2)

        results[method] = {
            "grid": grid,
            "x_plot": x_plot,
            "mean": mean,
            "std": std,
            "sem": sem,
            "per_run": arr,
            "budget": common_budget,
        }

    # ---------- 5) Final formatting ----------
    plt.xlabel(xlabel)
    plt.ylabel(regret_key)
    ttl = title or f"{regret_key} vs. Cost"
    if normalized_bounds is not None:
        ttl += f" (normalized to [{normalized_bounds[0]}, {normalized_bounds[1]}])"
    plt.title(ttl)
    plt.grid(alpha=0.3)
    plt.legend(fontsize=9)
    plt.tight_layout()
    plt.show()

    return results