import torch
from torch import nn
import torch.nn.functional as F

# ----------------- VOCAB -----------------
AA_VOCAB = list("*ADEFGHIKLMNPQRSTVWY")   # |V| = 20
VOCAB_SIZE = len(AA_VOCAB)
EOS_TOKEN = "*"                          # used in RF
STOP_IDX = 0                    # STOP/EOS == 0


def sample_from_mixed_logits(logits, eps=0.0):
    # p_mix = (1 - eps)*softmax(logits) + eps*(1/A)
    probs = F.softmax(logits, dim=-1)
    probs = (1 - eps) * probs + eps / logits.size(-1)
    return torch.distributions.Categorical(probs=probs).sample()


class Sequences:

    def __init__(self, seq_size, batch_size, log_reward, eps, device='cpu', seed=None):
        self.seq_size = int(seq_size)
        self.batch_size = int(batch_size)
        self.device = device
        self.seed = seed
        self.eps = eps
        self._log_reward = log_reward

        if seed is not None:
            self.g = torch.Generator(device=self.device)
            self.g.manual_seed(seed)

        self.state = torch.zeros((self.batch_size, self.seq_size), dtype=torch.long, device=device)
        self.alive = torch.ones(self.batch_size, dtype=torch.bool, device=device)

    def log_reward(self):
        return self._log_reward(self.state)

    def reset(self, batch_size=None):
        self.batch_size = batch_size if batch_size is not None else self.batch_size
        self.state = torch.zeros((self.batch_size, self.seq_size), dtype=torch.long, device=self.device)
        self.alive = torch.ones(self.batch_size, dtype=torch.bool, device=self.device)

    def get_actions(self, logits, training=True):
        if self.seed is not None:
            new_seed = torch.randint(0, 2 ** 31, (1,), generator=self.g).item()
            torch.manual_seed(new_seed)
        eps = 0.0 if not training else self.eps
        if eps <= 0.0:
            return torch.distributions.Categorical(logits=logits).sample()
        return sample_from_mixed_logits(logits, eps=eps)


class Policy(nn.Module):


    def __init__(
            self,
            vocab_size: int = VOCAB_SIZE,  # V=20
            pad_id: int = 0,  # PAD/EOS = 0
            emb_dim: int = 64,  # dimensão de embedding
            window: int = 6,  # W (tamanho da janela)
            pos_dim: int = 16,  # dimensão do positional encoding (par)
            hidden: int = 128,  # 1 camada oculta de 128 (paper)
            force_stop_on_full: bool = True  # força STOP se já atingiu seq_size
    ):
        super().__init__()
        assert pos_dim % 2 == 0, "pos_dim deve ser par (PE sinusoidal)"
        self.vocab_size = vocab_size
        self.pad_id = pad_id
        self.emb_dim = emb_dim
        self.window = window
        self.pos_dim = pos_dim
        self.hidden = hidden
        self.force_stop_on_full = force_stop_on_full
        self.emb = nn.Embedding(vocab_size, emb_dim, padding_idx=pad_id)
        self.fc1 = nn.Linear(window * emb_dim + pos_dim, hidden)
        self.fc2 = nn.Linear(hidden, vocab_size)

    @staticmethod
    def sinusoidal_pe(t: torch.Tensor, d: int) -> torch.Tensor:

        B = t.size(0)
        device = t.device
        t = t.float().unsqueeze(1)  # [B, 1]
        i = torch.arange(0, d // 2, device=device).float()  # [d/2]
        div = torch.pow(10000.0, (2.0 * i) / d)  # [d/2]
        angle = t / div.unsqueeze(0)  # [B, d/2]
        pe = torch.zeros(B, d, device=device)
        pe[:, 0::2] = torch.sin(angle)
        pe[:, 1::2] = torch.cos(angle)
        return pe

    def forward(self, states: torch.Tensor) -> torch.Tensor:

        B, L = states.shape
        device = states.device
        states = states.long()

        lengths = (states != self.pad_id).sum(dim=1)  # [B]

        offsets = torch.arange(-self.window, 0, device=device)  # [-W,...,-1]
        idx = (lengths.unsqueeze(1) + offsets).clamp_(min=0, max=L - 1)  # [B, W]
        win = states.gather(1, idx)  # [B, W]

        E = self.emb(win)  # [B, W, D]
        feat_seq = E.reshape(B, self.window * self.emb_dim)  # [B, W*D]

        pe = self.sinusoidal_pe(lengths, self.pos_dim)  # [B, pos_dim]
        x = torch.cat([feat_seq, pe], dim=1)  # [B, W*D + pos_dim]
        x = F.relu(self.fc1(x))
        logits = self.fc2(x)  # [B, V]

        if self.force_stop_on_full and L > 0:
            full = (lengths >= L)
            if full.any():
                logits = logits.clone()
                logits[full, 1:] = -1e9
        return logits


class DummyPolicy(nn.Module):

    def __init__(self, vocab_size: int, mode: str = "randn",
                 scale: float = 1.0, seed: int | None = None,
                 learnable_dummy: bool = False, device: str = "cpu"):
        super().__init__()
        self.vocab_size = int(vocab_size)
        self.mode = mode
        self.scale = float(scale)
        self.device = torch.device(device)

        if learnable_dummy:
            self.dummy = nn.Parameter(torch.zeros((), device=self.device))
        else:
            self.register_buffer("_dummy_buf", torch.zeros((), device=self.device))

        self.gen = None
        if seed is not None:
            self.gen = torch.Generator(device=self.device)
            self.gen.manual_seed(int(seed))

        self.register_buffer("_fixed_logits", torch.zeros(1, self.vocab_size, device=self.device))

    def forward(self, states: torch.Tensor) -> torch.Tensor:
        B = states.shape[0]
        dev = states.device

        if self.mode == "zeros":
            # Softmax -> uniforme. Nada de RNG, nada de alocação pesada.
            return torch.zeros(B, self.vocab_size, device=dev)

        elif self.mode == "fixed":
            # Reutiliza o mesmo buffer, expandindo para [B, V]
            if self._fixed_logits.device != dev:
                self._fixed_logits = self._fixed_logits.to(dev)
            return self._fixed_logits.expand(B, -1)

        elif self.mode == "randn":
            # Logits gaussianos i.i.d.; custo baixo e ainda "aleatório"
            if self.gen is None:
                return self.scale * torch.randn(B, self.vocab_size, device=dev)
            else:
                return self.scale * torch.randn(B, self.vocab_size, device=dev, generator=self.gen)

        else:
            raise ValueError(f"Modo inválido: {self.mode}. Use 'zeros', 'fixed' ou 'randn'.")


if __name__ == "__main__":
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from .peptide_sampling import forward_trajectory_log_prob

    seq_size = 5
    batch_size = 10

    env = Sequences(seq_size, batch_size,None, 0)
    net = Policy(seq_size)

    opt = torch.optim.AdamW([{"params": net.parameters(), "lr": 0.1}])
    opt.zero_grad()
    logp = forward_trajectory_log_prob(env, net)
    logp.mean().backward()
    for p in net.parameters():
        print(p.grad)
        break


