"""GRU and MLP policy networks."""

import torch
import torch.nn as nn
import numpy as np
from torch.nn.utils.rnn import pad_sequence


class MLPPolicy(nn.Module):
    """Memoryless MLP baseline; processes each timestep independently."""

    def __init__(self, obs_dim, hidden_dim=128, n_actions=4, dropout=0.1):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.n_actions = n_actions
        self.net = nn.Sequential(
            nn.Linear(obs_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, n_actions),
        )

    def forward(self, obs_seq):
        return self.net(obs_seq)


def train_mlp_policy(policy, dataset, n_epochs=100, lr=3e-4, batch_size=32,
                     print_every=20):
    """Train MLP policy via cross-entropy distillation."""
    optimizer = torch.optim.Adam(policy.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs)
    history = {"loss": []}

    for epoch in range(n_epochs):
        policy.train()
        indices = np.random.permutation(len(dataset))
        epoch_loss, n_steps = 0.0, 0

        for i in range(0, len(indices), batch_size):
            batch_idx = indices[i : i + batch_size]
            batch = [dataset[j] for j in batch_idx]

            obs = pad_sequence([b["obs"] for b in batch], batch_first=True)
            targets = pad_sequence([b["targets"] for b in batch], batch_first=True)
            lengths = [len(b["obs"]) for b in batch]

            logits = policy(obs)
            log_probs = torch.log_softmax(logits, dim=-1)

            loss = 0.0
            count = 0
            for j, L in enumerate(lengths):
                ce = -(targets[j, :L] * log_probs[j, :L]).sum(dim=-1)
                loss = loss + ce.sum()
                count += L
            loss = loss / count

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(policy.parameters(), max_norm=1.0)
            optimizer.step()

            epoch_loss += loss.item() * count
            n_steps += count

        scheduler.step()
        avg_loss = epoch_loss / n_steps
        history["loss"].append(avg_loss)

        if epoch % print_every == 0:
            print(f"Epoch {epoch}: CE loss = {avg_loss:.4f}")

    return policy, history


class GRUPolicy(nn.Module):

    def __init__(self, obs_dim, hidden_dim=128, n_actions=4, dropout=0.1):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.n_actions = n_actions
        self.obs_encoder = nn.Sequential(
            nn.Linear(obs_dim, hidden_dim),
            nn.ReLU(),
        )
        self.gru = nn.GRU(hidden_dim, hidden_dim, batch_first=True)
        self.policy_head = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, n_actions),
        )

    def forward(self, obs_seq, hidden=None):
        B, T, _ = obs_seq.shape
        encoded = self.obs_encoder(obs_seq)
        if hidden is None:
            hidden = torch.zeros(1, B, self.hidden_dim, device=obs_seq.device)
        gru_out, hidden = self.gru(encoded, hidden)
        logits = self.policy_head(gru_out)
        return logits, hidden

    def get_recurrent_output(self, obs_seq):
        encoded = self.obs_encoder(obs_seq)
        gru_out, _ = self.gru(encoded)
        return gru_out


def train_gru_policy(policy, dataset, n_epochs=100, lr=3e-4, batch_size=32,
                     print_every=20):
    """Train GRU policy via cross-entropy distillation."""
    optimizer = torch.optim.Adam(policy.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs)
    history = {"loss": []}

    for epoch in range(n_epochs):
        policy.train()
        indices = np.random.permutation(len(dataset))
        epoch_loss, n_steps = 0.0, 0

        for i in range(0, len(indices), batch_size):
            batch_idx = indices[i : i + batch_size]
            batch = [dataset[j] for j in batch_idx]

            obs = pad_sequence([b["obs"] for b in batch], batch_first=True)
            targets = pad_sequence([b["targets"] for b in batch], batch_first=True)
            lengths = [len(b["obs"]) for b in batch]

            logits, _ = policy(obs)
            log_probs = torch.log_softmax(logits, dim=-1)

            loss = 0.0
            count = 0
            for j, L in enumerate(lengths):
                ce = -(targets[j, :L] * log_probs[j, :L]).sum(dim=-1)
                loss = loss + ce.sum()
                count += L
            loss = loss / count

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(policy.parameters(), max_norm=1.0)
            optimizer.step()

            epoch_loss += loss.item() * count
            n_steps += count

        scheduler.step()
        avg_loss = epoch_loss / n_steps
        history["loss"].append(avg_loss)

        if epoch % print_every == 0:
            print(f"Epoch {epoch}: CE loss = {avg_loss:.4f}")

    return policy, history
