import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
class PPOBinaryActorCritic(nn.Module):
    def __init__(self, obs_dim):
        super().__init__()
        self.shared = nn.Sequential(
            nn.Linear(obs_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU()
        )
        self.policy_head = nn.Linear(128, 1)   
        self.value_head = nn.Linear(128, 1)    

    def forward(self, obs):
        features = self.shared(obs)
        logit = self.policy_head(features)         # shape [B, 1]
        value = self.value_head(features).squeeze(-1)  # shape [B]
        return logit, value

    def act(self, obs):
        logit, value = self.forward(obs)
        probs = torch.sigmoid(logit)
        dist = torch.distributions.Bernoulli(probs)
        action = dist.sample()
        log_prob = dist.log_prob(action)
        return action.squeeze(-1), log_prob.squeeze(-1), value
    

def ppo_update(model, optimizer, ppo_dataset, clip_eps=0.2, value_coef=0.5, entropy_coef=0.01,
               epochs=4, batch_size=512*4*32, device="cuda" if torch.cuda.is_available() else "cpu"):

    states = ppo_dataset["states"].to(device)
    actions = ppo_dataset["actions"].to(device)
    old_log_probs = ppo_dataset["log_probs"].to(device)
    advantages = ppo_dataset["advantages"].to(device)
    returns = ppo_dataset["returns"].to(device)

    dataset = TensorDataset(states, actions, old_log_probs, advantages, returns)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    model.train()

    epoch_losses = []

    for epoch in range(epochs):
        total_policy_loss = 0
        total_value_loss = 0
        total_entropy = 0

        for batch_states, batch_actions, batch_old_log_probs, batch_advantages, batch_returns in dataloader:
            logits, values = model(batch_states)  # logits shape [B,1], values shape [B]

            probs = torch.sigmoid(logits)
            dist = torch.distributions.Bernoulli(probs=probs)

            batch_actions = batch_actions.unsqueeze(-1).float()

            new_log_probs = dist.log_prob(batch_actions).squeeze(-1)  # [B]
            entropy = dist.entropy().mean()

            ratio = torch.exp(new_log_probs - batch_old_log_probs)

            clipped_ratio = torch.clamp(ratio, 1 - clip_eps, 1 + clip_eps)
            policy_loss = -torch.min(ratio * batch_advantages, clipped_ratio * batch_advantages).mean()

            value_loss = F.mse_loss(values, batch_returns)

            loss = policy_loss + value_coef * value_loss - entropy_coef * entropy

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_policy_loss += policy_loss.item()
            total_value_loss += value_loss.item()
            total_entropy += entropy.item()

        n_batches = len(dataloader)
        avg_policy_loss = total_policy_loss / n_batches
        avg_value_loss = total_value_loss / n_batches
        avg_entropy = total_entropy / n_batches

        print(f"Epoch {epoch+1} PPO update: policy_loss={avg_policy_loss:.4f}, "
              f"value_loss={avg_value_loss:.4f}, entropy={avg_entropy:.4f}")

        epoch_losses.append((avg_policy_loss, avg_value_loss, avg_entropy))

        if epoch % 10 == 0:
            torch.save(model.module.state_dict(), f"results/models/ppo_model_epoch_{epoch}.pth")
            print(f"Model saved at epoch {epoch}")

    return epoch_losses
