import numpy as np
import gymnasium as gym

import torch
import torch.nn as nn
import torch.optim as optim

from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv

import matplotlib.pyplot as plt
import os

from datetime import datetime
from stable_baselines3.common.callbacks import BaseCallback




class PotentialNet(nn.Module):
    """Φ_ψ(s): 2-layer MLP, 256 units each, ReLU activations."""
    def __init__(self, obs_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 1),
        )

    def forward(self, x):
        # x: (..., obs_dim)
        return self.net(x).squeeze(-1)  # (...,)


class RewardProjectionNNWrapper(gym.Wrapper):
    """
    Gymnasium wrapper that returns shaped/core reward:
        r_core = r - clip(γ Φ(s') - Φ(s), [-clip_val, clip_val])

    but fits Φ by minimizing the Ω* (max-entropy) projection loss at fit-time:

        Ω*_s(x) = log E_{a~π(·|s)}[ exp(x(a)) ]
        with x(a) = r(s,a) - (b + γΦ(s'(a)) - Φ(s))

    Practical implementation:
      - During rollout: store (obs_t, mujoco qpos_t, qvel_t) only (cheap).
      - Every `fit_every_episodes`: pick M stored states, sample K actions from current PPO policy,
        fork the simulator for one step to get (r_k, s'_k), then optimize Φ with:

            loss = mean_{states} [ logsumexp_k(x_k) - log(K) ]
    """
    def __init__(
        self,
        env,
        gamma=0.99,
        lr=1e-3,
        min_transitions=1000,
        train_epochs=5,
        batch_size=256,
        device="cpu",
        # ---- fit-time branching controls ----
        fit_M=256,          # number of stored states branched per fit
        fit_K=128,            # number of actions sampled per state
        fit_every_episodes=5,
        # ---- shaping clamp ----
        clip_val=0.05,
        reg_lambda=1e-4,   # <-- a tiny ridge term to satisfy strong convexity
    ):
        super().__init__(env)
        assert isinstance(env.observation_space, gym.spaces.Box), \
            "Wrapper assumes continuous observations."
        assert isinstance(env.action_space, gym.spaces.Box), \
            "Wrapper assumes continuous actions."
        assert env.observation_space.shape is not None and len(env.observation_space.shape) == 1, \
            "Wrapper assumes 1D observations."

        self.gamma = gamma
        self.min_transitions = min_transitions
        self.train_epochs = train_epochs
        self.batch_size = batch_size
        self.device = device

        self.fit_M = fit_M
        self.fit_K = fit_K
        self.fit_every_episodes = fit_every_episodes
        self.clip_val = clip_val

        self.lr = lr  # <---- store learning rate
        self.reg_lambda = reg_lambda

        obs_dim = env.observation_space.shape[0]
        self.obs_dim = obs_dim
        self.potential = PotentialNet(obs_dim).to(self.device)
        self.bias = nn.Parameter(torch.zeros(1, device=self.device))
        self.optim = optim.Adam(list(self.potential.parameters()) + [self.bias], lr=lr)

        # buffers for fit-time branching (store pre-step obs and mujoco state)
        self.obs_buf = []     # list[np.ndarray obs_dim]
        self.qpos_buf = []    # list[np.ndarray]
        self.qvel_buf = []    # list[np.ndarray]

        self._last_obs = None
        self.episode_count = 0

        # SB3 policy injected by callback
        self._sb3_policy = None

    # ---------------- SB3 policy injection ----------------
    def set_sb3_policy(self, policy):
        self._sb3_policy = policy

    # ---------------- Mujoco state helpers ----------------
    def _get_mujoco_state(self):
        raw = self.env.unwrapped
        # MuJoCo data arrays
        qpos = raw.data.qpos.copy()
        qvel = raw.data.qvel.copy()
        return qpos, qvel

    def _set_mujoco_state(self, qpos, qvel):
        raw = self.env.unwrapped
        raw.set_state(qpos, qvel)
        # forward call name differs across versions
        if hasattr(raw, "mj_forward"):
            raw.mj_forward()
        elif hasattr(raw, "sim") and hasattr(raw.sim, "forward"):
            raw.sim.forward()

    # ---------------- gym API ----------------
    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        self._last_obs = np.array(obs, copy=True)
        return obs, info

    def step(self, action):
        # ---- store pre-step state for later fit-time branching ----
        # (only store if we have a valid last_obs)
        if self._last_obs is not None:
            qpos, qvel = self._get_mujoco_state()
            self.obs_buf.append(self._last_obs.copy())
            self.qpos_buf.append(qpos)
            self.qvel_buf.append(qvel)

        # ---- real env step ----
        next_obs, reward, terminated, truncated, info = self.env.step(action)
        done = terminated or truncated

        # ---- shaping for returned reward (uses real next_obs) ----
        with torch.no_grad():
            s_t = torch.as_tensor(self._last_obs, dtype=torch.float32, device=self.device)
            s_tp1 = torch.as_tensor(next_obs, dtype=torch.float32, device=self.device)
            phi_t = self.potential(s_t.unsqueeze(0))[0]
            phi_tp1 = self.potential(s_tp1.unsqueeze(0))[0]
            shaping_term = self.gamma * phi_tp1.item() - phi_t.item()
            shaping_term = float(np.clip(shaping_term, -self.clip_val, self.clip_val))
            #shaping_term = 0.0 if abs(shaping_term) > self.clip_val else shaping_term


        r_core = float(reward) - shaping_term
        self._last_obs = np.array(next_obs, copy=True)

        
        # ---- periodic fit ----
        if done:
            self.episode_count += 1
            if (self.episode_count % self.fit_every_episodes) == 0:
                self._fit_potential_if_ready()

        return next_obs, r_core, terminated, truncated, info

    # ---------------- core: Ω* projection fit-time branching ----------------
    def _fit_potential_if_ready(self):
        n = len(self.obs_buf)
        if n < self.min_transitions:
            return
        if self._sb3_policy is None:
            # cannot sample from current PPO policy
            print("Cannot sample from current PPO policy!")
            return
        
        # ---- reinitialize Φ and optimizer for this fit ----
        self.potential = PotentialNet(self.obs_dim).to(self.device)
        self.bias = nn.Parameter(torch.zeros(1, device=self.device))
        self.optim = optim.Adam(
            list(self.potential.parameters()) + [self.bias],
            lr=self.lr,
        )

        M = min(self.fit_M, n)
        K = self.fit_K
        # sample M indices from stored buffer
        idx = np.random.choice(n, size=M, replace=False)

        obs_np = np.stack([self.obs_buf[i] for i in idx], axis=0)          # [M, obs_dim]
        qpos_list = [self.qpos_buf[i] for i in idx]
        qvel_list = [self.qvel_buf[i] for i in idx]

        obs = torch.as_tensor(obs_np, dtype=torch.float32, device=self.device)  # [M, obs_dim]

        # ---- sample K actions from current PPO policy for all M states (vectorized) ----
        with torch.no_grad():
            dist = self._sb3_policy.get_distribution(obs)
            # sample shape: [K, M, act_dim] -> [M, K, act_dim]
            a = dist.distribution.sample((K,)).permute(1, 0, 2).contiguous()

        # clip actions to env bounds
        low = torch.as_tensor(self.action_space.low, device=self.device, dtype=a.dtype)
        high = torch.as_tensor(self.action_space.high, device=self.device, dtype=a.dtype)
        a = torch.max(torch.min(a, high), low)

        # ---- fork simulator for 1-step outcomes (M*K env steps, only at fit-time) ----
        next_obsK = np.zeros((M, K, self.obs_dim), dtype=np.float32)
        rewK = np.zeros((M, K), dtype=np.float32)

        raw = self.env.unwrapped
        for m in range(M):
            qpos0, qvel0 = qpos_list[m], qvel_list[m]
            for k in range(K):
                self._set_mujoco_state(qpos0, qvel0)
                a_mk = a[m, k].detach().cpu().numpy()
                ob_k, r_k, term_k, trunc_k, info_k = raw.step(a_mk)
                next_obsK[m, k] = np.array(ob_k, copy=True)
                rewK[m, k] = float(r_k)

        # restore to the *current* live simulator state for safety:
        # (we don't have it saved here; but we only call this at episode end, so reset() comes next anyway)

        next_obsK_t = torch.as_tensor(next_obsK, dtype=torch.float32, device=self.device)  # [M,K,obs_dim]
        rewK_t = torch.as_tensor(rewK, dtype=torch.float32, device=self.device)            # [M,K]

        # ---- optimize Φ with Ω* loss: log E_{a~π}[exp(x)] ≈ log(1/K Σ exp(x_k)) ----
        import math
        for _ in range(self.train_epochs):
            perm = torch.randperm(M, device=self.device)
            for i in range(0, M, self.batch_size):
                bidx = perm[i:i + self.batch_size]
                s = obs[bidx]                              # [B, obs_dim]
                s_nextK = next_obsK_t[bidx]                # [B, K, obs_dim]
                rK = rewK_t[bidx]                          # [B, K]

                phi_s = self.potential(s)                  # [B]
                phi_nextK = self.potential(
                    s_nextK.reshape(-1, self.obs_dim)
                ).reshape(-1, K)                            # [B, K]

                # predK = b + γΦ(s'_k) - Φ(s)
                predK = self.bias + self.gamma * phi_nextK - phi_s.unsqueeze(1)  # [B, K]
                x = rK - predK                                                    # [B, K]

                # Ω*_s(x) = log E_{a~π}[exp(x(a))] ≈ log(1/K Σ exp(x_k))
                omega_star_hat = torch.logsumexp(x, dim=1) - math.log(K)          # [B]
                loss = omega_star_hat.mean()

                # optional gauge-fixing (prevents Φ drifting by constant)
                loss = loss + 1e-3 * (phi_s.mean() ** 2)

                reg = 0.0
                for p in self.potential.parameters():
                    reg = reg + torch.sum(p ** 2)
                loss = loss + 2 * self.reg_lambda * reg

                self.optim.zero_grad()
                loss.backward()
                self.optim.step()

        # clear buffers to keep it online
        self.obs_buf.clear()
        self.qpos_buf.clear()
        self.qvel_buf.clear()


class InjectPolicyCallback(BaseCallback):
    """
    Inject SB3 PPO policy into the (single) wrapped env inside DummyVecEnv.
    """
    def __init__(self, verbose=0):
        super().__init__(verbose)

    def _on_training_start(self) -> None:
        env0 = self.training_env.envs[0]
        if hasattr(env0, "set_sb3_policy"):
            env0.set_sb3_policy(self.model.policy)

    def _on_rollout_start(self) -> None:
        env0 = self.training_env.envs[0]
        if hasattr(env0, "set_sb3_policy"):
            env0.set_sb3_policy(self.model.policy)

    def _on_step(self) -> bool:
        return True


############################################
# PPO env factory (for DummyVecEnv)
############################################

def make_env(env_id="Hopper-v4", gamma=0.99, device="cpu", projected=True):
    def _thunk():
        env = gym.make(env_id)
        if projected:
            env = RewardProjectionNNWrapper(
                env,
                gamma=gamma,
                lr=1e-3,
                min_transitions=1000,
                train_epochs=10,
                batch_size=1024,
                device=device,
            )
        return env
    return _thunk


def make_plain_env(env_id="Hopper-v4"):
    """Plain Hopper-v4 environment with no reward projection."""
    def _thunk():
        return gym.make(env_id)
    return _thunk



############################################
# Evaluation on raw Hopper-v4 (10 rollouts)
############################################

def evaluate_policy_returns(model, env_id="Hopper-v4", n_eval_episodes=10):
    """
    Evaluate a trained SB3 model on the RAW Gymnasium env.

    Returns: np.array of episode returns, shape (n_eval_episodes,)
    """
    env = gym.make(env_id)
    returns = []

    for ep in range(n_eval_episodes):
        obs, info = env.reset()
        done = False
        ep_ret = 0.0

        while not done:
            action, _ = model.predict(obs, deterministic=True)
            next_obs, reward, terminated, truncated, info = env.step(action)
            done = terminated or truncated
            ep_ret += reward
            obs = next_obs

        returns.append(ep_ret)

    env.close()
    return np.asarray(returns, dtype=np.float32)



class EvalLoggerCallback(BaseCallback):
    """
    After every `eval_freq_rollouts` PPO rollouts, evaluate the current policy
    on a fresh RAW env (no wrapper) for `n_eval_episodes` full episodes and
    log mean/std return + timestamps to a file.
    """

    def __init__(
        self,
        log_path: str,
        eval_env_id: str = "Hopper-v4",
        n_eval_episodes: int = 5,
        eval_freq_rollouts: int = 10,
        verbose: int = 0,
    ):
        super().__init__(verbose)
        self.log_path = log_path
        self.eval_env_id = eval_env_id
        self.n_eval_episodes = n_eval_episodes
        self.eval_freq_rollouts = eval_freq_rollouts

        self.start_time = None
        self.file = None
        self.rollout_count = 0  # how many rollouts finished so far

    def _on_training_start(self) -> None:
        self.start_time = datetime.now()
        self.file = open(self.log_path, "a", encoding="utf-8")
        self.file.write(f"=== Training started at {self.start_time.isoformat()} ===\n")
        self.file.flush()

    def _evaluate_current_policy(self):
        # always evaluate on the RAW env (true reward), no wrapper
        env = gym.make(self.eval_env_id)
        returns = []

        for ep in range(self.n_eval_episodes):
            obs, info = env.reset()
            done = False
            ep_ret = 0.0
            while not done:
                action, _ = self.model.predict(obs, deterministic=True)
                obs, reward, terminated, truncated, info = env.step(action)
                done = terminated or truncated
                ep_ret += reward
            returns.append(ep_ret)

        env.close()

        returns = np.array(returns, dtype=np.float32)
        mean_ret = float(returns.mean())
        std_ret = float(returns.std())
        return mean_ret, std_ret, len(returns)

    def _on_rollout_end(self) -> None:
        # one PPO rollout finished
        self.rollout_count += 1

        if self.rollout_count % self.eval_freq_rollouts != 0:
            return

        mean_ret, std_ret, n_eps = self._evaluate_current_policy()

        self.file.write(
            f"[eval] timesteps={self.num_timesteps}, "
            f"rollouts={self.rollout_count}, "
            f"n_episodes={n_eps}, "
            f"mean_return_true={mean_ret:.2f}, "
            f"std_return_true={std_ret:.2f}\n"
        )
        self.file.flush()

    def _on_step(self) -> bool:
        # We don't use per-step logic in this callback,
        # but BaseCallback requires this method.
        return True

    def _on_training_end(self) -> None:
        end_time = datetime.now()
        duration = end_time - self.start_time
        self.file.write(f"=== Training ended at {end_time.isoformat()} ===\n")
        self.file.write(f"Total training time: {duration}\n")
        self.file.flush()
        self.file.close()



############################################
# Main: train + evaluate + plot mean/std
############################################

if __name__ == "__main__":
    #env_id = "Hopper-v4"
    #env_id = "Walker2d-v4"
    #env_id = "Ant-v4"
    #env_id = "Reacher-v4"
    #env_id = "Humanoid-v4"
    env_id = "HalfCheetah-v4"
    gamma = 0.99
    device = "cuda"

    # ----- create vec env with projected reward -----
    vec_env = DummyVecEnv([make_env(env_id, gamma=gamma, device=device, projected=True)])
    #vec_env = DummyVecEnv([make_plain_env(env_id)])

    model = PPO(
        "MlpPolicy",
        vec_env,
        verbose=1,
        gamma=gamma,
        ent_coef=0.001,   # max-entropy weight α
        n_steps=2048,
        batch_size=64,
        n_epochs=10,
        learning_rate=3e-4,
        seed=0,
        device=device
    )

    total_timesteps = 1_000_000
   

    log_file = "training_log_core_halfcheetah.txt"
    #log_file = "training_log_ppo_halfcheetah.txt"

    eval_callback = EvalLoggerCallback(
        log_path=log_file,
        eval_env_id=env_id,
        n_eval_episodes=5,
        eval_freq_rollouts=10,  # e.g. evaluate every 10 rollouts
    )
    inject_cb = InjectPolicyCallback()

    model.learn(total_timesteps=total_timesteps, callback=[eval_callback,inject_cb])

    vec_env.close()

    os.makedirs("saved_models", exist_ok=True)
    #save_path = "saved_models/hopper_core_reward_500k"
    #save_path = "saved_models/hopper_ppo_500k"
    #save_path = "saved_models/walker_ppo_1000k"
    #save_path = "saved_models/walker_core_reward_1000k_1"
    #save_path = "saved_models/ant_ppo_500k"
    #save_path = "saved_models/ant_core_reward_500k"
    #save_path = "saved_models/reacher_ppo_500k"
    #save_path = "saved_models/reacher_core_reward_500k_1"
    #save_path = "saved_models/humanoid_ppo_1000k"
    #save_path = "saved_models/humanoid_core_reward_1000k"
    #save_path = "saved_models/halfcheetah_ppo_500k"
    #save_path = "saved_models/halfcheetah_core_reward_1000k_1"
    #model.save(save_path)
    #print(f"Model saved to: {save_path}")

    # ----- evaluation on raw Hopper-v4 (no projection) -----
    n_eval_episodes = 100
    returns = evaluate_policy_returns(model, env_id, n_eval_episodes=n_eval_episodes)

    mean_ret = float(returns.mean())
    std_ret = float(returns.std())

    print(f"Evaluation over {n_eval_episodes} episodes on {env_id} (raw env):")
    print(f"  Mean return = {mean_ret:.2f}")
    print(f"  Std return  = {std_ret:.2f}")

    # also append the final eval to the log file
    with open(log_file, "a", encoding="utf-8") as f:
        f.write(
            f"\n=== Final evaluation over {n_eval_episodes} episodes "
            f"on {env_id} (raw env) ===\n"
        )
        f.write(f"Mean return = {mean_ret:.2f}\n")
        f.write(f"Std return  = {std_ret:.2f}\n")

