from math import sqrt

import torch
import torch.nn as nn
import transformers
from torch import Tensor
from torch.distributions import Categorical
from torch.nn import init

from args import PPOConfig

transformers.set_seed(0)
from transformers import GPT2Config, GPT2Model

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class Transformer(nn.Module):
    def __init__(self, config: dict):
        super(Transformer, self).__init__()

        self.config = config
        self.test = self.config.get("test", False)
        self.horizon = self.config["H"]
        self.n_embd = self.config["n_embd"]
        self.n_layer = self.config["n_layer"]
        self.n_head = self.config["n_head"]
        self.state_dim = self.config["state_dim"]
        self.action_dim = self.config["action_dim"]
        self.dropout = self.config["dropout"]

        model_config = GPT2Config(
            n_positions=4 * (1 + self.horizon),
            n_embd=self.n_embd,
            n_layer=self.n_layer,
            n_head=1,
            resid_pdrop=self.dropout,
            embd_pdrop=self.dropout,
            attn_pdrop=self.dropout,
            use_cache=False,
        )
        self.transformer = GPT2Model(model_config)

        self.embed_transition = nn.Linear(2 * self.state_dim + self.action_dim + 1, self.n_embd)
        self.pred_actions = nn.Linear(self.n_embd, self.action_dim)

    def forward(self, x: Tensor) -> Tensor:
        stacked_inputs = self.embed_transition(x)
        transformer_outputs = self.transformer(inputs_embeds=stacked_inputs)
        preds = self.pred_actions(transformer_outputs["last_hidden_state"])

        if self.test:
            return preds[:, -1, :]
        return preds[:, 1:, :]

    def __call__(self, *args, **kwds) -> Tensor:
        return super().__call__(*args, **kwds)

    def make_query_line(self, states: Tensor, batch_size: int) -> Tensor:
        query_line = torch.zeros((batch_size, 1, self.state_dim * 2 + self.action_dim + 1), device=device)
        query_line[:, 0, : states.shape[-1]] = states
        return query_line

    def predict_actions(self, context: Tensor, query_line: Tensor) -> Tensor:
        x = torch.cat((query_line, context.to(device)), dim=-2).type(torch.float32)

        return self(x)


class EnvLinear(nn.Module):
    def __init__(self, n_envs: int, input_dim: int, output_dim: int):
        super().__init__()
        self.weights = nn.Parameter(torch.empty(n_envs, input_dim, output_dim))
        self.biases = nn.Parameter(torch.empty(n_envs, 1, output_dim))

        init.kaiming_uniform_(self.weights, a=sqrt(5))
        bound = 1 / sqrt(input_dim)
        init.uniform_(self.biases, -bound, bound)

    def forward(self, x: Tensor) -> Tensor:
        # assert x.shape == (n_envs, batch_size, input_dim)
        output = torch.bmm(x, self.weights)  # (n_envs, batch_size, output_dim)
        output += self.biases  # (n_envs, batch_size, output_dim)
        return output


class ValueAndPolicyNetwork(nn.Module):
    config: PPOConfig
    n_envs: int
    state_dim: int
    action_dim: int

    layers: nn.Sequential
    policy: EnvLinear
    value: EnvLinear

    def __init__(self, config: PPOConfig, n_envs: int, state_dim: int, action_dim: int):
        super().__init__()

        self.config = config
        self.n_envs = n_envs
        self.state_dim = state_dim
        self.action_dim = action_dim

        self.layers = nn.Sequential()

        for i in range(self.config.n_layer):
            self.layers.append(nn.Linear(self.state_dim if i == 0 else self.config.n_hidden, self.config.n_hidden))
            self.layers.append(nn.LeakyReLU())

        self.policy = EnvLinear(n_envs, self.config.n_hidden, self.action_dim)
        self.value = EnvLinear(n_envs, self.config.n_hidden, 1)

    def forward(self, x: Tensor) -> tuple[Categorical, Tensor]:
        # x = one_hot(x.long(), self.square_len).reshape(x.shape[0], x.shape[1], -1).float()
        h: Tensor = self.layers(x)
        pi_logits: Tensor = self.policy(h)
        # pi_probs = softmax(pi_logits, dim=-1)
        pi = Categorical(logits=pi_logits)
        # pi = Categorical(pi_probs)
        value: Tensor = self.value(h).squeeze(-1)

        return pi, value

    def __call__(self, *args, **kwds) -> tuple[Categorical, Tensor]:
        return super().__call__(*args, **kwds)


class ImageTransformer(Transformer):
    """Transformer class for image-based data."""

    def __init__(self, config):
        super().__init__(config)
        self.im_embd = 8

        size = self.config["image_size"]
        size = (size - 3) // 2 + 1
        size = (size - 3) // 1 + 1

        self.image_encoder = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=2),
            nn.ReLU(),
            nn.Conv2d(16, 16, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.Dropout(self.dropout),
            nn.Flatten(start_dim=1),
            nn.Linear(int(16 * size * size), self.im_embd),
            nn.ReLU(),
        )

        new_dim = self.im_embd + self.state_dim + self.action_dim + 1
        self.embed_transition = torch.nn.Linear(new_dim, self.n_embd)
        self.embed_ln = nn.LayerNorm(self.n_embd)

    def forward(self, batch: tuple[Tensor, Tensor]):
        x_images, x = batch
        batch_size = x.shape[0]

        x_images = x_images.view(-1, *x_images.size()[2:])
        image_enc_seq = self.image_encoder(x_images)
        image_enc_seq = image_enc_seq.view(batch_size, -1, self.im_embd)

        stacked_inputs = torch.cat([image_enc_seq, x], dim=2)
        stacked_inputs = self.embed_transition(stacked_inputs)
        stacked_inputs = self.embed_ln(stacked_inputs)

        transformer_outputs = self.transformer(inputs_embeds=stacked_inputs)
        preds = self.pred_actions(transformer_outputs["last_hidden_state"])

        if self.test:
            return preds[:, -1, :]
        return preds[:, 1:, :]

    def make_input(self, batch) -> tuple[Tensor, Tensor]:
        query_images = batch["query_images"][:, None, :]

        batch_size = query_images.shape[0]

        c_images = batch["context_images"]
        c_states = batch["context_states"]
        c_actions = batch["context_actions"]
        c_rewards = batch["context_rewards"]
        if len(c_rewards.shape) == 1:
            c_rewards = c_rewards[:, :, None]
        query_states = batch["query_states"][:, None, :]

        image_seq = torch.cat([query_images, c_images], dim=1)

        ctx = torch.cat((c_states, c_actions, c_rewards), dim=2)
        query_line = torch.zeros((batch_size, 1, ctx.shape[-1]), device=device)
        query_line[:, :, : query_states.shape[-1]] = query_states
        x = torch.cat((query_line, ctx), dim=1)

        return image_seq, x

    def predict_actions(self, batch: tuple[Tensor, Tensor], query: tuple[Tensor, Tensor]) -> Tensor:
        c_images, context = batch
        query_images, query_line = query

        x_images = torch.cat([query_images[:, None], c_images], dim=1)

        x = torch.cat((query_line, context.to(device)), dim=-2).type(torch.float32)

        return self((x_images, x))

    def make_query_line(self, states: Tensor, batch_size: int) -> Tensor:
        query_line = torch.zeros((batch_size, 1, self.state_dim + self.action_dim + 1), device=device)
        query_line[:, 0, : states.shape[-1]] = states
        return query_line
