import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque
import torch.nn.functional as F


class ActorCritic(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()

        # Convolutional layers to process the grid part of the state
        self.conv = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        # Linear layers to process the non-grid part of the state
        self.linear = nn.Sequential(
            nn.Linear(input_dim - 256, 64),
            nn.ReLU()
        )

        # Shared fully connected layers
        self.shared = nn.Sequential(
            nn.Linear(32 * 4 * 4 + 64, 256),
            nn.ReLU(),
            nn.LayerNorm(256),
            nn.Linear(256, 128),
            nn.ReLU()
        )

        # Actor and critic heads
        self.actor = nn.Linear(128, output_dim)
        self.critic = nn.Linear(128, 1)

    def forward(self, x):
        # Split grid part (last 256 elements) and non-grid part
        grid_part = x[:, -256:].view(-1, 1, 16, 16)
        non_grid_part = x[:, :-256]

        # Process each part separately
        grid_features = self.conv(grid_part).view(grid_part.size(0), -1)
        non_grid_features = self.linear(non_grid_part)

        # Combine features
        combined = torch.cat((grid_features, non_grid_features), dim=1)
        shared_out = self.shared(combined)

        logits = self.actor(shared_out)
        value = self.critic(shared_out)
        return F.softmax(logits, dim=-1), value


class PPOAgent:
    def __init__(self, input_dim, output_dim, lr=1e-3, gamma=0.99,
                 eps_clip=0.2, num_epochs=16, batch_size=128):
        self.model = ActorCritic(input_dim, output_dim)
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
        self.gamma = gamma
        self.eps_clip = eps_clip
        self.memory = deque(maxlen=4000)
        self.num_epochs = num_epochs
        self.batch_size = batch_size

    def select_action(self, state):
        """Select an action according to the policy"""
        state = torch.tensor(state, dtype=torch.float32).view(1, -1)
        probs, _ = self.model(state)
        action = torch.multinomial(probs, 1).item()
        return action, probs.squeeze(0)[action].item()

    def store_transition(self, transition):
        """Store a transition (state, action, reward, next_state, old_prob)"""
        self.memory.append(transition)

    def train(self):
        """Train PPO with collected trajectories"""
        if len(self.memory) < self.batch_size:
            return

        transitions = list(self.memory)
        states, actions, rewards, next_states, old_probs = zip(*transitions)

        states = torch.tensor(states, dtype=torch.float32).view(len(states), -1)
        actions = torch.tensor(actions, dtype=torch.long)
        rewards = torch.tensor(rewards, dtype=torch.float32)
        old_probs = torch.tensor(old_probs, dtype=torch.float32)

        # Compute discounted returns
        returns = []
        G = 0
        for r in reversed(rewards):
            G = r + self.gamma * G
            returns.insert(0, G)
        returns = torch.tensor(returns, dtype=torch.float32)

        # PPO training loop
        for _ in range(self.num_epochs):
            indices = torch.randperm(len(states))
            for start in range(0, len(states), self.batch_size):
                end = start + self.batch_size
                batch_idx = indices[start:end]

                batch_states = states[batch_idx]
                batch_actions = actions[batch_idx]
                batch_returns = returns[batch_idx]
                batch_old_probs = old_probs[batch_idx]

                probs, values = self.model(batch_states)
                values = values.squeeze()
                new_probs = probs.gather(1, batch_actions.unsqueeze(1)).squeeze()

                advantages = batch_returns - values.detach()
                advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

                ratio = new_probs / batch_old_probs
                surr1 = ratio * advantages
                surr2 = torch.clamp(ratio, 1 - self.eps_clip, 1 + self.eps_clip) * advantages

                policy_loss = -torch.min(surr1, surr2).mean()
                value_loss = 0.5 * (batch_returns - values).pow(2).mean()
                loss = policy_loss + value_loss

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

        self.memory.clear()
