"""
A2C-GNN End-to-End
------------------
Actor outputs both paxAction and rebAction for all edges.
"""

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.distributions import Dirichlet
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from collections import namedtuple

SavedAction = namedtuple("SavedAction", ["log_prob", "value"])
eps = np.finfo(np.float32).eps.item()


#########################################
############## PARSER ###################
#########################################
class GNNParser:
    def __init__(self, env, T=6, scale_factor=0.01):
        self.env = env
        self.T = T
        self.s = scale_factor

    def parse_obs(self, obs, edge_index):
        acc, t, dacc, demand = obs
        # Extract simple region-level features (acc + dacc)
        features = []
        for n in self.env.region:
            acc_now = acc[n][t]
            future_arr = sum([dacc[n][tau] for tau in range(t, t + self.T) if tau in dacc[n]])
            demand_out = sum([demand[n, j][t] if (n, j) in demand and t in demand[n, j] else 0 for j in self.env.region])
            features.append([acc_now * self.s, future_arr * self.s, demand_out * self.s])

        x = torch.tensor(features, dtype=torch.float32)
        return Data(x=x, edge_index=edge_index)


#########################################
############## ACTOR ####################
#########################################
class GNNActor(nn.Module):
    def __init__(self, in_channels, hidden_size):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_size)
        self.lin1 = nn.Linear(hidden_size, 64)
        # Output 2 dimensions for pax / reb
        self.lin2 = nn.Linear(64, 2)

    def forward(self, data):
        h = F.relu(self.conv1(data.x, data.edge_index))
        h = F.relu(self.lin1(h))
        out = self.lin2(h)  # [|V|, 2]
        return out


#########################################
############## CRITIC ###################
#########################################
class GNNCritic(nn.Module):
    def __init__(self, in_channels, hidden_size):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_size)
        self.lin1 = nn.Linear(hidden_size, 64)
        self.lin2 = nn.Linear(64, 1)

    def forward(self, data):
        h = F.relu(self.conv1(data.x, data.edge_index))
        g = torch.sum(h, dim=0)  # Global graph pooling
        g = F.relu(self.lin1(g))
        v = self.lin2(g)
        return v


#########################################
############## A2C ######################
#########################################
class A2C(nn.Module):
    def __init__(self, env, input_size, edge_index, hidden_size=64, device="cpu"):
        super().__init__()
        self.env = env
        self.device = device
        self.edge_index = edge_index

        self.actor = GNNActor(input_size, hidden_size).to(device)
        self.critic = GNNCritic(input_size, hidden_size).to(device)
        self.parser = GNNParser(env)

        self.optimizer_actor = torch.optim.Adam(self.actor.parameters(), lr=1e-3)
        self.optimizer_critic = torch.optim.Adam(self.critic.parameters(), lr=1e-3)

        self.saved_actions = []
        self.rewards = []

    def forward(self, obs):
        data = self.parser.parse_obs(obs, self.edge_index).to(self.device)
        a_out = F.softplus(self.actor(data)) + 1e-6  # [|V|, 2]
        v_out = self.critic(data)
        return a_out, v_out

    def select_action(self, obs):
        a_out, value = self.forward(obs)
        
        # Sample pax/reb for each node separately (each node has independent Dirichlet distribution)
        actions = []
        log_probs = []
        
        for i in range(len(self.env.region)):
            # Each node's [pax_concentration, reb_concentration]
            node_concentration = a_out[i]
            m = Dirichlet(node_concentration)
            node_action = m.sample()  # [pax_prob, reb_prob] sum to 1
            actions.append(node_action)
            log_probs.append(m.log_prob(node_action))
        
        action = torch.stack(actions)  # [num_nodes, 2]
        total_log_prob = torch.stack(log_probs).sum()
        
        self.saved_actions.append(SavedAction(total_log_prob, value))

        # action is [pax_ratio, reb_ratio] for each node, sum to 1
        action_np = action.cpu().detach().numpy()
        
        # Get available vehicles for each node
        q = np.array([self.env.acc[i][self.env.time] for i in self.env.region])
        
        # Allocate vehicles to pax and reb for each node, maximize utilization
        node_pax = np.maximum(1, np.rint(action_np[:, 0] * q * 0.9)).astype(int)  # Use 90% of vehicles for passengers
        node_reb = np.maximum(0, np.rint(action_np[:, 1] * q * 0.4)).astype(int)  # Use 40% of vehicles for rebalancing
        
        # Ensure not exceeding available vehicles
        total_used = node_pax + node_reb
        for i in range(len(self.env.region)):
            if total_used[i] > q[i]:
                scale = q[i] / total_used[i] if total_used[i] > 0 else 0
                node_pax[i] = int(node_pax[i] * scale)
                node_reb[i] = int(node_reb[i] * scale)
        
        # Convert to edge-level actions
        paxAction = np.zeros(len(self.env.edges), dtype=int)
        rebAction = np.zeros(len(self.env.edges), dtype=int)
        
        # To avoid zero values from integer division, we use round-robin allocation strategy
        for i in range(len(self.env.region)):
            # Find all edges starting from node i
            out_edges = [(idx, j) for idx, (ii, j) in enumerate(self.env.edges) if ii == i]
            
            if len(out_edges) > 0 and (node_pax[i] > 0 or node_reb[i] > 0):
                # Distribute vehicles as evenly as possible to outgoing edges
                for k, (idx, j) in enumerate(out_edges):
                    paxAction[idx] = node_pax[i] // len(out_edges) + (1 if k < node_pax[i] % len(out_edges) else 0)
                    rebAction[idx] = node_reb[i] // len(out_edges) + (1 if k < node_reb[i] % len(out_edges) else 0)
        
        return paxAction, rebAction

    def training_step(self, gamma=0.99):
        R = 0
        returns = []
        for r in self.rewards[::-1]:
            R = r + gamma * R
            returns.insert(0, R)
        returns = torch.tensor(returns, dtype=torch.float32, device=self.device)
        if len(returns) > 1:
            returns = (returns - returns.mean()) / (returns.std() + eps)

        policy_losses, value_losses = [], []
        for (log_prob, value), R in zip(self.saved_actions, returns):
            advantage = R - value.item()
            policy_losses.append(-log_prob * advantage)
            value_losses.append(F.smooth_l1_loss(value, torch.tensor([R]).to(self.device)))

        self.optimizer_actor.zero_grad()
        torch.stack(policy_losses).sum().backward()
        self.optimizer_actor.step()

        self.optimizer_critic.zero_grad()
        torch.stack(value_losses).sum().backward()
        self.optimizer_critic.step()

        self.saved_actions, self.rewards = [], []

    def save_checkpoint(self, path="ckpt/a2c_gnn.pth"):
        torch.save(self.state_dict(), path)
    
    def load_checkpoint(self, path="ckpt/a2c_gnn.pth"):
        self.load_state_dict(torch.load(path, map_location=self.device))
