# Discrete-control DQN with soft-∞ surrogates (batch-LSC / elem-LSC)
# Adds: block-end AlphaWatcher (fixed V_t, Y_t, beta_ref) and greedy evaluation loop

import os
import copy
import random
import time
from dataclasses import dataclass
from datetime import datetime

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import tyro
from torch.utils.tensorboard import SummaryWriter

from cleanrl_utils.buffers import ReplayBuffer


@dataclass
class Args:
    exp_name: str = os.path.basename(__file__)[: -len(".py")]
    seed: int = 42
    torch_deterministic: bool = True
    cuda: bool = True
    track: bool = False
    wandb_project_name: str = "cleanRL"
    wandb_entity: str = None 
    capture_video: bool = False
    save_model: bool = False
    upload_model: bool = False
    hf_entity: str = ""

    # Algorithm specific
    env_id: str = "CartPole-v1"  # e.g., "CartPole-v1", "Acrobot-v1", "MountainCar-v0"
    total_timesteps: int = 600_000
    learning_rate: float = 2.5e-4
    num_envs: int = 1
    buffer_size: int = 20_000
    gamma: float = 0.99
    tau: float = 1.0  # Polyak factor
    target_network_frequency: int = 500
    batch_size: int = 32
    start_e: float = 1.0
    end_e: float = 0.05
    exploration_fraction: float = 0.2
    learning_starts: int = 10_000
    train_frequency: int = 4
    use_double_q: bool = True

    # Surrogate options (parity with your Atari script)
    surrogate_type: str = "elem_lsc"  # "mse", "batch_lsc", "elem_lsc", "huber"
    beta: float = 0.5
    adaptive_beta: bool = True
    beta_c: float = 1.          # target c in: beta ≈ c / E_hat  (recommend 0.7–1.5)
    beta_min: float = 1e-1
    beta_max: float = 20.0
    beta_eps: float = 2e-6       # avoids division by 0
    beta_quantile: float = 0.95  # use high quantile for E
    beta_ema: float = 0.10       # EMA smoothing for E_hat

    # Evaluation settings (added)
    eval_frequency: int = 5000
    eval_episodes: int = 5
    eval_epsilon: float = 0.0  # greedy eval

    # Alpha watcher settings (added)
    alpha_val_multiplier: int = 50  # |V_t| = multiplier * batch_size
    alpha_beta_ref_mode: str = "c_over_E"  # {"c_over_E", "const"}
    alpha_beta_ref_const: float = 1.0


def make_env(env_id, seed, idx, capture_video, run_name):
    def thunk():
        if capture_video and idx == 0:
            env = gym.make(env_id, render_mode="rgb_array")
            env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
        else:
            if env_id == "MountainCar-v0":
                env = gym.make(env_id, max_episode_steps=1000)
            else:
                env = gym.make(env_id)
        env = gym.wrappers.RecordEpisodeStatistics(env)
        env.action_space.seed(seed)
        return env
    return thunk


class QNetwork(nn.Module):
    def __init__(self, env):
        super().__init__()
        obs_dim = int(np.array(env.single_observation_space.shape).prod())
        self.network = nn.Sequential(
            nn.Linear(obs_dim, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, env.single_action_space.n),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Expect input (B, obs_dim)
        return self.network(x)


def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
    slope = (end_e - start_e) / duration
    return max(slope * t + start_e, end_e)


# -------- Batch-LSC (soft L∞): (1/β) * log( mean_i cosh(β δ_i) ) ----------
def batch_lsc_loss(td_errors: torch.Tensor, beta: float) -> torch.Tensor:
    # td_errors: (B,)
    u = beta * td_errors
    m = u.abs().max().detach()
    # cosh(u) = 0.5(e^{u} + e^{-u}); factor out exp(m) for stability
    s = 0.5 * (torch.exp(u - m) + torch.exp(-u - m))  # (B,)
    return (m + torch.log(s.mean())) / beta


# -------- Elementwise-LSC (smooth |·| family): mean_i [ (1/β) * log cosh(β δ_i) ] ----------
def elemwise_lsc_loss(td_errors: torch.Tensor, beta: float) -> torch.Tensor:
    x = td_errors * beta
    ax = x.abs()
    # stable logcosh(x) = |x| + log(1 + exp(-2|x|)) - log(2)
    logcosh = ax + torch.log1p(torch.exp(-2.0 * ax)) - np.log(2.0)
    return logcosh.mean() / (beta**2)


class BetaScheduler:
    def __init__(self, c: float, q: float, ema: float, beta_min: float, beta_max: float, eps: float):
        self.c = c
        self.q = q
        self.ema = ema
        self.beta_min = beta_min
        self.beta_max = beta_max
        self.eps = eps
        self.E_hat = None
        self._switched = False
        self._violations = 0

    @torch.no_grad()
    def update_and_get(
        self,
        td_errors: torch.Tensor,
        step: int = None,           # <- optional: provide to enable two-phase
        warmup_steps: int = None,   # <- optional: provide to enable two-phase
        beta_warmup: float = 0.1,
        patience: int = 0           # require N consecutive violations to switch
    ) -> float:
        # --- update E_hat (unchanged) ---
        E_q = torch.quantile(td_errors.detach().abs(), self.q)
        if self.E_hat is None:
            self.E_hat = E_q
        else:
            self.E_hat = (1 - self.ema) * self.E_hat + self.ema * E_q
        E = float(self.E_hat.item() if isinstance(self.E_hat, torch.Tensor) else self.E_hat)

        # --- if no warmup args are given, keep original behavior ---
        if step is None or warmup_steps is None:
            raw = self.c / (E + self.eps)
            beta_eff = torch.clamp(torch.as_tensor(raw), self.beta_min, self.beta_max)
            return float(beta_eff.item())

        if not self._switched:
            if step >= warmup_steps:
                self._switched = True
            else:
                violates = (beta_warmup * E) >  self.c
                if violates:
                    self._violations += 1
                    if self._violations >= patience:
                        self._switched = True
                else:
                    self._violations = 0  # reset if no violation

            if not self._switched:
                # Phase 1: constant small beta
                return float(beta_warmup)

        # Phase 2: your original adaptive rule
        raw = self.c / (E + self.eps)
        beta_eff = torch.clamp(torch.as_tensor(raw), self.beta_min, self.beta_max)
        return float(beta_eff.item())

    def get_E_hat(self) -> float:
        if self.E_hat is None:
            return 0.0
        return float(self.E_hat.item() if isinstance(self.E_hat, torch.Tensor) else self.E_hat)


# === Alpha watcher helpers (fixed targets over a validation set) ===
@torch.no_grad()
def make_fixed_targets(
    q_online_snap: nn.Module, q_target_snap: nn.Module, batch, gamma: float, use_double_q: bool
) -> torch.Tensor:
    next_obs = batch.next_observations
    if use_double_q:
        q_next_online = q_online_snap(next_obs)
        next_actions = q_next_online.argmax(dim=1, keepdim=True)
        q_next_target = q_target_snap(next_obs)
        q_next = q_next_target.gather(1, next_actions).squeeze(1)
    else:
        q_next = q_target_snap(next_obs).max(dim=1).values
    td_target = batch.rewards.flatten() + gamma * q_next * (1.0 - batch.dones.flatten())
    return td_target  # (B,)


def loss_against_fixed_Y(
    q_net: nn.Module,
    obs: torch.Tensor,
    acts: torch.Tensor,
    Y_fixed: torch.Tensor,
    surrogate_type: str,
    beta: float,
) -> tuple[torch.Tensor, torch.Tensor]:
    q_pred_all = q_net(obs)
    if acts.ndim == 1:
        acts = acts.view(-1, 1)
    q_pred = q_pred_all.gather(1, acts.long()).squeeze(1)
    td = q_pred - Y_fixed
    if surrogate_type == "batch_lsc":
        loss = batch_lsc_loss(td, beta=beta)
    elif surrogate_type == "elem_lsc":
        loss = elemwise_lsc_loss(td, beta=beta)
    elif surrogate_type == "huber":
        loss = F.huber_loss(q_pred, Y_fixed, delta=1/beta)
    else:
        loss = F.mse_loss(q_pred, Y_fixed)
    return loss, td


class AlphaWatcher:
    def __init__(self):
        self.active = False
        self.obs = None
        self.acts = None
        self.Y_fixed = None
        self.beta_ref = None
        self.baseline_loss = None

    @torch.no_grad()
    def start_block(
        self,
        q_online_snap: nn.Module,
        q_target_snap: nn.Module,
        rb,
        batch_size: int,
        gamma: float,
        surrogate_type: str,
        beta_ref: float,
        device: torch.device,
        use_double_q: bool,
    ):
        batch = rb.sample(batch_size)
        self.obs = batch.observations.to(device)
        self.acts = batch.actions.to(device)
        self.Y_fixed = make_fixed_targets(q_online_snap, q_target_snap, batch, gamma, use_double_q).to(device)
        self.beta_ref = float(beta_ref)
        loss0, _ = loss_against_fixed_Y(
            q_net=q_online_snap,  # baseline measured at block start snapshot
            obs=self.obs,
            acts=self.acts,
            Y_fixed=self.Y_fixed,
            surrogate_type=surrogate_type,
            beta=self.beta_ref,
        )
        self.baseline_loss = float(loss0.item())
        self.active = True

    @torch.no_grad()
    def finalize_alpha(self, q_online_current: nn.Module, surrogate_type: str) -> tuple[float, float]:
        if not self.active or self.baseline_loss is None or self.baseline_loss <= 0:
            return 1.0, 0.0
        loss_k, td = loss_against_fixed_Y(
            q_net=q_online_current,
            obs=self.obs,
            acts=self.acts,
            Y_fixed=self.Y_fixed,
            surrogate_type=surrogate_type,
            beta=self.beta_ref,
        )
        alpha = float(loss_k.item()) / self.baseline_loss
        max_abs_td = float(td.abs().max().item())
        # reset active so we don't accidentally reuse
        self.active = False
        return alpha, max_abs_td


@torch.no_grad()
def evaluate_policy(q_net: nn.Module, env_id: str, episodes: int, epsilon: float, device: torch.device) -> tuple[float, float, float]:
    env = gym.make(env_id, max_episode_steps=1000) if env_id == "MountainCar-v0" else gym.make(env_id)
    returns = []
    lengths = []
    for _ in range(episodes):
        obs, _ = env.reset()
        done = False
        trunc = False
        ep_ret = 0.0
        ep_len = 0
        while not (done or trunc):
            if random.random() < epsilon:
                action = env.action_space.sample()
            else:
                q = q_net(torch.as_tensor(obs, device=device).float().unsqueeze(0))
                action = int(torch.argmax(q, dim=1).item())
            obs, r, done, trunc, _ = env.step(action)
            ep_ret += float(r)
            ep_len += 1
        returns.append(ep_ret)
        lengths.append(ep_len)
    env.close()
    returns = np.asarray(returns, dtype=np.float32)
    lengths = np.asarray(lengths, dtype=np.float32)
    return float(returns.mean()), float(returns.std()), float(lengths.mean())


def format_loss_tag(surrogate_type: str, beta: float) -> str:
    st = surrogate_type.lower()
    if st == "mse":
        return "loss=mse"
    elif st == "batch_lsc":
        return f"loss=batch_lsc(beta={beta:g})"
    elif st == "elem_lsc":
        return f"loss=elem_lsc(beta={beta:g})"
    elif st == "huber":
        return f"loss=huber(delta={1/beta:g})"
    else:
        return f"loss={surrogate_type}"


if __name__ == "__main__":
    args = tyro.cli(Args)
    assert args.num_envs == 1, "vectorized envs are not supported at the moment"

    # Human-readable run name like your Atari script
    timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    loss_tag = format_loss_tag(args.surrogate_type, args.beta)
    if args.adaptive_beta and args.surrogate_type != "mse":
        loss_tag += f"_beta=adaptive(C={args.beta_c:g},[{args.beta_min:g},{args.beta_max:g}])"
    if args.use_double_q:
        run_name = "__".join(["final_new_new", args.env_id, loss_tag, f"seed{args.seed}", timestamp])
    else:
        run_name = "__".join([args.env_id, loss_tag, f"seed{args.seed}", timestamp])

    if args.track:
        import wandb
        wandb.init(
            project=args.wandb_project_name,
            entity=args.wandb_entity,
            sync_tensorboard=True,
            config=vars(args),
            name=run_name,
            monitor_gym=True,
            save_code=True,
        )

    writer = SummaryWriter(f"runs/{run_name}")
    writer.add_text(
        "hyperparameters",
        "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{k}|{v}|" for k, v in vars(args).items()])),
    )

    # Seeding
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = args.torch_deterministic

    device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")

    # Env
    envs = gym.vector.SyncVectorEnv(
        [make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)]
    )
    assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"
    q_network = QNetwork(envs).to(device)
    target_network = QNetwork(envs).to(device)
    target_network.load_state_dict(q_network.state_dict())

    # Optimizer
    optimizer = optim.AdamW(q_network.parameters(), lr=args.learning_rate)

    rb = ReplayBuffer(
        args.buffer_size,
        envs.single_observation_space,
        envs.single_action_space,
        device,
        handle_timeout_termination=False,
    )

    start_time = time.time()
    obs, _ = envs.reset(seed=args.seed)

    beta_sched = None
    if args.adaptive_beta and args.surrogate_type != "mse":
        beta_sched = BetaScheduler(
            c=args.beta_c,
            q=args.beta_quantile,
            ema=args.beta_ema,
            beta_min=args.beta_min,
            beta_max=args.beta_max,
            eps=args.beta_eps,
        )

    # Alpha watcher setup
    alpha_watcher = AlphaWatcher()

    for global_step in range(args.total_timesteps):
        # Action selection
        epsilon = linear_schedule(
            args.start_e, args.end_e, int(args.exploration_fraction * args.total_timesteps), global_step
        )
        if random.random() < epsilon:
            actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
        else:
            with torch.no_grad():
                q_values = q_network(torch.as_tensor(obs, device=device).float())
                actions = torch.argmax(q_values, dim=1).cpu().numpy()

        # Step
        next_obs, rewards, terminations, truncations, infos = envs.step(actions)

        # Logging episodic stats
        if "final_info" in infos:
            for info in infos["final_info"]:
                if info and "episode" in info:
                    writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
                    writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
                    print(f"global_step={global_step}, episodic_return={info['episode']['r']}")

        # Store — use dones = terminations OR truncations; buffer handles time-limit correction
        dones = np.logical_or(terminations, truncations)
        rb.add(obs, next_obs, actions, rewards, dones, infos)
        obs = next_obs  # crucial

        # Learning
        if global_step > args.learning_starts and global_step % args.train_frequency == 0:
            batch = rb.sample(args.batch_size)

            # --- compute target (no grad) ---
            with torch.no_grad():
                if args.use_double_q:
                    q_next_online = q_network(batch.next_observations)                    # (B, A)
                    next_actions  = q_next_online.argmax(dim=1, keepdim=True)            # (B, 1)
                    q_next_target = target_network(batch.next_observations)               # (B, A)
                    q_next = q_next_target.gather(1, next_actions).squeeze(1)            # (B,)
                else:
                    q_next = target_network(batch.next_observations).max(dim=1).values    # (B,)
                td_target = batch.rewards.flatten() + args.gamma * q_next * (1.0 - batch.dones.flatten())

            # --- current Q for chosen actions ---
            q_pred_all = q_network(batch.observations)               # (B, A)
            acts = batch.actions.long()
            if acts.ndim == 1:
                acts = acts.view(-1, 1)
            q_pred = q_pred_all.gather(1, acts).squeeze(1)          # (B,)
            td_errors = q_pred - td_target

            # choose beta for TRAINING (can adapt every step)
            if args.adaptive_beta and args.surrogate_type != "mse":
                beta_train = beta_sched.update_and_get(
                        td_errors=td_errors,
                        step=global_step,
                        warmup_steps=int(0.2 * args.total_timesteps),
                        beta_warmup=0.5,
                        patience=10
                    )
            else:
                beta_train = args.beta

            print(f"step={global_step} | beta_train={beta_train:.3f}", end="\r")
            if args.surrogate_type == "batch_lsc":
                loss = batch_lsc_loss(td_errors, beta=beta_train)
            elif args.surrogate_type == "elem_lsc":
                loss = elemwise_lsc_loss(td_errors, beta=beta_train)
            elif args.surrogate_type == "huber":
                loss = F.huber_loss(q_pred, td_target, delta=1/beta_train)
            else:  # "mse"
                loss = F.mse_loss(q_pred, td_target)

            # --- optimize ---
            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            # torch.nn.utils.clip_grad_norm_(q_network.parameters(), 10.0)
            optimizer.step()

            # periodic logging
            if global_step % 100 == 0:
                td_train = td_errors.detach()
                train_abs = td_train.abs()
                writer.add_scalar("losses/td_loss", loss.item(), global_step)
                writer.add_scalar("losses/q_values", q_pred.mean().item(), global_step)
                sps = int(global_step / (time.time() - start_time))
                print("SPS:", sps)
                writer.add_scalar("charts/SPS", sps, global_step)
                writer.add_scalar("train/max_abs_td", train_abs.max().item(), global_step)
                writer.add_scalar("train/q95_abs_td", torch.quantile(train_abs, 0.95).item(), global_step)
                writer.add_scalar("train/q99_abs_td", torch.quantile(train_abs, 0.99).item(), global_step)
                writer.add_scalar("train/td_target_mean", td_target.mean().item(), global_step)
                writer.add_scalar("train/td_target_std",  td_target.std().item(), global_step)
                if beta_sched is not None:
                    writer.add_scalar("beta/value", beta_train, global_step)
                    writer.add_scalar("beta/E_hat", beta_sched.get_E_hat(), global_step)

            # === BLOCK BOUNDARY: finalize previous alpha, sync targets, and start next block ===
            if global_step % args.target_network_frequency == 0:
                # (1) Finalize alpha for the current block just BEFORE syncing targets
                if alpha_watcher.active:
                    alpha_final, max_abs_td_final = alpha_watcher.finalize_alpha(
                        q_online_current=q_network, surrogate_type=args.surrogate_type
                    )
                    writer.add_scalar("val/alpha_block_final", alpha_final, global_step)
                    writer.add_scalar("val/block_max_abs_td_final", max_abs_td_final, global_step)

                # (2) Sync target network
                with torch.no_grad():
                    for tp, p in zip(target_network.parameters(), q_network.parameters()):
                        tp.data.lerp_(p.data, args.tau)

                # (3) Prepare snapshots for building next block's fixed targets
                q_online_snap = copy.deepcopy(q_network).eval()
                q_target_snap = copy.deepcopy(target_network).eval()

                # (4) Decide beta_ref for evaluation in the next block
                if args.surrogate_type != "mse":
                    if args.alpha_beta_ref_mode == "c_over_E":
                        # We'll estimate E on the soon-to-be validation set after Y is built
                        beta_ref = None  # placeholder
                    else:
                        beta_ref = float(args.alpha_beta_ref_const)
                else:
                    beta_ref = float(args.beta)

                # (5) Start next block: sample V_t and build fixed Y_t
                val_bs = args.batch_size * args.alpha_val_multiplier
                # ensure buffer has enough; if not, cap to current size heuristically
                try:
                    val_batch = rb.sample(val_bs)
                except Exception:
                    val_batch = rb.sample(max(args.batch_size, val_bs // 2))

                # Build Y_fixed first (DDQN/DQN rule matched to training)
                Y_fixed = make_fixed_targets(
                    q_online_snap=q_online_snap,
                    q_target_snap=q_target_snap,
                    batch=val_batch,
                    gamma=args.gamma,
                    use_double_q=args.use_double_q,
                )

                # If beta_ref needs c/Ê, compute Ê on this V_t using current TDs against Y_fixed
                if beta_ref is None:
                    with torch.no_grad():
                        obs_v = val_batch.observations.to(device)
                        acts_v = val_batch.actions.to(device)
                        q_all = q_online_snap(obs_v)
                        if acts_v.ndim == 1:
                            acts_v = acts_v.view(-1, 1)
                        q_sel = q_all.gather(1, acts_v.long()).squeeze(1)
                        td_v = q_sel - Y_fixed.to(device)
                        E_q = torch.quantile(td_v.abs(), args.beta_quantile)
                        raw = args.beta_c / (E_q + args.beta_eps)
                        raw = torch.clamp(raw, args.beta_min, args.beta_max)
                        beta_ref = float(raw.item())
                writer.add_scalar("beta/ref", beta_ref, global_step)

                # (6) Initialize watcher state with baseline measured on the snapshot
                alpha_watcher.start_block(
                    q_online_snap=q_online_snap,
                    q_target_snap=q_target_snap,
                    rb=rb,
                    batch_size=val_bs,
                    gamma=args.gamma,
                    surrogate_type=args.surrogate_type,
                    beta_ref=beta_ref,
                    device=device,
                    use_double_q=args.use_double_q,
                )

        # === Greedy episodic evaluation on cadence ===
        if (
            global_step >= args.learning_starts
            and args.eval_frequency > 0
            and global_step % args.eval_frequency == 0
        ):
            with torch.no_grad():
                avg_ret, std_ret, avg_len = evaluate_policy(
                    q_net=q_network,
                    env_id=args.env_id,
                    episodes=args.eval_episodes,
                    epsilon=args.eval_epsilon,
                    device=device,
                )
            writer.add_scalar("eval/episodic_return_greedy", avg_ret, global_step)
            writer.add_scalar("eval/episodic_return_std", std_ret, global_step)
            writer.add_scalar("eval/episodic_length_mean", avg_len, global_step)

    # === Optional: save & eval after training ===
    if args.save_model:
        model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model"
        torch.save(q_network.state_dict(), model_path)
        print(f"model saved to {model_path}")
        # quick greedy eval using our helper
        avg_ret, std_ret, avg_len = evaluate_policy(
            q_net=q_network,
            env_id=args.env_id,
            episodes=max(10, args.eval_episodes),
            epsilon=0.0,
            device=device,
        )
        writer.add_scalar("eval/episodic_return_greedy_final", avg_ret, args.total_timesteps)
        writer.add_scalar("eval/episodic_return_std_final", std_ret, args.total_timesteps)
        writer.add_scalar("eval/episodic_length_mean_final", avg_len, args.total_timesteps)

        if args.upload_model:
            from cleanrl_utils.huggingface import push_to_hub
            repo_name = f"{args.env_id}-{args.exp_name}-seed{args.seed}"
            repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name
            # Note: push_to_hub expects videos etc.; this is optional
            push_to_hub(args, [avg_ret], repo_id, "DQN", f"runs/{run_name}", f"videos/{run_name}-eval")

    envs.close()
    writer.close()
