import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers
from torch import Tensor

transformers.set_seed(0)
from transformers import GPT2Config, GPT2Model

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class AttackerTransformer(nn.Module):
    def __init__(self, config: dict):
        super(AttackerTransformer, self).__init__()

        self.config = config
        self.test = self.config.get("test", False)
        self.horizon = self.config["H"]
        self.n_embd = self.config["n_embd"]
        self.n_layer = self.config["n_layer"]
        self.n_head = self.config["n_head"]
        self.state_dim = self.config["state_dim"]
        self.action_dim = self.config["action_dim"]
        self.dropout = self.config["dropout"]

        model_config = GPT2Config(
            n_positions=4 * (1 + self.horizon),
            n_embd=self.n_embd,
            n_layer=self.n_layer,
            n_head=1,
            resid_pdrop=self.dropout,
            embd_pdrop=self.dropout,
            attn_pdrop=self.dropout,
            use_cache=False,
        )
        self.transformer = GPT2Model(model_config)

        self.embed_transition = nn.Linear(2 * self.state_dim + self.action_dim + 1, self.n_embd)
        self.pred_actions = nn.Linear(
            self.n_embd,
            self.action_dim * 2,
        )

    def forward(self, x: Tensor) -> tuple[Tensor, Tensor]:
        stacked_inputs = self.embed_transition(x)
        transformer_outputs = self.transformer(inputs_embeds=stacked_inputs)
        preds: Tensor = self.pred_actions(transformer_outputs["last_hidden_state"])
        means, stds = preds.reshape(x.size(0), x.size(1), 2, self.action_dim).unbind(dim=-2)

        if self.test:
            return means[:, -1, :], stds[:, -1, :]
        return means[:, 1:, :], stds[:, 1:, :]

    def __call__(self, *args, **kwds) -> tuple[Tensor, Tensor]:
        return super().__call__(*args, **kwds)

    def make_query_line(self, states: Tensor, batch_size: int) -> Tensor:
        query_line = torch.zeros((batch_size, 1, self.state_dim * 2 + self.action_dim + 1), device=device)
        query_line[:, 0, : states.shape[-1]] = states
        return query_line

    def predict_rewards(self, context: Tensor, query_line: Tensor) -> tuple[Tensor, Tensor]:
        x = torch.cat((query_line, context), dim=-2)
        means, stds = self(x)
        return means, F.softplus(stds)
