import hydra
from omegaconf import DictConfig, OmegaConf
import matplotlib.pyplot as plt
import numpy as np
import os
import sys
import csv

# Add src to python path to import modules
sys.path.append(os.path.join(os.path.dirname(__file__), "src"))
from src.utils import set_seed
from src.dataset import load_adult
from src.problem import Fairness_Learning
from src.solver import EG, run_LEN, spaco

@hydra.main(config_path="conf", config_name="config", version_base=None)
def main(cfg: DictConfig):
    print(f"Current working directory: {os.getcwd()}")
    print(f"Config:\n{cfg}")

    # Decide which solvers to run based on the high-level algo name.
    algo = cfg.get("algo", "spaco")
    if algo == "eg":
        algo_solvers = ["EG"]
    elif algo == "spaco":
        algo_solvers = ["SPACO"]
    elif algo == "npe":
        algo_solvers = ["NPE"]
    elif algo == "len":
        algo_solvers = ["LEN"]
    elif algo == "all":
        # Standard methods only (no custom sEG variants).
        algo_solvers = ["EG", "SPACO", "LEN"]
    else:
        raise ValueError(f"Unknown algo preset '{algo}'. "
                         f"Expected one of ['eg','spaco','npe','len','all'].")

    # Build a per-solver config by starting from the base cfg and
    # overlaying the corresponding block (eg/spaco/npe/len) when present.
    def make_solver_cfg(solver_name: str) -> DictConfig:
        base = OmegaConf.create(OmegaConf.to_container(cfg, resolve=True))
        OmegaConf.set_struct(base, False)
        if solver_name == "EG":
            preset = cfg.get("eg")
        elif solver_name == "SPACO":
            preset = cfg.get("spaco")
        elif solver_name == "NPE":
            preset = cfg.get("npe")
        elif solver_name == "LEN":
            preset = cfg.get("len")
        else:
            preset = None
        if preset is not None:
            for k, v in OmegaConf.to_container(preset, resolve=True).items():
                base[k] = v
        return base

    set_seed(cfg.seed)

    # Resolve data directory
    # If path is relative, it should be relative to the project root (original cwd)
    data_dir = hydra.utils.to_absolute_path(cfg.data_dir)
    print(f"Loading data from: {data_dir}")
    A, b, c = load_adult(data_dir)

    oracle = Fairness_Learning(A, b, c, lamb=cfg.lamb, gamma=cfg.gamma, beta=cfg.beta)
    z0 = np.zeros((oracle.d, 1))

    runs = {}

    def run_solver(key, runner):
        print(f"Running solver: {key}")
        return runner()

    # EG
    if "EG" in algo_solvers:
        cfg_EG = make_solver_cfg("EG")
        runs["EG"] = {
            "style": "-.b",
            "label": "EG",
            "time": None,
            "gnorm": None,
            "eval": None,
        }
        t, g, e = run_solver("EG", lambda: EG(oracle, z0, cfg_EG))
        runs["EG"].update({"time": t, "gnorm": g, "eval": e})

    # SPACO
    if "SPACO" in algo_solvers:
        cfg_SPACO = make_solver_cfg("SPACO")
        runs["SPACO"] = {
            "style": "-c",
            "label": "SPACO",
            "time": None,
            "gnorm": None,
            "eval": None,
        }
        t, g, e = run_solver("SPACO", lambda: spaco(oracle, z0, cfg_SPACO))
        runs["SPACO"].update({"time": t, "gnorm": g, "eval": e})

    # NPE (run_LEN with m=1)
    if "NPE" in algo_solvers:
        cfg_NPE = make_solver_cfg("NPE")
        runs["NPE"] = {
            "style": ":r",
            "label": "NPE",
            "time": None,
            "gnorm": None,
            "eval": None,
        }
        t, g, e = run_solver("NPE", lambda: run_LEN(oracle, z0, cfg_NPE, m=1))
        runs["NPE"].update({"time": t, "gnorm": g, "eval": e})

    # LEN (run_LEN with general m)
    if "LEN" in algo_solvers:
        cfg_LEN = make_solver_cfg("LEN")
        runs["LEN"] = {
            "style": "-k",
            "label": "LEN",
            "time": None,
            "gnorm": None,
            "eval": None,
        }
        t, g, e = run_solver("LEN", lambda: run_LEN(oracle, z0, cfg_LEN, m=cfg_LEN.m))
        runs["LEN"].update({"time": t, "gnorm": g, "eval": e})

    if not runs:
        print("No solvers selected; nothing to run.")
        return

    # Save Results
    os.makedirs(cfg.results_dir, exist_ok=True)

    # Plotting
    plt.rcParams['pdf.fonttype'] = 42
    plt.rcParams['ps.fonttype'] = 42
    plt.rc('font', size=21)

    # Composite figure and prints for core utility & fairness:
    # DPD = |P(Ŷ=1|Z=0) - P(Ŷ=1|Z=1)|,  EOD = |TPR0-TPR1| + |FPR0-FPR1|
    metrics = ['accuracy', 'dpd', 'eod']
    series = [
        (r["style"], r["label"], r["eval"])
        for r in runs.values()
        if r["eval"]
    ]
    if not series:
        print("No evaluation data available to plot.")
        return

    # Print final core metrics for each solver (keys: dpd, eod)
    print("Final core metrics (acc, DPD, EOD):")
    for style, label, data in series:
        if not data:
            continue
        last = data[-1]
        acc = float(last.get("accuracy", float("nan")))
        dpd = float(last.get("dpd", float("nan")))
        eod = float(last.get("eod", float("nan")))
        print(f"{label}: acc={acc:.6f}, DPD={dpd:.6f}, EOD={eod:.6f}")

    def export_metric_csv(metric_name: str, filename: str) -> None:
        # Collect per-solver sequences: list of (time, value)
        per_solver = {}
        all_times = set()
        for key, run in runs.items():
            eval_data = run.get("eval")
            if not eval_data:
                continue
            seq = []
            for rec in eval_data:
                t = rec.get("time", None)
                v = rec.get(metric_name, None)
                if t is None or v is None:
                    continue
                t_f = float(t)
                v_f = float(v)
                seq.append((t_f, v_f))
                all_times.add(t_f)
            if seq:
                seq.sort(key=lambda x: x[0])
                per_solver[key] = seq

        if not per_solver:
            return

        # Define solver column order (standard methods) and filter to those present.
        preferred_order = ["EG", "SPACO", "NPE", "LEN"]
        solvers = [s for s in preferred_order if s in per_solver.keys()]
        # Include any remaining solvers deterministically.
        for s in sorted(per_solver.keys()):
            if s not in solvers:
                solvers.append(s)

        times_sorted = sorted(all_times)
        csv_path = os.path.join(cfg.results_dir, filename)
        with open(csv_path, mode="w", newline="") as f_csv:
            writer = csv.writer(f_csv)
            writer.writerow(["time"] + solvers)

            # Per-solver cursor and last-seen value (for step-wise interpolation).
            indices = {s: 0 for s in solvers}
            last_vals = {s: "" for s in solvers}

            for t in times_sorted:
                row = [f"{t:.6f}"]
                for s in solvers:
                    seq = per_solver[s]
                    j = indices[s]
                    while j < len(seq) and seq[j][0] <= t:
                        last_vals[s] = f"{seq[j][1]:.6f}"
                        j += 1
                    indices[s] = j
                    row.append(last_vals[s])
                writer.writerow(row)
        print(f"Saved {metric_name} CSV to {csv_path}")

    export_metric_csv("accuracy", f"Fairness_accuracy_seed_{cfg.seed}.csv")
    export_metric_csv("dpd",      f"Fairness_dpd_seed_{cfg.seed}.csv")
    export_metric_csv("eod",      f"Fairness_eod_seed_{cfg.seed}.csv")

    n_cols = 3
    n_rows = int(np.ceil(len(metrics) / n_cols))
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 4 * n_rows))
    axes = np.array(axes).reshape(-1)

    def plot_metric(ax, metric, eval_data, label, style):
        if not eval_data:
            return False
        times = [item['time'] for item in eval_data]
        if metric == 'Iter':
            values = [item.get('iter', None) for item in eval_data]
        else:
            values = [item.get(metric, None) for item in eval_data]
        if any(v is None for v in values):
            return False
        ax.plot(times, values, style, label=label, linewidth=3)
        return True

    for idx, metric in enumerate(metrics):
        ax_m = axes[idx]
        ax_m.grid()
        plotted = False
        for style, label, data in series:
            if plot_metric(ax_m, metric, data, label, style):
                plotted = True
        ax_m.set_xlabel('time (s)')
        if metric == 'Iter':
            ax_m.set_yscale('log')
        if plotted:
            ax_m.legend(fontsize=12, loc='best')
        else:
            ax_m.text(0.5, 0.5, 'No data', ha='center', va='center')
        ax_m.set_title({'dpd': 'DPD', 'eod': 'EOD'}.get(metric, metric))

    # Hide any unused subplots
    for j in range(len(metrics), len(axes)):
        fig.delaxes(axes[j])

    fig.tight_layout()
    metric_path = os.path.join(cfg.results_dir, f'Fairness_metrics_seed_{cfg.seed}.png')
    fig.savefig(metric_path)
    print(f"Saved metric composite plot to {metric_path}")

if __name__ == "__main__":
    main()
