import torch
import torch.nn as nn
from torch.nn import functional as F
import numpy as np
import math


class NewGELU(nn.Module):
    """
    Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT).
    Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415
    """
    def forward(self, x):
        return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))

class CausalSelfAttention(nn.Module):
    """
    A vanilla multi-head masked self-attention layer with a projection at the end.
    It is possible to use torch.nn.MultiheadAttention here but I am including an
    explicit implementation here to show that there is nothing too scary here.
    """

    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        # regularization
        self.attn_dropout = nn.Dropout(config.attn_pdrop)
        self.resid_dropout = nn.Dropout(config.resid_pdrop)
        # causal mask to ensure that attention is only applied to the left in the input sequence
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                                     .view(1, 1, config.block_size, config.block_size))
        self.n_head = config.n_head
        self.n_embd = config.n_embd

    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k ,v  = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        att = self.attn_dropout(att)
        y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side

        # output projection
        y = self.resid_dropout(self.c_proj(y))
        return y

class Block(nn.Module):
    """ an unassuming Transformer block """

    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = nn.ModuleDict(dict(
            c_fc    = nn.Linear(config.n_embd, 4 * config.n_embd),
            c_proj  = nn.Linear(4 * config.n_embd, config.n_embd),
            act     = NewGELU(),
            dropout = nn.Dropout(config.resid_pdrop),
        ))
        m = self.mlp
        self.mlpf = lambda x: m.dropout(m.c_proj(m.act(m.c_fc(x)))) # MLP forward

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlpf(self.ln_2(x))
        return x

class GPT_for_DT(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.block_size = config.block_size
        self.n_embd = config.n_embd

        # build modules
        self.global_timestep_encoding = nn.Embedding(config.max_timestep, config.n_embd)
        self.context_position_encoding = nn.Embedding(config.block_size, config.n_embd)
        self.dropout = nn.Dropout(config.embd_pdrop)
        self.block_loop = nn.ModuleList([Block(config) for _ in range(config.n_layer)])
        self.norm = nn.LayerNorm(config.n_embd)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        # initialize weights
        self.apply(self._init_weights)

    # see karpathy/minGPT for weight's initilization in OpenAI GPT
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.LayerNorm):
            torch.nn.init.zeros_(module.bias)
            torch.nn.init.ones_(module.weight)

    def forward(self, rtgs_emb, states_emb, actions_emb, timesteps):
        # rtgs_emb    : (batch_size, step_size, n_embd)
        # states_emb  : (batch_size, step_size, n_embd)
        # actions_emb : (batch_size, step_size, n_embd)
        # timesteps   : (batch_size, step_size)  <-- but only the first step is used (other steps are ignored)

        batch_size = states_emb.shape[0]
        actual_step_size = states_emb.shape[1]

        # Generate a sequence of tokens :
        # [s], [a], [R] --> [R, s, a, R, s, a, ...]
        token_emb = torch.zeros(
            (batch_size, actual_step_size*3, self.n_embd),
            dtype=torch.float32,
            device=states_emb.device)
        token_emb[:,::3,:] = rtgs_emb
        token_emb[:,1::3,:] = states_emb
        if actions_emb is not None:
            token_emb[:,2::3,:] = actions_emb

        # Position encoding
        # print(("Position encoding : ", timesteps)
        timestep_start = torch.repeat_interleave(timesteps[:,0].unsqueeze(dim=-1), actual_step_size*3, dim=-1) # (batch_size, actual_step_size*3)
        # print(("Position encoding 2 : ", timestep_start)
        # print(("Position encoding 3 : ", timestep_start.shape)
        pos_global = self.global_timestep_encoding(timestep_start)
        context_position = torch.arange(actual_step_size*3, device=states_emb.device).repeat(batch_size,1) # (batch_size, actual_step_size*3)
        pos_relative = self.context_position_encoding(context_position)
        pos_emb = pos_global + pos_relative

        x = self.dropout(token_emb + pos_emb)

        # Apply multi-layered MHA (multi-head attentions)
        for block in self.block_loop:
            x = block(x)

        x = self.norm(x)

        # Apply Feed-Forward and Return
        logits = self.lm_head(x)
        # only get predictions from states
        logits = logits[:,1::3,:]

        return logits


class Embeddings_for_Atari(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.state_embedding = nn.Sequential(
            nn.Conv2d(4, 8, 3, padding=0),
            nn.ReLU(),
            nn.Conv2d(8, 16, 2, padding=0),
            nn.ReLU(),
            #nn.Conv2d(64, 64, 1, stride=1, padding=0),
            #nn.ReLU(),
            nn.Flatten(),
            nn.Linear(784, config.n_embd),
            nn.Tanh()
        )

        self.action_embedding = nn.Sequential(
            nn.Embedding(config.vocab_size, config.n_embd),
            nn.Tanh()
        )
        self.rtg_embedding = nn.Sequential(
            nn.Linear(1, config.n_embd),
            nn.Tanh()
        )

        # initialize weights
        self.apply(self._init_weights)

    # see karpathy/minGPT for weight's initilization in OpenAI GPT
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.LayerNorm):
            torch.nn.init.zeros_(module.bias)
            torch.nn.init.ones_(module.weight)

    def forward(self, rtgs, states, actions):
        ### inputs
        # rtgs        : (batch_size, step_size, 1)
        # states      : (batch_size, step_size, 4, 84, 84)
        # actions     : (batch_size, step_size)

        # batch_states:  torch.Size([1, 2, 4, 10, 10])
        # rtgs:  torch.Size([1, 2, 1])
        # batch_actions:  torch.Size([1, 2])

        ### outputs
        # rtgs_emb    : (batch_size, step_size, n_embd)
        # states_emb  : (batch_size, step_size, n_embd)
        # actions_emb : (batch_size, step_size, n_embd)
        
        # rtgs:  torch.Size([1, 2, 128])
        # batch_actions:  torch.Size([1, 2, 128])
        # batch_states:  torch.Size([1, 2, 128])

        rtgs_emb = self.rtg_embedding(rtgs)
        states_shp = states.reshape(-1, 4, 10, 10)
        states_emb = self.state_embedding(states_shp)
        states_emb = states_emb.reshape(states.shape[0], states.shape[1], states_emb.shape[1])

        if actions is None:
            actions_emb = None
        else:
            actions_emb = self.action_embedding(actions)

        return rtgs_emb, states_emb, actions_emb


# class CfgNode:
#     step_size = 50
#     n_head = 8
#     n_layer = 6
#     n_embd = 128  # each head has n_embd / n_head
#     attn_pdrop = 0.1
#     resid_pdrop = 0.1
#     embd_pdrop = 0.1
#     block_size = step_size * 3
#     # max_timestep = max_timesteps
#     max_timestep = 200
#     # vocab_size = len(action_dict)  # all actions
#     vocab_size = 5

class CfgNode:
    n_head = 8
    n_layer = 6
    n_embd = 128  # each head has n_embd / n_head
    attn_pdrop = 0.1
    resid_pdrop = 0.1
    embd_pdrop = 0.1
    vocab_size = 5

    def __init__(self, step_size=50, max_timesteps=176):
        self.block_size = step_size * 3
        self.max_timestep = max_timesteps
        self.step_size = step_size


