import os
import numpy as np
import gymnasium as gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import wandb

from networks import Actor, Critic_r

class AR_A2C_Agent:
    """Average-Reward A2C with fixed anchor s_ref = [1, 0, 0]."""
    def __init__(self, env, seed=777,
                 actor_lr=1e-5, critic_lr=7e-5,
                 rho_lr=1e-4, rho_clip=5.0,
                 hidden=256, blocks=2,
                 wandb_project="discounted_a2c", run_name=None, log_every=100,
                 anchor_every=1):
        self.env = env
        self.seed = seed
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        obs_dim = env.observation_space.shape[0]
        act_dim = env.action_space.shape[0]

        self.actor = Actor(obs_dim, act_dim).to(self.device)
        self.critic = Critic_r(obs_dim, hidden=hidden, num_blocks=blocks).to(self.device)

        self.actor_opt = optim.Adam(self.actor.parameters(), lr=actor_lr)
        self.critic_opt = optim.Adam(self.critic.parameters(), lr=critic_lr)

        # average reward estimator ρ: shape [1]
        self.rho = torch.zeros(1, dtype=torch.float32, device=self.device)
        self.rho_lr = rho_lr
        self.rho_clip = float(rho_clip)

        # EMA for delta normalization (single-step, no batch)
        self.ema_mean = 0.0
        self.ema_var = 1.0
        self.ema_beta = 0.99

        self.total_step = 0
        self.is_test = False

        # Fixed anchor: s_ref = [cos 0, sin 0, dtheta=0] = [1, 0, 0]
        self.s_ref = torch.tensor([[1.0, 0.0, 0.0]], dtype=torch.float32, device=self.device)
        self.anchor_every = int(anchor_every)

        # wandb
        self.use_wandb = len(wandb_project) > 0
        if self.use_wandb:
            wandb.init(project=wandb_project,
                       name=run_name or f"ar-a2c-anchor-fixed_{seed}",
                       config=dict(seed=seed, actor_lr=actor_lr, critic_lr=critic_lr,
                                   rho_lr=rho_lr, rho_clip=rho_clip,
                                   hidden=hidden, blocks=blocks,
                                   anchor_every=self.anchor_every,
                                   setting="avg_reward_single_env_single_step_anchor_fixed"))

        self.log_every = log_every

    @staticmethod
    def _done_flag(terminated, truncated):
        return float(terminated or truncated)

    def _step_env(self, action):
        next_state, reward, terminated, truncated, _ = self.env.step(action)
        return next_state, reward, self._done_flag(terminated, truncated)

    def select_action(self, state):
        s = torch.as_tensor(state, dtype=torch.float32, device=self.device).unsqueeze(0)
        if self.is_test:
            a = self.actor.deterministic(s); logp = torch.zeros(1, device=self.device)
        else:
            a, logp = self.actor(s)
        return a.squeeze(0).detach().cpu().numpy(), logp  # logp:[1]

    # --- helpers ---
    def _ema_norm(self, x: float):
        m, v, b = self.ema_mean, self.ema_var, self.ema_beta
        m = b * m + (1 - b) * x
        v = b * v + (1 - b) * (x - m) ** 2
        self.ema_mean, self.ema_var = m, max(v, 1e-12)
        return (x - m) / (np.sqrt(self.ema_var) + 1e-8)

    @torch.no_grad()
    def _anchor_critic(self):
        """RVI-style anchoring: enforce h(s_ref) = 0 by shifting the final bias."""
        h_ref = self.critic(self.s_ref).squeeze(-1)  # [1]
        self.critic.out.bias.add_(-h_ref)            # global shift
        return float(h_ref.item())

    # --- updates ---
    def update_critic(self, state, reward, next_state, done):
        s  = torch.as_tensor(state, dtype=torch.float32, device=self.device).unsqueeze(0)
        s2 = torch.as_tensor(next_state, dtype=torch.float32, device=self.device).unsqueeze(0)
        r  = torch.tensor([reward], dtype=torch.float32, device=self.device)
        mask = torch.tensor([1.0 - done], dtype=torch.float32, device=self.device)

        with torch.no_grad():
            target = r - self.rho + mask * self.critic(s2).squeeze(-1)  # [1]

        pred = self.critic(s).squeeze(-1)  # [1]
        value_loss = F.smooth_l1_loss(pred, target)

        self.critic_opt.zero_grad()
        value_loss.backward()
        nn.utils.clip_grad_norm_(self.critic.parameters(), 1.0)
        self.critic_opt.step()

        anchor_val = 0.0
        if (self.anchor_every > 0) and (self.total_step % self.anchor_every == 0):
            anchor_val = self._anchor_critic()

        return float(value_loss.item()), anchor_val

    def update_actor_and_rho(self, state, log_prob, reward, next_state, done):
        s  = torch.as_tensor(state, dtype=torch.float32, device=self.device).unsqueeze(0)
        s2 = torch.as_tensor(next_state, dtype=torch.float32, device=self.device).unsqueeze(0)
        r  = torch.tensor([reward], dtype=torch.float32, device=self.device)
        mask = torch.tensor([1.0 - done], dtype=torch.float32, device=self.device)

        with torch.no_grad():
            v_s  = self.critic(s).squeeze(-1)   # [1]
            v_s2 = self.critic(s2).squeeze(-1)  # [1]
            delta = r - self.rho + mask * v_s2 - v_s  # [1]
            delta_raw = float(delta.item())
            delta_norm = self._ema_norm(delta_raw)
            adv = torch.tensor([delta_norm], dtype=torch.float32, device=self.device)  # [1]

        policy_loss = -(adv.detach() * log_prob).sum()

        self.actor_opt.zero_grad()
        policy_loss.backward()
        nn.utils.clip_grad_norm_(self.actor.parameters(), 1.0)
        self.actor_opt.step()

        with torch.no_grad():
            delta_clip = torch.clamp(delta, -5.0, 5.0)  # 轻截断稳 rho
            self.rho.add_(self.rho_lr * delta_clip)

        return float(policy_loss.item()), delta_raw, float(delta_norm), float(self.rho.item())

    # --- viz ---
    def generate_v_heatmap(self):
        theta = np.linspace(-np.pi, np.pi, 250)
        dtheta = np.linspace(-8.0, 8.0, 250)
        TH, DW = np.meshgrid(theta, dtheta)
        grid = np.stack([np.cos(TH), np.sin(TH), DW], axis=-1).reshape(-1, 3)
        states = torch.as_tensor(grid, dtype=torch.float32, device=self.device)
        with torch.no_grad():
            v = self.critic(states).squeeze(-1).cpu().numpy().reshape(250, 250)  # anchored h(s)
        os.makedirs('heatmaps_ar', exist_ok=True)
        fig_path = f'heatmaps_ar/ar_v_heatmap_step_{self.total_step}.png'
        plt.figure(figsize=(7, 6), dpi=140)
        plt.pcolormesh(theta, dtheta, v, cmap='viridis', shading='auto')
        plt.colorbar(label='h(s) (anchored @ [1,0,0])')
        plt.xlabel('theta (rad)'); plt.ylabel('theta_dot (rad/s)')
        plt.title(f'Average-Reward A2C (anchored) — step {self.total_step}')
        plt.tight_layout(); plt.savefig(fig_path); plt.close()
        return fig_path

    # --- train/test ---
    def train(self, num_frames: int, heatmap_every: int = 50_000):
        self.is_test = False
        state, _ = self.env.reset(seed=self.seed)
        ep_ret, ep_len = 0.0, 0
        for self.total_step in range(1, num_frames + 1):
            action, log_prob = self.select_action(state)
            next_state, reward, done = self._step_env(action)

            critic_loss, anchor_val = self.update_critic(state, reward, next_state, done)
            actor_loss, delta_raw, delta_norm, rho_val = self.update_actor_and_rho(
                state, log_prob, reward, next_state, done
            )

            ep_ret += reward; ep_len += 1
            state = next_state
            if done:
                if self.use_wandb:
                    wandb.log({"episode_reward": ep_ret, "episode_length": ep_len, "steps": self.total_step})
                state, _ = self.env.reset(); ep_ret, ep_len = 0.0, 0

            if self.use_wandb and (self.total_step % self.log_every == 0):
                wandb.log({
                    "actor_loss": actor_loss, "critic_loss": critic_loss,
                    "delta_raw": delta_raw, "delta_norm": delta_norm,
                    "rho": rho_val, "anchor_value_pre_shift": anchor_val,
                    "steps": self.total_step
                })

            if (self.total_step % heatmap_every) == 0:
                fig_path = self.generate_v_heatmap()
                if self.use_wandb:
                    wandb.log({"V_heatmap": wandb.Image(fig_path), "steps": self.total_step})

        if self.use_wandb:
            wandb.finish()

    def test(self, video_folder: str):
        self.is_test = True
        tmp_env = self.env
        self.env = gym.wrappers.RecordVideo(self.env, video_folder=video_folder)
        state, _ = self.env.reset()
        score, done = 0.0, False
        while not done:
            action, _ = self.select_action(state)
            state, reward, done = self._step_env(action)
            score += reward
        print(f"[test] episode reward sum = {score:.2f}")
        self.env.close()
        self.env = tmp_env
