import numpy as np
import math
import matplotlib.pyplot as plt
from tqdm import tqdm
from scipy.optimize import linprog
import json, time
from pathlib import Path

from utils import *

def _json_float(x):
    try:
        return float(x)
    except Exception:
        return None


class JobSchedulingCMDP:
    def __init__(self, S=10, A=2, H=10, seed=0):
        self.rng = np.random.default_rng(seed)
        self.S = S
        self.A = A
        self.H = H
        self.s_initial = None

        self.loss_variant = 0          
        self.loss_schedule = None      
        self.reset()

    def reset(self):
        self.s_initial = self.S - 1

    def step(self, s, a, h):
        if a == 1:
            u = self.rng.random()
            if u < 0.8:
                s_next = max(s - 2, 0)
            elif u < 0.9:
                s_next = max(s - 1, 0)
            else:
                s_next = s
        else:
            s_next = s

        f = self.get_loss(a, h)         
        g = self.get_cost(s, s_next)
        return s_next, f, g

    def loss_fn1(self, a, h):
        if a == 0:
            return 1.0
        elif 3 <= h <= 6:
            return 0.55 * a
        else:
            return 0.1 * a

    def loss_fn2(self, a, h):
        if a == 0:
            return 1.0
        elif 4 <= h <= 6:       
            return 0.6 * a
        else:
            return 0.2 * a

    def get_loss(self, a, h):
        return self.loss_fn1(a, h) if self.loss_variant == 0 else self.loss_fn2(a, h)

    def get_cost(self, s, s_next):
        return 1 - (s - s_next) / 2.0

    def prepare_adversarial_schedule(self, K: int):
        self.loss_schedule = []
        if K <= 1:
            self.loss_schedule = [0]
            return
        for k in range(K):
            # p2 = k / (K - 1)          # 세팅2 확률
            p2 = 0.1 + 0.9 * k / (K - 1)
            v = 1 if self.rng.random() < p2 else 0
            self.loss_schedule.append(v)

    def select_loss_variant_for_episode(self, k: int, K: int, use_schedule: bool = True):
        if use_schedule and (self.loss_schedule is not None) and (k < len(self.loss_schedule)):
            self.loss_variant = self.loss_schedule[k]
        else:
            if K <= 1:
                self.loss_variant = 0
            else:
                p2 = k / (K - 1)
                self.loss_variant = 1 if self.rng.random() < p2 else 0


def trans_probs(env, s, a, h):
    S = env.S
    if a == 0:
        # stays
        return {s: 1.0}
    else:
        # a == 1
        s2 = max(s - 2, 0)
        s1 = max(s - 1, 0)
        s0 = s
        probs = {}
        for s_next, p in [(s2, 0.8), (s1, 0.1), (s0, 0.1)]:
            probs[s_next] = probs.get(s_next, 0.0) + p
        return probs


def expected_step_cost(env, s, a, h):
    """E[g | (h,s,a)] = sum_{s'} P(s'|s,a,h) * g(s, s')."""
    ps = trans_probs(env, s, a, h)
    return sum(p * env.get_cost(s, s_next) for s_next, p in ps.items())


def build_cmdp_lp(env, H, b, K):
    if env.loss_schedule is None or len(env.loss_schedule) < K:
        raise ValueError('error')

    S, A = env.S, env.A
    nvar = H * S * A

    def idx(h, s, a):
        return h * (S * A) + s * A + a

    schedule = np.asarray(env.loss_schedule[:K], dtype=int)
    w2 = schedule.mean()
    w1 = 1.0 - w2

    c = np.zeros(nvar)
    for h in range(H):
        for s in range(S):
            for a in range(A):
                loss1 = env.loss_fn1(a, h)
                loss2 = env.loss_fn2(a, h)
                c[idx(h, s, a)] = w1 * loss1 + w2 * loss2

    cost_vec = np.zeros(nvar)
    for h in range(H):
        for s in range(S):
            for a in range(A):
                cost_vec[idx(h, s, a)] = expected_step_cost(env, s, a, h)

    A_ub = cost_vec.reshape(1, -1)
    b_ub = np.array([b], dtype=float)

    n_eq = H * S
    A_eq = np.zeros((n_eq, nvar))
    b_eq = np.zeros(n_eq)

    d0 = np.zeros(S)
    d0[env.s_initial] = 1.0

    def row(h, s):
        return h * S + s

    # h=0
    for s in range(S):
        r = row(0, s)
        for a in range(A):
            A_eq[r, idx(0, s, a)] = 1.0
        b_eq[r] = d0[s]

    # h>=1
    for h in range(1, H):
        for s in range(S):
            r = row(h, s)
            for a in range(A):
                A_eq[r, idx(h, s, a)] += 1.0
            for s_prev in range(S):
                for a_prev in range(A):
                    p = trans_probs(env, s_prev, a_prev, h - 1).get(s, 0.0)
                    if p != 0.0:
                        A_eq[r, idx(h - 1, s_prev, a_prev)] -= p

    return c, A_eq, b_eq, A_ub, b_ub

def recover_policy_from_occupancy(x, env, H):
    S, A = env.S, env.A
    pi_list = []
    for h in range(H):
        pi_h = np.zeros((S, A))
        for s in range(S):
            denom = 0.0
            for a in range(A):
                xi = x[h * (S * A) + s * A + a]
                pi_h[s, a] = xi
                denom += xi
            if denom > 0:
                pi_h[s, :] /= denom
            else:
                pi_h[s, :] = 1.0 / A
        pi_list.append(pi_h)
    return pi_list





class PrimalDualPPM:
    """CMDP agent with primal-dual periodic policy mixing"""
    def __init__(self, env, S, A, H, K, b, seed=0):
        self.rng = np.random.default_rng(seed)
        self.env = env
        self.S, self.A, self.H, self.K = S, A, H, K
        self.d = S * A

        self.seed = seed
        self.Y_history = []  


        self.alpha = 0.1
        self.beta_b = K**0.25
        self.beta_w = self.beta_b * np.log(K)
        self.theta = K ** (-1)
        self.eta   = H ** (-2) * K ** (-3/4)
        self.KB    = int(K ** (3/4))
        self.b = 5.6  # constraint threshold

        self.k_e = 0
        

        self.diag = [np.ones(self.d, dtype=float) for _ in range(H)]   
        self.logdet_Lambdas   = np.zeros(H, dtype=float)               
        self.logdet_Lambdas_e = np.zeros(H, dtype=float)               
        self.diag_e_inv = [np.ones(self.d, dtype=float) for _ in range(H)]  

        self.C_counts = [np.zeros((self.d, self.S), dtype=np.int64) for _ in range(H)]
        self.g_sums   = [np.zeros(self.d, dtype=float) for _ in range(H)]

        self.last_idx   = np.full(H, -1, dtype=int)   
        self.last_snext = np.zeros(H, dtype=int)      
        self.last_g     = np.zeros(H, dtype=float)    

        self.contracted_coef = [np.ones(self.d, dtype=float) for _ in range(H)]

        self.pi_hats = [self.unif() for _ in range(H)]
        self.Y = 0.0

        self.loss_history = []
        self.cost_history = []
        self.loss_avg_history = []
        self.cost_avg_history = []

        self.reset(k=0)

    def unif(self):
        return np.full((self.S, self.A), 1.0 / self.A)

    def reset(self, k):
        self.k_e = k
        for h in range(self.H):
            self.logdet_Lambdas[h] = np.sum(np.log(self.diag[h]))
        self.logdet_Lambdas_e = self.logdet_Lambdas.copy()

        for h in range(self.H):
            self.diag_e_inv[h] = 1.0 / self.diag[h]
            sqrt_term = np.sqrt(self.diag_e_inv[h])  # (d,)
            self.contracted_coef[h] = 1.0 / (1.0 + np.exp(-( -self.beta_w * sqrt_term + math.log(self.K) )))

        self.pi_hats = [self.unif() for _ in range(self.H)]
        self.Y = 0.0



    def solve_cmdp_via_lp(self, b=None, K=None, verbose=False):
        if b is None:
            b = self.b
        if K is None:
            K = self.K

        if (self.env.loss_schedule is None) or (len(self.env.loss_schedule) < K):
            self.env.prepare_adversarial_schedule(K)

        c, A_eq, b_eq, A_ub, b_ub = build_cmdp_lp(
            self.env, self.H, b, K=K 
        )
        bounds = [(0.0, None)] * (self.H * self.S * self.A)

        res = linprog(c=c, A_ub=A_ub, b_ub=b_ub, A_eq=A_eq, b_eq=b_eq,
                    bounds=bounds, method="highs")
        if res.status != 0:
            msg = f"LP solver failed (status={res.status}): {res.message}"
            if verbose:
                print(msg)
            raise RuntimeError(msg)

        x_opt = res.x
        self.pi_opt = recover_policy_from_occupancy(x_opt, self.env, self.H)

        total_loss = float(c @ x_opt)  
        total_cost = float(A_ub @ x_opt) 
        self.loss_opt_episode = total_loss
        self.cost_opt_episode = total_cost
        return self.pi_opt, self.loss_opt_episode, self.cost_opt_episode



    def run(self, print_every=50):
        env = self.env
        S, A, H, K = self.S, self.A, self.H, self.K

        V_f = np.zeros((H + 1, S), dtype=float)
        V_g = np.zeros((H + 1, S), dtype=float)

        for k in tqdm(range(K)):
            env.select_loss_variant_for_episode(k, K, use_schedule=True)

            if (k == 0) or np.any(self.logdet_Lambdas - self.logdet_Lambdas_e >= np.log(2.0)):
                self.reset(k)

            # Rollout
            s = env.s_initial
            ep_loss = 0.0
            ep_cost = 0.0
            for h in range(H):
                pi_hat = self.pi_hats[h][s]
                a = self.rng.choice(A, p=pi_hat)
                s_next, f, g = env.step(s, a, h)

                i = phi_idx(s, a, S, A)

                self.diag[h][i] += 1.0
                self.C_counts[h][i, s_next] += 1
                self.g_sums[h][i] += g

                self.last_idx[h] = i
                self.last_snext[h] = s_next
                self.last_g[h] = g

                ep_loss += f
                ep_cost += g
                s = s_next

            for h in reversed(range(H)):
                self.logdet_Lambdas[h] = np.sum(np.log(self.diag[h]))

                i_last = self.last_idx[h]
                s_last = self.last_snext[h]

                diag_prev = self.diag[h].copy()
                diag_prev[i_last] -= 1.0
                inv_diag_prev = 1.0 / diag_prev   

                u_f = self.C_counts[h] @ V_f[h + 1]   
                u_f[i_last] -= V_f[h + 1, s_last]
                psiV_f = inv_diag_prev * u_f         

                u_g = self.C_counts[h] @ V_g[h + 1]
                u_g[i_last] -= V_g[h + 1, s_last]
                psiV_g = inv_diag_prev * u_g

                g_sums_prev = self.g_sums[h].copy()
                g_sums_prev[i_last] -= self.last_g[h]
                theta_g_hat = inv_diag_prev * g_sums_prev  

                loss_a = np.array([env.get_loss(aa, h) for aa in range(A)], dtype=float)  
                inst_f_vec = np.tile(loss_a, S)  

                coef_vec  = self.contracted_coef[h]                    
                bonus_vec = self.beta_b * coef_vec * np.sqrt(self.diag_e_inv[h])  

                Qf_vec = coef_vec * (inst_f_vec + psiV_f) - bonus_vec    
                Qg_vec = coef_vec * (theta_g_hat + psiV_g) - bonus_vec   

                Q_f = Qf_vec.reshape(S, A)
                Q_g = Qg_vec.reshape(S, A)

                V_f[h] = (self.pi_hats[h] * Q_f).sum(axis=1)
                V_g[h] = (self.pi_hats[h] * Q_g).sum(axis=1)

                # Periodic policy mixing 
                if ((k - self.k_e) % self.KB) == 0:
                    pi_tilde = (1.0 - self.theta) * self.pi_hats[h] + self.theta * self.unif()
                else:
                    pi_tilde = self.pi_hats[h]

                logits = np.log(np.maximum(pi_tilde, 1e-12)) - self.alpha * (Q_f + self.Y * Q_g)
                try:
                    self.pi_hats[h] = softmax(logits, axis=1)  
                except Exception:
                    expx = np.exp(logits - logits.max(axis=1, keepdims=True))
                    self.pi_hats[h] = expx / expx.sum(axis=1, keepdims=True)

            # Dual update
            self.Y = max(
                0.0,
                self.Y * (1 - 4 * self.alpha * self.eta * (self.H ** 3))
                + self.eta * (
                    V_g[0, self.env.s_initial]
                    - self.b
                    - 4 * self.alpha * (self.H ** 3)
                    - 4 * self.theta * (self.H ** 2)
                )
            )
            self.Y_history.append(self.Y)

            self.loss_history.append(ep_loss)
            self.cost_history.append(ep_cost)
            self.loss_avg_history.append(np.mean(self.loss_history))
            self.cost_avg_history.append(np.mean(self.cost_history))

            if print_every and ((k + 1) % print_every == 0):
                tqdm.write(
                    f"[Episode {k+1}] "
                    f"loss={ep_loss:.4f} | cost={ep_cost:.4f} | "
                    f"avg_loss={self.loss_avg_history[-1]:.4f} | avg_cost={self.cost_avg_history[-1]:.4f}"
                )


    def save_results(self, out_dir="results", tag=None, extra_notes=None):
        out_path = Path(out_dir)
        out_path.mkdir(parents=True, exist_ok=True)

        ts = time.strftime("%Y%m%d-%H%M%S")
        run_name = ts + (f"_{tag}" if tag else "")
        run_dir = out_path / run_name
        run_dir.mkdir(parents=True, exist_ok=True)

        # 1) 시계열 저장 (npz)
        np.savez_compressed(
            run_dir / "histories.npz",
            loss_history=np.asarray(self.loss_history, dtype=float),
            cost_history=np.asarray(self.cost_history, dtype=float),
            loss_avg_history=np.asarray(self.loss_avg_history, dtype=float),
            cost_avg_history=np.asarray(self.cost_avg_history, dtype=float),
            Y_history=np.asarray(self.Y_history, dtype=float),
        )

        meta = {
            "S": self.S, "A": self.A, "H": self.H, "K": self.K,
            "seed_agent": self.seed,
            "params": {
                "alpha": _json_float(self.alpha),
                "eta": _json_float(self.eta),
                "theta": _json_float(self.theta),
                "beta_b": _json_float(getattr(self, "beta_b", getattr(self, "beta", None))),
                "beta_w": _json_float(getattr(self, "beta_w", None)),
                "KB": int(self.KB),
                "b": _json_float(self.b),
            },
            "lp_opt": {
                "loss_opt_episode": _json_float(getattr(self, "loss_opt_episode", None)),
                "cost_opt_episode": _json_float(getattr(self, "cost_opt_episode", None)),
            },
            "adversarial": {
                "has_schedule": bool(getattr(self.env, "loss_schedule", None)),
                "schedule_len": len(getattr(self.env, "loss_schedule", [])) if getattr(self.env, "loss_schedule", None) is not None else 0,
            },
            "notes": extra_notes,
            "created_at": ts,
        }
        with open(run_dir / "meta.json", "w", encoding="utf-8") as f:
            json.dump(meta, f, ensure_ascii=False, indent=2)

        if getattr(self.env, "loss_schedule", None) is not None:
            np.save(run_dir / "loss_schedule.npy",
                    np.asarray(self.env.loss_schedule, dtype=int))

        with open(run_dir / "README.txt", "w", encoding="utf-8") as f:
            f.write(
                "Files:\n"
                "- histories.npz: loss/cost, running averages, Y, loss_variant\n"
                "- meta.json: env/hparams/LP-opt/meta\n"
                "- loss_schedule.npy: (optional) adversarial schedule (0/1 per episode)\n"
            )

        return str(run_dir)





def main():
    K = 100000
    S, A, H = 10, 2, 10
    R = 10  
    b = 5.6

    base_env = JobSchedulingCMDP(S, A, H, seed=0)
    base_env.prepare_adversarial_schedule(K)
    realized_schedule = list(base_env.loss_schedule)  

    agent_for_lp = PrimalDualPPM(base_env, S, A, H, K, b, seed=0)
    agent_for_lp.solve_cmdp_via_lp(K=K, verbose=False)
    loss_opt_episode = float(agent_for_lp.loss_opt_episode)
    cost_opt_episode = float(agent_for_lp.cost_opt_episode)
    print(f"[LP-Optimal (Hindsight on realized schedule)] "
          f"loss_ep={loss_opt_episode:.4f}, cost_ep={cost_opt_episode:.4f}")

    loss_runs = []
    cost_runs = []
    for r in range(R):
        seed_env = r   
        seed_agent = r        
        tag = f"advK{K}_run{r}"
        res = run_once(
            K=K, S=S, A=A, H=H, b=b,
            schedule=realized_schedule,
            seed_env=seed_env, seed_agent=seed_agent,
            print_every=0, save_tag=tag  
        )
        loss_runs.append(res["loss_avg"])
        cost_runs.append(res["cost_avg"])

    loss_runs = np.vstack(loss_runs)  
    cost_runs = np.vstack(cost_runs)  

    out_dir = Path("results")
    out_dir.mkdir(parents=True, exist_ok=True)
    np.savez_compressed(
        out_dir / f"aggregate_R{R}_K{K}.npz",
        loss_runs=loss_runs,
        cost_runs=cost_runs,
        loss_opt_episode=loss_opt_episode,
        cost_opt_episode=cost_opt_episode,
    )
    print(f"Saved aggregate arrays to: {out_dir / f'aggregate_R{R}_K{K}.npz'}")


def run_once(K, S, A, H, b, schedule, seed_env, seed_agent, print_every=0, save_tag=None):
    env = JobSchedulingCMDP(S, A, H, seed=seed_env)
    env.loss_schedule = list(schedule)  
    agent = PrimalDualPPM(env, S, A, H, K, b, seed=seed_agent)

    agent.solve_cmdp_via_lp(K=K, verbose=False)

    agent.run(print_every=print_every)

    if save_tag is not None:
        agent.save_results(out_dir="results", tag=save_tag)

    return {
        "loss_avg": np.asarray(agent.loss_avg_history, dtype=float),
        "cost_avg": np.asarray(agent.cost_avg_history, dtype=float),
        "loss_opt": float(getattr(agent, "loss_opt_episode", np.nan)),
        "cost_opt": float(getattr(agent, "cost_opt_episode", np.nan)),
    }



if __name__ == '__main__':
    main()