#!/usr/bin/env python
# -*- coding: utf-8 -*-
import sys
import os
import csv
import json
import time
import pickle
import logging
import traceback
from itertools import product
from typing import Dict, Any, List
from collections import defaultdict

import numpy as np
import torch
from concurrent.futures import ProcessPoolExecutor, as_completed

from src.polar_chmc import ConstrainedPolarHMC, run_chains
from src.dataset import simulate_latent_subspace, make_Phi_diag, logprob_ppca_macg
from src.utils import report_one_rep_chains_to_dfs
from src.misc import setup_rep_logger, StreamToLogger, pushover_notify

OUT_ROOT = os.path.expanduser("~/hmcstiefel/outputs_ppca")

GRID = {
    "p_list": [10, 20, 30, 40],
    "u_list": [4],
    "N_list": [30],        
    "eps_list": [1e-1, 5e-2, 1e-2],  
    "c_list": [0.01],
    "L_list": [50],
    "step_size_list": [0.1],
}
REP_START = 0  
N_REPS = 30 - REP_START
REP_END = None 
N_WORKERS = 90
RESUME_MISSING = True

R = 15.0
M = 200
BURNIN = 2000
THIN = 1
N_SAMPLES = 2000
N_CHAINS = 4
BASE_SEEDS = [1, 2, 3, 4]

# barrier params
RHO_BARRIER = 0.2
BARRIER_POWER = 4
BETA_PHI_K = 10.0
LAM_PENALTY = 2.0
BARRIER_SIGMA = 10.0
TARGET_ACCEPT = 0.80
HIT_TIME = True

# PPCA knobs
SIGMA_Y = 1.0
BETA = 1.0 / 100.0

R_TRUE_BASE = [5.0, 0.5]   

# report params
GAMMA_KEY = "Q"
CLAMP_EPS = 1e-12
ESS_METRIC = 'mean'

# methods
METHODS = [
    dict(name="softbarrier", use_reflection=True,  use_soft_barrier=True,  do_hungarian=True),
    dict(name="naive",       use_reflection=False, use_soft_barrier=False, do_hungarian=True),
]


def safe_mkdir(path: str):
    os.makedirs(path, exist_ok=True)

def ftag(x: float) -> str:
    x = float(x)
    if x == 0:
        return "0"
    ax = abs(x)
    if ax < 1e-2 or ax >= 1e3:
        s = f"{x:.1e}"
        s = s.replace(".", "p").replace("-", "m").replace("+", "")
        return s
    s = f"{x:.6g}"
    return s.replace(".", "p").replace("-", "m")

def setting_dir_name(method: str, p: int, u: int, N: int, eps: float, c: float, L: int, step: float) -> str:
    return (
        f"{method}"
        f"_p{p}_u{u}_N{N}"
        f"_eps{ftag(eps)}"
        f"_c{ftag(c)}"
        f"_L{L}"
        f"_step{ftag(step)}"
        f"_bp{BARRIER_POWER}"
        f"_rhob{ftag(RHO_BARRIER)}"
    )


def load_done_reps(results_csv: str) -> set:
    done = set()
    if not os.path.exists(results_csv):
        return done
    with open(results_csv, "r", newline="") as f:
        reader = csv.DictReader(f)
        for row in reader:
            try:
                if str(row.get("status", "")).strip().upper() == "OK":
                    done.add(int(row["rep"]))
            except Exception:
                continue
    return done

def append_rows(results_csv: str, fieldnames: List[str], rows: List[Dict[str, Any]]):
    safe_mkdir(os.path.dirname(results_csv))
    file_exists = os.path.exists(results_csv)
    with open(results_csv, "a", newline="") as f:
        w = csv.DictWriter(f, fieldnames=fieldnames)
        if not file_exists:
            w.writeheader()
        for r in rows:
            for fn in fieldnames:
                if fn not in r:
                    r[fn] = ""
            w.writerow(r)

def pick_cols(df, must_have: List[str]):
    if df is None or getattr(df, "shape", (0, 0))[1] == 0:
        return df
    cols = []
    for c in df.columns:
        lc = str(c).lower()
        if all(tok.lower() in lc for tok in must_have):
            cols.append(c)
    if not cols:
        return df.iloc[:, 0:0]
    return df.loc[:, cols]

def df_firstrow_to_dict(df):
    if df is None:
        return {}
    if len(df.index) == 0:
        return {str(c): np.nan for c in df.columns}
    row0 = df.iloc[0]
    out = {}
    for c in df.columns:
        v = row0[c]
        try:
            out[str(c)] = float(v) if np.isscalar(v) and (v == v) else np.nan
        except Exception:
            out[str(c)] = str(v)
    return out

def write_setting_summary(summary_csv: str, rows: List[Dict[str, Any]]):
    fieldnames = ["setting", "ok", "fail", "total", "fail_rate"]
    append_rows(summary_csv, fieldnames, rows)

def proj_dist_ct_torch(
    Q_ctpr: torch.Tensor,
    Gamma0_pr: torch.Tensor,
    eps: float = 1e-12,
) -> torch.Tensor:
 
    C, T, p, u = Q_ctpr.shape
    G0 = Gamma0_pr.to(device=Q_ctpr.device, dtype=Q_ctpr.dtype)

    # (선택) 수치 안정성: QR로 재직교화 — 필요 없으면 지워도 됨
    Q_flat = Q_ctpr.reshape(C * T, p, u)
    Q_flat, _ = torch.linalg.qr(Q_flat, mode="reduced")
    Q = Q_flat.reshape(C, T, p, u)

    G0, _ = torch.linalg.qr(G0, mode="reduced")

 
    M = torch.einsum("ctpa,pb->ctab", Q, G0)


    fro2 = (M * M).sum(dim=(-2, -1))   # (C,T)


    dist2 = 2.0 * u - 2.0 * fro2

    return torch.sqrt(torch.clamp(dist2, min=0.0) + eps)


def run_one_job(job: Dict[str, Any]) -> Dict[str, Any]:
    rep = int(job["rep"])
    setting_dir = job["setting_dir"]

    logs_dir = os.path.join(setting_dir, "logs")
    pkls_dir = os.path.join(setting_dir, "pkls")
    safe_mkdir(logs_dir)
    safe_mkdir(pkls_dir)

    log_path = os.path.join(logs_dir, f"rep{rep}.log")
    pkl_path = os.path.join(pkls_dir, f"rep{rep}.pkl")    
    pt_path  = os.path.join(pkls_dir, f"rep{rep}.samples.pt")     
    tb_path  = os.path.join(logs_dir, f"rep{rep}.traceback.txt")

    logger = setup_rep_logger(log_path, rep)

  

    method = job["method"]
    p, u = int(job["p"]), int(job["u"])
    N = int(job["N"])
    eps, c = float(job["eps"]), float(job["c"])
    L, step_size = int(job["L"]), float(job["step_size"])

    seeds = [s + 1000 * rep for s in BASE_SEEDS]

    rvec = (R_TRUE_BASE + [R_TRUE_BASE[-1]] * max(0, u - len(R_TRUE_BASE)))[:u]
    R_true = torch.diag(torch.tensor(rvec, dtype=torch.float64))

    Y, A, Gamma0, W0 = simulate_latent_subspace(
        p=p, u=u, n=N,
        sigma_y=SIGMA_Y,
        R=R_true,
        seed=123 + rep
    )
    Phi = make_Phi_diag(p=p, m=u, eps=eps)

    t0 = time.time()
    try:
        def make_sampler() -> ConstrainedPolarHMC:
            return ConstrainedPolarHMC(
                logprob_fn=lambda q: logprob_ppca_macg(
                    q,
                    A.to(dtype=q.dtype, device=q.device),
                    Phi.to(dtype=q.dtype, device=q.device),
                    p, u, BETA
                ),
                step_size=step_size,
                num_steps=L,
                device="cpu",
                dtype=torch.float64,
                use_reflection=job["use_reflection"],
                use_soft_barrier=job["use_soft_barrier"],
                beta_phiK=BETA_PHI_K,
                lam_penalty=LAM_PENALTY,
                rho=RHO_BARRIER,
                barrier_sigma=BARRIER_SIGMA,
                barrier_power=BARRIER_POWER,
                hit_time=HIT_TIME,
                target_accept=TARGET_ACCEPT,
            )

        logger.info("=" * 80)
        logger.info(f"setting_dir={setting_dir}")
        logger.info(f"method={method} rep={rep}")
        logger.info(f"p={p} u={u} N={N} eps={eps} c={c} L={L} step={step_size}")
        logger.info(f"use_reflection={job['use_reflection']} use_soft_barrier={job['use_soft_barrier']}")
        logger.info(f"burnin={BURNIN} thin={THIN} n_kept={N_SAMPLES} n_chains={N_CHAINS}")
        logger.info(f"R={R} M={M} rho_barrier={RHO_BARRIER} barrier_power={BARRIER_POWER}")
        logger.info(f"seeds={seeds}")
        logger.info("=" * 80)

        samples, budget, meta = run_chains(
            make_sampler,
            R=R, r=p, u=u, c=c, M=M,
            n_kept=N_SAMPLES,
            burnin=BURNIN,
            thin=THIN,
            n_chains=N_CHAINS,
            seeds=seeds,
            verbose=False,         
            kept_only=False,
            do_hungarian=job["do_hungarian"],
        )

        df_meta_b_chain, df_meta_b_sum, df_b_rhat, df_ess_b_eff = report_one_rep_chains_to_dfs(
            samples=samples,
            meta=meta,
            budget=budget,
            gamma_key=GAMMA_KEY,
            clamp_eps=CLAMP_EPS,
            ess_pick=ESS_METRIC,
            A_ref=None
        )

        Q_post = samples["Q"]  # torch tensor (C,T,p,u)
        Gamma0_t = torch.tensor(Gamma0, dtype=torch.float64) if not torch.is_tensor(Gamma0) else Gamma0.to(dtype=torch.float64)
        d_ct = proj_dist_ct_torch(Q_post, Gamma0_t)  # (C,T)

      
        d_mean = float(torch.nanmean(d_ct).item())
        d_median = float(torch.nanmedian(d_ct).item())

       
        d_chain_mean = torch.nanmean(d_ct, dim=1).detach().cpu().numpy().tolist()

     
        meta["proj_dist_ct"] = d_ct.detach().cpu().numpy()        
        meta["proj_dist_mean"] = d_mean
        meta["proj_dist_median"] = d_median
        meta["proj_dist_chain_mean"] = d_chain_mean
        
        df_ess_mean = pick_cols(df_ess_b_eff, ["ess", "mean"])
        df_rhat_max = pick_cols(df_b_rhat, ["rhat", "max"])
        ess_dict = df_firstrow_to_dict(df_ess_mean)
        rhat_dict = df_firstrow_to_dict(df_rhat_max)

        elapsed = time.time() - t0
        logger.info(f"[DONE] elapsed_sec={elapsed:.3f}")

       
        torch.save(samples, pt_path)

        with open(pkl_path, "wb") as f:
            pickle.dump(
                {
                    "rep": rep,
                    "elapsed_sec": elapsed,
                    "budget": budget,
                    "meta": meta,
                    "df_meta_chain": df_meta_b_chain,
                    "df_meta_sum": df_meta_b_sum,
                    "df_rhat": df_b_rhat,
                    "df_ess": df_ess_b_eff,
                    "pt_samples_path": pt_path,   
                },
                f,
            )

        return {
            "setting_dir": setting_dir,
            "rep": rep,
            "status": "OK",
            "elapsed_sec": elapsed,
            "ess_mean_selected": json.dumps(ess_dict, ensure_ascii=False),
            "rhat_max_selected": json.dumps(rhat_dict, ensure_ascii=False),
            "proj_dist_mean": d_mean,
            "proj_dist_median": d_median,
            "proj_dist_chain_mean": json.dumps(d_chain_mean, ensure_ascii=False),
            "error_type": "",
            "error_msg": "",
        }


    except Exception as e:
        elapsed = time.time() - t0
        etype = type(e).__name__
        emsg = str(e)

        tb = traceback.format_exc()
        try:
            with open(tb_path, "w") as f:
                f.write(tb)
        except Exception:
            tb_path = ""

        try:
            logger.exception("[FAILED]")
        except Exception:
            pass

        return {
            "setting_dir": setting_dir,
            "rep": rep,
            "status": "fail",
            "elapsed_sec": elapsed,
            "ess_mean_selected": "",
            "rhat_max_selected": "",
            "proj_dist_mean": "",
            "proj_dist_median": "",
            "proj_dist_chain_mean": "",
            "error_type": etype,
            "error_msg": emsg,
        }


# ============================================================
# main
# ============================================================
def prepare_setting_folder(setting_dir: str, config_obj: Dict[str, Any]):
    safe_mkdir(setting_dir)
    safe_mkdir(os.path.join(setting_dir, "logs"))
    safe_mkdir(os.path.join(setting_dir, "pkls"))
    cfg_path = os.path.join(setting_dir, "config.json")
    if not os.path.exists(cfg_path):
        with open(cfg_path, "w") as f:
            json.dump(config_obj, f, ensure_ascii=False, indent=2)

def main():
    safe_mkdir(OUT_ROOT)
    torch.set_num_threads(1) 


    if REP_END is None:
        rep_list = list(range(REP_START, REP_START + N_REPS))
    else:
        rep_list = list(range(REP_START, REP_END))
    
    grid_base = list(product(
        GRID["p_list"],
        GRID["u_list"],
        GRID["N_list"],
        GRID["eps_list"],
        GRID["L_list"],
        GRID["step_size_list"],
    ))

    total_settings = 0
    for method_cfg in METHODS:
        if method_cfg["name"] == "naive" or (not method_cfg["use_reflection"] and not method_cfg["use_soft_barrier"]):
            total_settings += len(grid_base) * 1
        else:
            total_settings += len(grid_base) * len(GRID["c_list"])

    print(f"[RUN] grid_points(base)={len(grid_base)}, methods={len(METHODS)}, total_settings={total_settings}")
    print(f"[RUN] N_REPS={N_REPS}, N_WORKERS={N_WORKERS}, RESUME={RESUME_MISSING}")


    all_jobs: List[Dict[str, Any]] = []
    results_csv_by_setting = {}

    for method_cfg in METHODS:
     
        if method_cfg["name"] == "naive" or (not method_cfg["use_reflection"] and not method_cfg["use_soft_barrier"]):
            c_list_eff = [GRID["c_list"][0]]  
        else:
            c_list_eff = GRID["c_list"]
            
        for (p, u, N, eps, L, step) in grid_base:
            for c in c_list_eff:
                sname = setting_dir_name(method_cfg["name"], p, u, N, eps, c, L, step)
                sdir = os.path.join(OUT_ROOT, sname)
                results_csv = os.path.join(sdir, "results.csv")
                results_csv_by_setting[sdir] = results_csv

                prepare_setting_folder(
                    sdir,
                    {
                        "method": method_cfg,
                        "grid_point": {"p": p, "u": u, "N": N, "eps": eps, "c": c, "L": L, "step_size": step},
                        "ppca_params": {
                            "sigma_y": SIGMA_Y,
                            "beta": BETA,
                            "R_true_base": R_TRUE_BASE,
                        },
                        "run_params": {
                            "R": R, "M": M, "burnin": BURNIN, "thin": THIN,
                            "n_samples": N_SAMPLES, "n_chains": N_CHAINS,
                            "rho_barrier": RHO_BARRIER, "barrier_power": BARRIER_POWER,
                            "beta_phiK": BETA_PHI_K, "lam_penalty": LAM_PENALTY,
                            "barrier_sigma": BARRIER_SIGMA, "target_accept": TARGET_ACCEPT,
                            "hit_time": HIT_TIME,
                            "gamma_key": GAMMA_KEY, 
                            "clamp_eps": CLAMP_EPS, "ess_pick": ESS_METRIC
                        },
                        "n_reps": N_REPS,
                    },
                )

                done = load_done_reps(results_csv) if RESUME_MISSING else set()

                for rep in rep_list:
                    if rep in done:
                        continue
                    all_jobs.append({
                        "setting_dir": sdir,
                        "method": method_cfg["name"],
                        "use_reflection": method_cfg["use_reflection"],
                        "use_soft_barrier": method_cfg["use_soft_barrier"],
                        "do_hungarian": method_cfg["do_hungarian"],
                        "rep": rep,
                        "p": p, "u": u, "N": N, "eps": eps, "c": c, "L": L, "step_size": step,
                    })

    print(f"[RUN] total_jobs_to_run={len(all_jobs)}")
    if len(all_jobs) == 0:
        print("[RUN] nothing to do (all reps done?)")
        return


    fieldnames = [
        "rep", "status", "elapsed_sec",
        "ess_mean_selected", "rhat_max_selected",
        "proj_dist_mean", "proj_dist_median", "proj_dist_chain_mean",
        "error_type", "error_msg"
    ]

    buffers = defaultdict(list) 
    flush_every = 10


    counts = defaultdict(lambda: {"ok": 0, "fail": 0})
    setting_total = defaultdict(int)  # setting_dir -> total jobs (to run)
    for job in all_jobs:
        setting_total[job["setting_dir"]] += 1

    total_jobs = len(all_jobs)
    done_jobs = 0
    ok_jobs = 0
    fail_jobs = 0

    t0 = time.time()


    with ProcessPoolExecutor(max_workers=N_WORKERS) as ex:
        futs = [ex.submit(run_one_job, job) for job in all_jobs]

        total_jobs = len(futs)
        done_jobs = 0
        ok_jobs = 0
        fail_jobs = 0

        last_print_t = time.time()
        print_interval_sec = 600

        for fut in as_completed(futs):
            row = fut.result()
            sdir = row.pop("setting_dir")

            buffers[sdir].append(row)

            status = str(row.get("status", "")).strip().lower()
            done_jobs += 1
            if status == "ok":
                ok_jobs += 1
                counts[sdir]["ok"] += 1
            else:
                fail_jobs += 1
                counts[sdir]["fail"] += 1

            if len(buffers[sdir]) >= flush_every:
                append_rows(results_csv_by_setting[sdir], fieldnames, buffers[sdir])
                buffers[sdir].clear()

         
            now = time.time()
            if (now - last_print_t) >= print_interval_sec or done_jobs == total_jobs:
                remaining = total_jobs - done_jobs
                elapsed = now - t0
                print(
                    f"[PROGRESS] done={done_jobs}/{total_jobs} remaining={remaining} "
                    f"(ok={ok_jobs}, fail={fail_jobs}) elapsed_min={elapsed/60:.1f}",
                    flush=True
                )
                last_print_t = 0


    for sdir, rows in buffers.items():
        if rows:
            append_rows(results_csv_by_setting[sdir], fieldnames, rows)
            rows.clear()

  
    summary_csv = os.path.join(OUT_ROOT, "summary_by_setting.csv")
    summary_rows = []
    for sdir, d in counts.items():
        ok = d.get("ok", 0)
        fail = d.get("fail", 0)
        tot = ok + fail
        summary_rows.append({
            "setting": os.path.basename(sdir),
            "ok": ok,
            "fail": fail,
            "total": tot,
            "fail_rate": (fail / tot) if tot > 0 else "",
        })
    write_setting_summary(summary_csv, summary_rows)

    print(f"\n[SUMMARY] saved: {summary_csv}")
    for r in summary_rows:
        print(f"- {r['setting']}: ok={r['ok']} fail={r['fail']} fail_rate={r['fail_rate']}")
    print(f"\n[ALL DONE] elapsed_sec={time.time() - t0:.2f}")

    
if __name__ == "__main__":
    main()
