#!/usr/bin/env python
# -*- coding: utf-8 -*-
import sys
import os
import csv
import json
import time
import pickle
import logging
from itertools import product
from typing import Dict, Any, List, Tuple
from collections import defaultdict
import numpy as np
import torch
import traceback
from concurrent.futures import ProcessPoolExecutor, as_completed

from src.polar_chmc import ConstrainedPolarHMC, run_chains
from src.dataset import sqexp_cov_and_precision, logprob_fn_mbh, logprob_fn_mvmf, make_A
from src.utils import report_one_rep_chains_to_dfs
from src.misc import setup_rep_logger, StreamToLogger, pushover_notify


########################################################################################
DDIST = 'mvmf'
OUT_ROOT = os.path.expanduser(f"~/hmcstiefel/outputs_{DDIST}")

GRID = {
    "p_list": [40],
    "k_list": [10, 20, 30, 39],
    "rho_list": [1.0, 10.0],
    "c_list": [0.1, 0.01],          
    "L_list": [50],
    "step_size_list": [0.1],
}

REP_START = 0 
N_REPS = 50 - REP_START
REP_END = None 
N_WORKERS = 90            
RESUME_MISSING = True

# run params
R = 40.0
M = 1600
BURNIN = 2000
THIN = 1
N_SAMPLES = 2000            # kept draws
N_CHAINS = 4
BASE_SEEDS = [1, 2, 3, 4]

# barrier params
RHO_BARRIER = 0.2
BARRIER_POWER = 2
BETA_PHI_K = 10.0
LAM_PENALTY = 10.0
BARRIER_SIGMA = 10.0
TARGET_ACCEPT = 0.65
HIT_TIME = True

# report params
GAMMA_KEY = "Q"
CLAMP_EPS = 1e-12
ESS_METRIC = 'mean'


METHODS = [
    dict(name="softbarrier", use_reflection=True,  use_soft_barrier=True,  do_hungarian=True),
    dict(name="naive",       use_reflection=False, use_soft_barrier=False, do_hungarian=True),
]


# ============================================================
# utils
# ============================================================
def safe_mkdir(path: str):
    os.makedirs(path, exist_ok=True)

def ftag(x: float) -> str:
    x = float(x)
    if x == 0:
        return "0"
    ax = abs(x)
    if ax < 1e-2 or ax >= 1e3:
        s = f"{x:.1e}"
        s = s.replace(".", "p").replace("-", "m").replace("+", "")
        return s
    s = f"{x:.6g}"
    return s.replace(".", "p").replace("-", "m")

def setting_dir_name(method: str, p: int, k: int, rho: float, c: float, L: int, step: float) -> str:
    return (
        f"{method}"
        f"_p{p}_k{k}_rho{ftag(rho)}"
        f"_c{ftag(c)}"
        f"_L{L}"
        f"_step{ftag(step)}"
        f"_bp{BARRIER_POWER}"
        f"_rhob{ftag(RHO_BARRIER)}"
    )

def load_done_reps(results_csv: str) -> set:
    done = set()
    if not os.path.exists(results_csv):
        return done
    with open(results_csv, "r", newline="") as f:
        reader = csv.DictReader(f)
        for row in reader:
            try:
                if str(row.get("status", "")).strip().upper() == "OK":
                    done.add(int(row["rep"]))
            except Exception:
                continue
    return done

def append_rows(results_csv: str, fieldnames: List[str], rows: List[Dict[str, Any]]):
    safe_mkdir(os.path.dirname(results_csv))
    file_exists = os.path.exists(results_csv)
    with open(results_csv, "a", newline="") as f:
        w = csv.DictWriter(f, fieldnames=fieldnames)
        if not file_exists:
            w.writeheader()
        for r in rows:
            # missing fields -> ""
            for fn in fieldnames:
                if fn not in r:
                    r[fn] = ""
            w.writerow(r)

def pick_cols(df, must_have: List[str]):
    if df is None or getattr(df, "shape", (0, 0))[1] == 0:
        return df
    cols = []
    for c in df.columns:
        lc = str(c).lower()
        if all(tok.lower() in lc for tok in must_have):
            cols.append(c)
    if not cols:
        return df.iloc[:, 0:0]
    return df.loc[:, cols]

def df_firstrow_to_dict(df):
    if df is None:
        return {}
    if len(df.index) == 0:
        return {str(c): np.nan for c in df.columns}
    row0 = df.iloc[0]
    out = {}
    for c in df.columns:
        v = row0[c]
        try:
            out[str(c)] = float(v) if np.isscalar(v) and (v == v) else np.nan
        except Exception:
            out[str(c)] = str(v)
    return out

def write_setting_summary(summary_csv: str, rows: List[Dict[str, Any]]):
    fieldnames = ["setting", "ok", "fail", "total", "fail_rate"]
    append_rows(summary_csv, fieldnames, rows)


def run_one_job(job: Dict[str, Any]) -> Dict[str, Any]:
    rep = int(job["rep"])
    setting_dir = job["setting_dir"]

    logs_dir = os.path.join(setting_dir, "logs")
    pkls_dir = os.path.join(setting_dir, "pkls")
    safe_mkdir(logs_dir)
    safe_mkdir(pkls_dir)

    log_path = os.path.join(logs_dir, f"rep{rep}.log")
    pkl_path = os.path.join(pkls_dir, f"rep{rep}.pkl")     
    pt_path  = os.path.join(pkls_dir, f"rep{rep}.samples.pt")    
    logger = setup_rep_logger(log_path, rep)

 

    method = job["method"]
    p, k = int(job["p"]), int(job["k"])
    rho, c = float(job["rho"]), float(job["c"])
    L, step_size = int(job["L"]), float(job["step_size"])

    # seeds per rep
    seeds = [s + 1000 * rep for s in BASE_SEEDS]


    
    if DDIST == 'mbh':
        _, A = sqexp_cov_and_precision(p, rho)
    
    elif DDIST == 'mvmf':
        A = make_A(p, k)

    try:
        def make_sampler() -> ConstrainedPolarHMC:
            return ConstrainedPolarHMC(
                logprob_fn=lambda q: logprob_fn_mvmf(q, A.to(dtype=q.dtype, device=q.device), p,k, kappa=rho),
                step_size=step_size,
                num_steps=L,
                device="cpu",
                dtype=torch.float64,
                use_reflection=job["use_reflection"],
                use_soft_barrier=job["use_soft_barrier"],
                beta_phiK=BETA_PHI_K,
                lam_penalty=LAM_PENALTY,
                rho=RHO_BARRIER,
                barrier_sigma=BARRIER_SIGMA,
                barrier_power=BARRIER_POWER,
                hit_time=HIT_TIME,
                target_accept=TARGET_ACCEPT,
            )

        logger.info("=" * 80)
        logger.info(f"setting_dir={setting_dir}")
        logger.info(f"method={method} rep={rep}")
        logger.info(f"p={p} k={k} rho={rho} c={c} L={L} step_size={step_size}")
        logger.info(f"use_reflection={job['use_reflection']} use_soft_barrier={job['use_soft_barrier']}")
        logger.info(f"burnin={BURNIN} thin={THIN} n_samples={N_SAMPLES} n_chains={N_CHAINS}")
        logger.info(f"R={R} M={M} rho_barrier={RHO_BARRIER} barrier_power={BARRIER_POWER}")
        logger.info(f"seeds={seeds}")
        logger.info("=" * 80)

        t0 = time.time()

        samples, budget, meta = run_chains(
            make_sampler,
            R=R, r=p, u=k, c=c, M=M,
            n_kept=N_SAMPLES,
            burnin=BURNIN,
            thin=THIN,
            n_chains=N_CHAINS,
            seeds=seeds,
            verbose= False,
            kept_only=False,
            do_hungarian=job["do_hungarian"],
        )

        df_meta_b_chain, df_meta_b_sum, df_b_rhat, df_ess_b_eff = report_one_rep_chains_to_dfs(
            samples=samples,
            meta=meta,
            budget = budget, 
            gamma_key=GAMMA_KEY,
            clamp_eps=CLAMP_EPS,
            directional_dist=DDIST,
            ess_pick = ESS_METRIC,
            A_ref = A        
        )

     
        df_ess_mean = pick_cols(df_ess_b_eff, ["ess", "mean"])
        df_rhat_max = pick_cols(df_b_rhat, ["rhat", "max"])
        ess_dict = df_firstrow_to_dict(df_ess_mean)
        rhat_dict = df_firstrow_to_dict(df_rhat_max)

        elapsed = time.time() - t0
        logger.info(f"[DONE] elapsed_sec={elapsed:.3f}")
        logger.info(f"[ESS&MEAN cols] {list(df_ess_mean.columns)}")
        logger.info(f"[RHAT&MAX cols] {list(df_rhat_max.columns)}")

        torch.save(samples, pt_path)
        
        # pickle: 선택 DF만 저장
        with open(pkl_path, "wb") as f:
            pickle.dump(
                {
                    "rep": rep,
                    "elapsed_sec": elapsed,
                    'samples': samples,
                    'budget': budget,
                    'meta': meta,
                    'df_meta_chain': df_meta_b_chain,
                    'df_meta_sum': df_meta_b_sum,
                    'df_rhat': df_b_rhat,
                    'df_ess': df_ess_b_eff
                },
                f,
            )

       
        return {
            "setting_dir": setting_dir,           # main이 어디에 쓸지 결정
            "rep": rep,
            'status': "OK",
            "elapsed_sec": elapsed,
            "ess_mean_selected": json.dumps(ess_dict, ensure_ascii=False),
            "rhat_max_selected": json.dumps(rhat_dict, ensure_ascii=False),
            "error_type": "",
            "error_msg": ""
        }
    except Exception as e:
        elapsed = time.time() - t0
        etype = type(e).__name__
        emsg = str(e)

        tb = traceback.format_exc()
        try:
            with open(tb_path, "w") as f:
                f.write(tb)
        except Exception:
            tb_path = ""

        try:
            logger.exception("[FAILED]")
        except Exception:
            pass

        return {
            "setting_dir": setting_dir,
            "rep": rep,
            "status": "ERROR",
            "elapsed_sec": elapsed,
            "ess_mean_selected": "",
            "rhat_max_selected": "",
            "error_type": etype,
            "error_msg": emsg,
        }

def prepare_setting_folder(setting_dir: str, config_obj: Dict[str, Any]):
    safe_mkdir(setting_dir)
    safe_mkdir(os.path.join(setting_dir, "logs"))
    safe_mkdir(os.path.join(setting_dir, "pkls"))
    cfg_path = os.path.join(setting_dir, "config.json")
    if not os.path.exists(cfg_path):
        with open(cfg_path, "w") as f:
            json.dump(config_obj, f, ensure_ascii=False, indent=2)

def main():
    safe_mkdir(OUT_ROOT)
    torch.set_num_threads(1) 

    if REP_END is None:
        rep_list = list(range(REP_START, REP_START + N_REPS))
    else:
        rep_list = list(range(REP_START, REP_END))
   
    grid_base = list(product(
        GRID["p_list"],
        GRID["k_list"],
        GRID["rho_list"],
        GRID["L_list"],
        GRID["step_size_list"],
    ))

    total_settings = 0
    for method_cfg in METHODS:
        if method_cfg["name"] == "naive" or (not method_cfg["use_reflection"] and not method_cfg["use_soft_barrier"]):
            total_settings += len(grid_base) * 1
        else:
            total_settings += len(grid_base) * len(GRID["c_list"])

    print(f"[RUN] grid_points(base)={len(grid_base)}, methods={len(METHODS)}, total_settings={total_settings}")
    print(f"[RUN] N_REPS={N_REPS}, N_WORKERS={N_WORKERS}, RESUME={RESUME_MISSING}")

    all_jobs: List[Dict[str, Any]] = []
    results_csv_by_setting = {}

    for method_cfg in METHODS:
    
        if method_cfg["name"] == "naive" or (not method_cfg["use_reflection"] and not method_cfg["use_soft_barrier"]):
            c_list_eff = [GRID["c_list"][0]] 
        else:
            c_list_eff = GRID["c_list"]

        for (p, k, rho, L, step) in grid_base:
            for c in c_list_eff:
                sname = setting_dir_name(method_cfg["name"], p, k, rho, c, L, step)
                sdir = os.path.join(OUT_ROOT, sname)
                results_csv = os.path.join(sdir, "results.csv")
                results_csv_by_setting[sdir] = results_csv

                prepare_setting_folder(
                    sdir,
                    {
                        "method": method_cfg,
                        "grid_point": {"p": p, "k": k, "rho": rho, "c": c, "L": L, "step_size": step},
                        "run_params": {
                            "R": R, "M": M, "burnin": BURNIN, "thin": THIN,
                            "n_samples": N_SAMPLES, "n_chains": N_CHAINS,
                            "rho_barrier": RHO_BARRIER, "barrier_power": BARRIER_POWER,
                            "beta_phiK": BETA_PHI_K, "lam_penalty": LAM_PENALTY,
                            "barrier_sigma": BARRIER_SIGMA, "target_accept": TARGET_ACCEPT,
                            "hit_time": HIT_TIME,
                            "gamma_key": GAMMA_KEY, "directional_dist": DDIST,
                            "clamp_eps": CLAMP_EPS, "ess_pick": ESS_METRIC
                        },
                        "n_reps": N_REPS,
                    },
                )

                done = load_done_reps(results_csv) if RESUME_MISSING else set()

                for rep in rep_list:
                    if rep in done:
                        continue
                    all_jobs.append({
                        "setting_dir": sdir,
                        "method": method_cfg["name"],
                        "use_reflection": method_cfg["use_reflection"],
                        "use_soft_barrier": method_cfg["use_soft_barrier"],
                        "do_hungarian": method_cfg["do_hungarian"],
                        "rep": rep,
                        "p": p, "k": k, "rho": rho, "c": c, "L": L, "step_size": step,
                    })

    print(f"[RUN] total_jobs_to_run={len(all_jobs)}")
    if len(all_jobs) == 0:
        print("[RUN] nothing to do (all reps done?)")
        return

    fieldnames = ["rep", "status",  "elapsed_sec", "ess_mean_selected", "rhat_max_selected",
                  "error_type", "error_msg"]
    buffers = defaultdict(list)  
    flush_every = 10
    counts = defaultdict(lambda: {"ok": 0, "fail": 0})

    t0 = time.time()
    with ProcessPoolExecutor(max_workers=N_WORKERS) as ex:
        futs = [ex.submit(run_one_job, job) for job in all_jobs]

        total_jobs = len(futs)
        done_jobs = 0
        ok_jobs = 0
        fail_jobs = 0

        last_print_t = time.time()
        print_interval_sec = 600

        for fut in as_completed(futs):
            row = fut.result()
            sdir = row.pop("setting_dir")

            buffers[sdir].append(row)

            status = str(row.get("status", "")).strip().lower()
            done_jobs += 1
            if status == "ok":
                ok_jobs += 1
                counts[sdir]["ok"] += 1
            else:
                fail_jobs += 1
                counts[sdir]["fail"] += 1
            if len(buffers[sdir]) >= flush_every:
                append_rows(results_csv_by_setting[sdir], fieldnames, buffers[sdir])
                buffers[sdir].clear()

            now = time.time()
            if (now - last_print_t) >= print_interval_sec or done_jobs == total_jobs:
                remaining = total_jobs - done_jobs
                elapsed = now - t0
                print(
                    f"[PROGRESS] done={done_jobs}/{total_jobs} remaining={remaining} "
                    f"(ok={ok_jobs}, fail={fail_jobs}) elapsed_min={elapsed/60:.1f}",
                    flush=True
                )
                last_print_t = 0

    for sdir, rows in buffers.items():
        if rows:
            append_rows(results_csv_by_setting[sdir], fieldnames, rows)
            rows.clear()

    summary_csv = os.path.join(OUT_ROOT, "summary_by_setting.csv")
    summary_rows = []
    for sdir, d in counts.items():
        ok = d.get("ok", 0)
        fail = d.get("fail", 0)
        tot = ok + fail
        summary_rows.append({
            "setting": os.path.basename(sdir),
            "ok": ok,
            "fail": fail,
            "total": tot,
            "fail_rate": (fail / tot) if tot > 0 else "",
        })
    write_setting_summary(summary_csv, summary_rows)

    print(f"\n[SUMMARY] saved: {summary_csv}")
    for r in summary_rows:
        print(f"- {r['setting']}: ok={r['ok']} fail={r['fail']} fail_rate={r['fail_rate']}")
    print(f"\n[ALL DONE] elapsed_sec={time.time() - t0:.2f}")
    
   


if __name__ == "__main__":
    main()
