# algorithms/finsler_actor_critic.py
import torch
import torch.nn as nn
import torch.optim as optim
from algorithms import cvar_utils

class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_sizes=(256,256)):
        super().__init__()
        layers = []
        last_dim = state_dim
        for h in hidden_sizes:
            layers.append(nn.Linear(last_dim, h))
            layers.append(nn.ReLU())
            last_dim = h
        # Policy mean output
        layers.append(nn.Linear(last_dim, action_dim))
        self.mean_head = nn.Sequential(*layers)
        # We use Gaussian policy with state-independent log-std for continuous actions
        self.log_std = nn.Parameter(torch.zeros(action_dim))
    
    def forward(self, x):
        mean = self.mean_head(x)
        std = torch.exp(self.log_std)
        return mean, std  # returns parameters of Gaussian

    def sample_action(self, x):
        mean, std = self.forward(x)
        dist = torch.distributions.Normal(mean, std)
        action = dist.sample()
        return torch.tanh(action)  # map to (-1,1) assuming action space is normalized

class CvarValueNetwork(nn.Module):
    """Value network estimating CVaR value of state (scalar output)."""
    def __init__(self, state_dim, hidden_sizes=(256,256)):
        super().__init__()
        layers = []
        last_dim = state_dim
        for h in hidden_sizes:
            layers.append(nn.Linear(last_dim, h))
            layers.append(nn.ReLU())
            last_dim = h
        layers.append(nn.Linear(last_dim, 1))
        self.net = nn.Sequential(*layers)
    def forward(self, x):
        return self.net(x).squeeze(-1)  # output shape: (batch,)
class FinslerActorCritic:
    def __init__(self, state_dim, action_dim, cvar_alpha=0.1, quantile_mode=False, quantiles=50, lr=3e-4, gamma=0.99, device=None):
        self.device = device or torch.device('cpu')
        self.alpha = cvar_alpha  # CVaR level (e.g., 0.1 for CVaR_{0.1})
        self.gamma = gamma
        self.quantile_mode = quantile_mode
        # Initialize networks
        self.actor = Actor(state_dim, action_dim).to(self.device)
        if quantile_mode:
            # Distributional critic: output multiple quantiles
            self.critic = cvar_utils.QuantileCritic(state_dim, quantiles=quantiles).to(self.device)
        else:
            # Standard CVaR critic (single value output)
            self.critic = CvarValueNetwork(state_dim).to(self.device)
        # Optimizers
        self.actor_opt = optim.Adam(self.actor.parameters(), lr=lr)
        self.critic_opt = optim.Adam(self.critic.parameters(), lr=lr)
        # For PPO-style updates
        self.clip_ratio = 0.2
        self.value_coef = 1.0   # weight for critic loss
        self.entropy_coef = 0.0 # (optional) weight for entropy regularization
        # For off-policy reuse
        self.max_grad_norm = 0.5  # gradient clipping
        self.kl_penalty = 0.0     # weight for KL divergence regularizer
        self.last_actor_params = None  # store old actor params for KL (if needed)
    def sample_trajectory(self, env, max_steps=1000):
        """Roll out the current policy in the environment for up to max_steps (or until done)."""
        states, actions, rewards, dones = [], [], [], []
        state = env.reset()
        for t in range(max_steps):
            state_t = torch.tensor(state, dtype=torch.float32, device=self.device)
            action = self.actor.sample_action(state_t).cpu().numpy()
            next_state, reward, done, info = env.step(action)
            # Collect transition
            states.append(state); actions.append(action); rewards.append(reward); dones.append(done)
            state = next_state
            if done:
                break
        return states, actions, rewards, dones

    def estimate_cvar_returns(self, rewards, dones):
        """Compute discounted return for each timestep and then estimate CVaR of episode return."""
        # Calculate episode returns (discounted sum of rewards) for the trajectory
        G = 0.0
        returns = []
        for r, d in zip(reversed(rewards), reversed(dones)):
            G = r + (0 if d else self.gamma * G)
            returns.insert(0, G)
        episode_return = returns[0]
        # In on-policy mode, to estimate CVaR of returns from a **state**, we might simulate multiple trajectories.
        # For simplicity, we'll treat each episode as one sample of return from the initial state distribution.
        # The CVaR of initial state returns across episodes will be estimated after collecting multiple episodes.
        return returns, episode_return
    def update(self, trajectories, next_value_estimates=None):
        """
        Update actor and critic using a batch of trajectories.
        `trajectories` is a list of dicts, each with keys: 'states', 'actions', 'rewards', 'dones'.
        `next_value_estimates` (optional): estimated V values for final states if episodes truncated.
        """
        # Flatten all trajectories into arrays
        states = np.concatenate([traj['states'] for traj in trajectories], axis=0)
        actions = np.concatenate([traj['actions'] for traj in trajectories], axis=0)
        rewards = np.concatenate([traj['rewards'] for traj in trajectories], axis=0)
        dones   = np.concatenate([traj['dones']   for traj in trajectories], axis=0)
        # Convert to tensors
        state_tensors = torch.tensor(states, dtype=torch.float32, device=self.device)
        action_tensors = torch.tensor(actions, dtype=torch.float32, device=self.device)
        rewards_tensors = torch.tensor(rewards, dtype=torch.float32, device=self.device)
        dones_tensors = torch.tensor(dones.astype(np.float32), dtype=torch.float32, device=self.device)

        # Compute value predictions for all states
        values = self.critic(state_tensors)
        # Compute *advantages* and *targets* for CVaR critic
        # We use Generalized Advantage Estimation (GAE) adapted for CVaR.
        # For simplicity, use 1-step advantage = Q - V with Q computed via one-step lookahead.
        with torch.no_grad():
            # Estimate next state values
            next_values = cvar_utils.get_next_values(self.critic, trajectories, self.device, self.alpha)
            # Bellman targets for value network (Equation 6)
            # y_i = F(x_i, u_i) + gamma * ρ̂_α[V(x_{i+1})]
            # Here, -reward = cost = F(x_i, u_i) by design of reward wrapper.
            costs = -rewards_tensors  # since reward = -cost
            targets = costs + self.gamma * next_values
            # Advantage for actor = Q - V = (cost + gamma * CVaR[V(next)]) - V(current)
            advantages = targets - values
            # Normalize advantages (optional for stability)
            advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
        # Critic update: minimize 1/2 * (V(s) - target)^2
        value_loss = 0.5 * (values - targets).pow(2).mean()
        self.critic_opt.zero_grad()
        value_loss.backward()
        nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm)
        self.critic_opt.step()

        # Actor update: maximize advantage (risk-sensitive PG).
        # Using PPO surrogate objective with clipping.
        # Prepare old log probabilities for current policy (for PPO).
        mean, std = self.actor(state_tensors)
        dist = torch.distributions.Normal(mean, std)
        log_probs = dist.log_prob(action_tensors).sum(dim=-1)
        old_log_probs = log_probs.detach()  # treat current policy as old for this update (or store from rollout)
        # Compute new log_probs (the policy is updated iteratively within this batch)
        dist_new = torch.distributions.Normal(mean, std)  # after critic update, actor unchanged yet
        new_log_probs = dist_new.log_prob(action_tensors).sum(dim=-1)
        ratio = torch.exp(new_log_probs - old_log_probs)
        # Policy loss (negative PPO objective): we want to maximize A * ratio, so minimize -A * ratio
        policy_loss1 = -advantages * ratio
        policy_loss2 = -advantages * torch.clamp(ratio, 1.0 - self.clip_ratio, 1.0 + self.clip_ratio)
        policy_loss = torch.max(policy_loss1, policy_loss2).mean()
        # Optional entropy bonus
        entropy = dist_new.entropy().sum(dim=-1).mean()
        policy_loss = policy_loss - self.entropy_coef * entropy
        # Off-policy regularization: KL penalty to previous policy (Bregman divergence)
        if self.kl_penalty > 0.0 and self.last_actor_params is not None:
            with torch.no_grad():
                # Compute KL( pi_old || pi_new ) approximately
                old_mean, old_std = self.last_actor_params
                old_dist = torch.distributions.Normal(old_mean, old_std)
            kl_div = torch.distributions.kl_divergence(old_dist, dist_new).mean()
            policy_loss = policy_loss + self.kl_penalty * kl_div

        # Optimize actor
        self.actor_opt.zero_grad()
        policy_loss.backward()
        nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
        self.actor_opt.step()

        # Store current actor dist parameters for next update's KL (if off-policy reuse)
        with torch.no_grad():
            mean_detached, std_detached = self.actor(state_tensors)
            self.last_actor_params = (mean_detached.detach(), std_detached.detach())

        return value_loss.item(), policy_loss.item(), entropy.item(), advantages.mean().item()
