
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from collections import namedtuple
import pandas as pd
import torch.nn.init as init

SavedAction = namedtuple('SavedAction', ['log_prob', 'value'])
Args = namedtuple('Args', ('render', 'gamma', 'log_interval'))
args = Args(render=True, gamma=0.97, log_interval=10)

# -------------------------
# 正交初始化
# -------------------------
def init_weights_orthogonal(m):
    if isinstance(m, nn.Linear):
        gain = nn.init.calculate_gain("relu")
        init.orthogonal_(m.weight, gain=gain)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0.0)
    elif isinstance(m, (nn.Conv1d, nn.Conv2d)):
        gain = nn.init.calculate_gain("relu")
        init.orthogonal_(m.weight, gain=gain)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0.0)

# -------------------------
# Parser
# -------------------------
class GNNParser:
    def __init__(self, env, T=10, scale_factor=0.01, csv_path="data/network_data.csv"):
        super().__init__()
        self.env = env
        self.T = T
        self.s = scale_factor

        df = pd.read_csv(csv_path)
        self.demand_static = {(int(row['i']), int(row['j'])): float(row['demand']) for _, row in df.iterrows()}
        self.price_static  = {(int(row['i']), int(row['j'])): float(row['price'])  for _, row in df.iterrows()}

    def parse_obs(self, obs, edge_index):
        # 当前(t+1)车辆数（缺失置0）
        x_cur = [obs[0][n].get(self.env.time + 1, 0.0) * self.s for n in self.env.region]

        # 未来 T 步 (q + dacc)
        fut = []
        for t in range(self.env.time + 1, self.env.time + self.T + 1):
            fut.append([
                (obs[0][n].get(t, 0.0) +
                 (self.env.dacc[n].get(t, 0.0) if isinstance(self.env.dacc[n], dict) else 0.0)) * self.s
                for n in self.env.region
            ])

        # demand*price（优先用 env 的动态数据）
        demand_feat = []
        for t in range(self.env.time + 1, self.env.time + self.T + 1):
            per_i = []
            for i in self.env.region:
                acc = 0.0
                for j in self.env.region:
                    if i == j:
                        continue
                    # demand
                    if hasattr(self.env, "demand") and (i, j) in self.env.demand and t in self.env.demand[(i, j)]:
                        d_ijt = float(self.env.demand[(i, j)][t])
                    else:
                        d_ijt = self.demand_static.get((i, j), 0.0)
                    # price
                    if hasattr(self.env, "price") and (i, j) in self.env.price and t in self.env.price[(i, j)]:
                        p_ijt = float(self.env.price[(i, j)][t])
                    else:
                        p_ijt = self.price_static.get((i, j), 0.0)
                    acc += d_ijt * p_ijt * self.s
                per_i.append(acc)
            demand_feat.append(per_i)

        x = torch.cat((
            torch.tensor(x_cur).view(1, 1, self.env.nregion).float(),
            torch.tensor(fut).view(1, self.T, self.env.nregion).float(),
            torch.tensor(demand_feat).view(1, self.T, self.env.nregion).float()
        ), dim=1).squeeze(0).view(1 + 2*self.T, self.env.nregion).T

        return Data(x, edge_index)

# -------------------------
# Actor（边流量输出）
# -------------------------
class GNNActor(nn.Module):
    def __init__(self, in_channels, hidden_dim, edge_index, num_edges):
        super().__init__()
        self.edge_index = edge_index  # torch.LongTensor [2, |E|]
        self.num_edges = num_edges

        self.conv1 = GCNConv(in_channels, hidden_dim)
        self.edge_mlp = nn.Sequential(
            nn.Linear(2*hidden_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 2),    # (pax, reb)
            nn.Softplus()        # ≥0
        )

    def forward(self, data):
        h = F.relu(self.conv1(data.x, data.edge_index))    # [N, H]
        src, dst = self.edge_index
        e_feat = torch.cat([h[src], h[dst]], dim=-1)       # [|E|, 2H]
        flows = self.edge_mlp(e_feat)                      # [|E|, 2]
        return flows

# -------------------------
# Critic
# -------------------------
class GNNCritic(nn.Module):
    def __init__(self, in_channels, hidden_dim):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_dim)
        self.lin1 = nn.Linear(hidden_dim, 32)
        self.lin2 = nn.Linear(32, 32)
        self.lin3 = nn.Linear(32, 1)

    def forward(self, data):
        h = F.relu(self.conv1(data.x, data.edge_index))
        g = torch.sum(h, dim=0, keepdim=True)   # 简单池化
        g = F.relu(self.lin1(g))
        g = F.relu(self.lin2(g))
        v = self.lin3(g)
        return v.squeeze(0)

# -------------------------
# A2C Agent
# -------------------------
class A2C(nn.Module):
    def __init__(self, env, input_size, edge_index, eps=np.finfo(np.float32).eps.item(),
                 hidden_size=64, device=torch.device("cpu"), csv_path="data/network_data.csv"):
        super().__init__()
        self.env = env
        self.eps = eps
        self.input_size = input_size
        self.device = device
        self.edge_index = edge_index
        self.num_edges = edge_index.shape[1]

        # 静态 price/demand 备份（select_action 也会用到）
        df = pd.read_csv(csv_path)
        self.price_static  = {(int(row['i']), int(row['j'])): float(row['price'])  for _, row in df.iterrows()}
        self.demand_static = {(int(row['i']), int(row['j'])): float(row['demand']) for _, row in df.iterrows()}

        self.actor = GNNActor(self.input_size, hidden_size, edge_index, self.num_edges)
        self.critic = GNNCritic(self.input_size, hidden_size)
        self.obs_parser = GNNParser(self.env)

        # 正交初始化
        self.apply(init_weights_orthogonal)

        self.optimizers = self.configure_optimizers()
        self.saved_actions = []
        self.rewards = []
        self.to(self.device)

    # ---------- helpers ----------
    def parse_obs(self, obs):
        return self.obs_parser.parse_obs(obs, self.edge_index)

    def _q_available(self, t):
        """用 t+1 的 q（若没有回退 t）"""
        N = self.env.nregion
        q = []
        for i in range(N):
            if (t + 1) in self.env.acc[i]:
                q.append(float(self.env.acc[i][t + 1]))
            else:
                q.append(float(self.env.acc[i].get(t, 0.0)))
        return torch.tensor(q, dtype=torch.float32, device=self.device)

    def _price_at_t(self, t):
        E = len(self.env.edges)
        price_t = torch.zeros(E, dtype=torch.float32, device=self.device)
        for e_idx, (i, j) in enumerate(self.env.edges):
            if hasattr(self.env, "price") and (i, j) in self.env.price and t in self.env.price[(i, j)]:
                price_t[e_idx] = float(self.env.price[(i, j)][t])
            else:
                price_t[e_idx] = self.price_static.get((i, j), 0.0)
        return price_t

    def _demand_cap_at_t(self, t):
        E = len(self.env.edges)
        cap_t = torch.zeros(E, dtype=torch.float32, device=self.device)
        for e_idx, (i, j) in enumerate(self.env.edges):
            if hasattr(self.env, "demand") and (i, j) in self.env.demand and t in self.env.demand[(i, j)]:
                cap_t[e_idx] = float(self.env.demand[(i, j)][t])
            else:
                cap_t[e_idx] = self.demand_static.get((i, j), 0.0)
        return cap_t

    # ---------- forward ----------
    def forward(self, obs):
        data = self.parse_obs(obs).to(self.device)
        flows = self.actor(data)
        value = self.critic(data)
        return flows, value

    # ---------- action ----------
    def select_action(self, obs, alpha_gain=3.0, alpha_logit=1.0):
        """
        alpha_gain: 收益权重（对 demand*price），越大越“收益优先”
        alpha_logit: 策略权重（对 actor logits）
        """
        flows, value = self.forward(obs)
        logits_pax = flows[:, 0]
        logits_reb = flows[:, 1]

        t = self.env.time
        N = self.env.nregion
        E = len(self.env.edges)

        # 可用车辆（优先 t+1）
        q = self._q_available(t)  # [N]

        # 需求上限 & 价格（当前时刻）
        demand_cap = self._demand_cap_at_t(t)  # [E]
        price_t = self._price_at_t(t)          # [E]
        gain = demand_cap * price_t            # “收益” proxy

        # 每个节点的出边索引
        out_index_by_i = [[] for _ in range(N)]
        for e_idx, (i, j) in enumerate(self.env.edges):
            out_index_by_i[i].append(e_idx)

        # Step 1: 乘客分配（收益贪心 + 策略调制），先吃满需求
        pax = torch.zeros(E, dtype=torch.float32, device=self.device)
        for i in range(N):
            idxs = out_index_by_i[i]
            if not idxs:
                continue

            scores = alpha_gain * gain[idxs] + alpha_logit * logits_pax[idxs]
            order = torch.argsort(scores, descending=True)

            remain_i = q[i]
            for k in order:
                e = idxs[k.item()]
                if remain_i <= 1e-9:
                    break
                cap_e = demand_cap[e]
                if cap_e <= 1e-9:
                    continue
                alloc = torch.minimum(remain_i, cap_e)
                pax[e] = alloc
                remain_i -= alloc

        # Step 2: 剩余车辆用于重定位（策略 softmax）
        used_pax_out = torch.zeros(N, dtype=torch.float32, device=self.device)
        for e_idx, (i, j) in enumerate(self.env.edges):
            used_pax_out[i] += pax[e_idx]
        remain = torch.clamp(q - used_pax_out, min=0.0)

        reb = torch.zeros(E, dtype=torch.float32, device=self.device)
        for i in range(N):
            idxs = out_index_by_i[i]
            if not idxs or remain[i] <= 1e-9:
                continue
            w = torch.softmax(logits_reb[idxs], dim=0)
            reb[idxs] = remain[i] * w

        # 数值 & 整数化（round）
        pax_np = np.rint(np.clip(pax.detach().cpu().numpy(), 0, None)).astype(int)
        reb_np = np.rint(np.clip(reb.detach().cpu().numpy(), 0, None)).astype(int)

        # surrogate log_prob（按节点求和）
        logp = 0.0
        for i in range(N):
            idxs = out_index_by_i[i]
            if idxs:
                logp += torch.log_softmax(logits_pax[idxs], dim=0).sum()
                logp += torch.log_softmax(logits_reb[idxs], dim=0).sum()

        self.saved_actions.append(SavedAction(logp, value))
        return pax_np, reb_np

    # ---------- train step ----------
    def training_step(self):
        R = 0.0
        policy_losses, value_losses, returns = [], [], []

        for r in self.rewards[::-1]:
            R = r + args.gamma * R
            returns.insert(0, R)

        returns = torch.tensor(returns, dtype=torch.float32, device=self.device)
        if returns.numel() > 1:
            returns = (returns - returns.mean()) / (returns.std() + self.eps)

        for (log_prob, value), R in zip(self.saved_actions, returns):
            advantage = R - value.detach().squeeze()
            policy_losses.append(-log_prob * advantage)
            value_losses.append(F.smooth_l1_loss(value.squeeze(), R))

        self.optimizers['a_optimizer'].zero_grad()
        if policy_losses:
            torch.stack(policy_losses).sum().backward()
        self.optimizers['a_optimizer'].step()

        self.optimizers['c_optimizer'].zero_grad()
        if value_losses:
            torch.stack(value_losses).sum().backward()
        self.optimizers['c_optimizer'].step()

        self.rewards.clear()
        self.saved_actions.clear()

    # ---------- optim ----------
    def configure_optimizers(self):
        return {
            'a_optimizer': torch.optim.Adam(self.actor.parameters(), lr=1e-3),
            'c_optimizer': torch.optim.Adam(self.critic.parameters(), lr=1e-3)
        }

    # ---------- io ----------
    def save_checkpoint(self, path='ckpt.pth'):
        checkpoint = {'model': self.state_dict()}
        for k, opt in self.optimizers.items():
            checkpoint[k] = opt.state_dict()
        torch.save(checkpoint, path)

    def load_checkpoint(self, path='ckpt.pth'):
        checkpoint = torch.load(path, map_location=self.device)
        self.load_state_dict(checkpoint['model'])
        for k, opt in self.optimizers.items():
            self.optimizers[k].load_state_dict(checkpoint[k])