import torch
import torch.nn.functional as F
from Network import PolicyNet, ValueNet
from util import compute_advantage


class LICACritic(torch.nn.Module):
    """
    Simple feedforward critic for LICA.

    Inputs:
        - state: [B, state_dim]
        - action_rep: [B, n_agents, max_action_dim] (probabilities or one-hot)
    Output:
        - joint Q-value: [B, 1]
    """
    def __init__(self, state_dim: int, n_agents: int, max_action_dim: int, hidden_dim: int = 128):
        super().__init__()
        self.n_agents = n_agents
        self.max_action_dim = max_action_dim
        self.input_dim = state_dim + n_agents * max_action_dim

        self.fc1 = torch.nn.Linear(self.input_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = torch.nn.Linear(hidden_dim, 1)

    def forward(self, action_rep: torch.Tensor, state: torch.Tensor):
        x = torch.cat([state, action_rep.reshape(state.size(0), -1)], dim=1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)


class LICA:
    """
    Lightweight actor-critic model with centralized critic and independent policies.
    """
    def __init__(
        self,
        agent_num: int,
        state_dim_list: list[int],
        hidden_dim: int,
        action_num_list: list[int],
        actor_lr: float,
        critic_lr: float,
        epochs: int,
        eps: float,
        gamma: float,
        device: torch.device,
        entropy_coef: float = 0.01,
        sample_size=None,
        entropy_soft=None
    ) -> None:
        self.agent_num = agent_num
        self.device = device
        self.gamma = gamma
        self.epochs = epochs
        self.entropy_coef = entropy_coef

        self.actors = [
            PolicyNet(state_dim_list[i], hidden_dim, action_num_list[i]).to(device)
            for i in range(agent_num)
        ]
        self.actor_opts = [
            torch.optim.AdamW(actor.parameters(), lr=actor_lr) for actor in self.actors
        ]

        self.state_dim = sum(state_dim_list)
        self.max_action_dim = max(action_num_list)
        self.critic = LICACritic(self.state_dim, agent_num, self.max_action_dim, hidden_dim).to(device)
        self.target_critic = LICACritic(self.state_dim, agent_num, self.max_action_dim, hidden_dim).to(device)
        self.target_critic.load_state_dict(self.critic.state_dict())
        self.critic_opt = torch.optim.AdamW(self.critic.parameters(), lr=critic_lr)
        self.td_lambda = 0.95

    def take_action(self, state_list):
        actions = []
        with torch.no_grad():
            for i in range(self.agent_num):
                probs = self.actors[i](state_list[i].to(self.device))
                dist = torch.distributions.Categorical(probs)
                actions.append(dist.sample())
        return actions

    def _pad_actions(self, action_idx: torch.Tensor, action_dim: int):
        one_hot = F.one_hot(action_idx, num_classes=action_dim).float()
        if action_dim < self.max_action_dim:
            pad_size = self.max_action_dim - action_dim
            pad = torch.zeros(*one_hot.shape[:-1], pad_size, device=one_hot.device)
            one_hot = torch.cat([one_hot, pad], dim=-1)
        return one_hot

    def _pad_probs(self, probs: torch.Tensor, action_dim: int):
        if action_dim < self.max_action_dim:
            pad_size = self.max_action_dim - action_dim
            pad = torch.zeros(*probs.shape[:-1], pad_size, device=probs.device)
            probs = torch.cat([probs, pad], dim=-1)
        return probs

    def update(self, transition_dict: dict):
        states = transition_dict["states"]
        next_states = transition_dict["next_states"]
        actions_idx = transition_dict["actions"].long()
        rewards = transition_dict["rewards"].view(-1, 1)
        dones = transition_dict["dones"].view(-1, 1).float()

        T = states.size(0)
        states = states.to(self.device)
        next_states = next_states.to(self.device)
        rewards = rewards.to(self.device)
        dones = dones.to(self.device)

        padded_probs_list = []
        entropy_list = []
        for i in range(self.agent_num):
            probs_i = self.actors[i](states[:, i, :])
            entropy_i = -(probs_i * (probs_i + 1e-8).log()).sum(dim=-1)
            entropy_list.append(entropy_i)
            probs_i_pad = self._pad_probs(probs_i, probs_i.size(-1))
            padded_probs_list.append(probs_i_pad.unsqueeze(1))
        policy_tensor = torch.cat(padded_probs_list, dim=1)
        global_state = states.reshape(T, -1)

        q_tot_pred = self.critic(policy_tensor, global_state)
        actor_loss = -q_tot_pred.mean() - self.entropy_coef * torch.stack(entropy_list, dim=1).mean()

        for opt in self.actor_opts:
            opt.zero_grad()
        actor_loss.backward()
        for opt in self.actor_opts:
            torch.nn.utils.clip_grad_norm_(opt.param_groups[0]["params"], 40.0)
            opt.step()

        onehot_list = []
        for i in range(self.agent_num):
            a_i = actions_idx[:, i]
            inferred_dim = self.actors[i].fc2.out_features
            one_hot_i = self._pad_actions(a_i, inferred_dim)
            onehot_list.append(one_hot_i.unsqueeze(1))
        action_onehot = torch.cat(onehot_list, dim=1)

        q_t = self.critic(action_onehot[:-1], global_state[:-1])
        with torch.no_grad():
            q_next = self.target_critic(action_onehot[1:], global_state[1:])
        td_target = rewards[:-1] + self.gamma * q_next * (1 - dones[:-1])

        critic_loss = F.mse_loss(q_t, td_target)
        self.critic_opt.zero_grad()
        critic_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.critic.parameters(), 40.0)
        self.critic_opt.step()

        with torch.no_grad():
            tau = 0.01
            for p, p_targ in zip(self.critic.parameters(), self.target_critic.parameters()):
                p_targ.data.mul_(1 - tau).add_(tau * p.data)
