import math
import inspect
from dataclasses import dataclass

import torch
import torch.nn as nn
from torch.nn import functional as F

class LayerNorm(nn.Module):
    """LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False"""

    def __init__(self, ndim, bias):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(ndim))
        self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None

    def forward(self, input):
        return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)


class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        # regularization
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.dropout = config.dropout

    def forward(self, x):
        (B, T, C) = x.size()

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)

        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        y = nn.functional.scaled_dot_product_attention(q, k, v,
            attn_mask=None,
            dropout_p=self.dropout if self.training else 0,
            is_causal=True,
        )

        y = (y.transpose(1, 2).contiguous().view(B, T, C))  # re-assemble all head outputs side by side

        # output projection
        y = self.resid_dropout(self.c_proj(y))
        return y

    def forward_kvc(self, x, kv_cache, cnt):
        B, TT, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k1, v1  = self.c_attn(x).split(self.n_embd, dim=2)
        kv_cache[0, :, cnt:cnt+TT] = k1
        kv_cache[1, :, cnt:cnt+TT] = v1
        k = kv_cache[0, :, :cnt+TT]
        v = kv_cache[1, :, :cnt+TT]
        T = k.size(1)

        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, TT, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

        is_causal = True if cnt == 0 else False
        # efficient attention using Flash Attention CUDA kernels
        y = nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=is_causal)

        y = y.transpose(1, 2).contiguous().view(B, TT, C) # re-assemble all head outputs side by side

        # output projection
        y = self.resid_dropout(self.c_proj(y))
        return y, kv_cache


class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
        self.gelu = nn.GELU()
        self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x


class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

    def forward_kvc(self, x, kv_cache, cnt):
        _, TT, _ = x.size()
        xx, kv_cache = self.attn.forward_kvc(self.ln_1(x), kv_cache, cnt)
        x = x + xx
        x = x + self.mlp(self.ln_2(x))
        return x, kv_cache

@dataclass
class GPTConfig:
    context_length: int
    action_size: int
    n_layer: int
    n_head: int
    n_embd: int
    input_dim: int
    dropout: float
    bias: bool

class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.action_size is not None
        assert config.context_length is not None
        self.config = config
        self.transformer = nn.ModuleDict(
            dict(
                wte=nn.Embedding(config.action_size, config.n_embd),
                wpe=nn.Embedding(config.context_length*3, config.n_embd),
                wre=nn.Embedding(6, config.n_embd),
                drop=nn.Dropout(config.dropout),
                h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
                ln_f=LayerNorm(config.n_embd, bias=config.bias),
            )
        )
        if config.n_embd == config.input_dim == 512:
            self.enc_head = nn.Linear(config.n_embd, config.n_embd)
        else:
            intermediate_dim = max(config.n_embd, config.input_dim)
            input_encoder_layers = [nn.Linear(config.input_dim, intermediate_dim), nn.GELU(), nn.Linear(intermediate_dim, config.n_embd)]
            self.input_encoder = nn.Sequential(*input_encoder_layers)
            self.enc_head = nn.Linear(config.n_embd, config.input_dim)

        # init all weights
        self.apply(self._init_weights)
        # apply special scaled init to the residual projections, per GPT-2 paper
        for pn, p in self.named_parameters():
            if pn.endswith("c_proj.weight"):
                nn.init.normal_(
                    p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer)
                )

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def compute_context_vector(self, X_enc, X_action, X_reward):
        device = X_enc.device
        if self.config.n_embd == self.config.input_dim == 512:
            pass
        else:
            X_enc = self.input_encoder(X_enc)
        b, t, n_embd = X_enc.size()
        tt = X_enc.size(1) + X_action.size(1) + X_reward.size(1)
        pos = torch.arange(0, tt, dtype=torch.long, device=device)  # shape (t)
        tok_emb_action = self.transformer.wte(X_action)  # token embeddings of shape (b, t, n_embd)
        tok_emb_reward = self.transformer.wre(X_reward)
        tok_emb_enc = X_enc
        # from: https://stackoverflow.com/questions/61026393/pytorch-concatenate-rows-in-alternate-order
        tmp1 = F.pad(tok_emb_action, (0, 0, 0, 1)) if X_enc.size(1) == X_action.size(1) + 1 else tok_emb_action
        tmp2 = F.pad(tok_emb_reward, (0, 0, 0, 1)) if X_enc.size(1) == X_reward.size(1) + 1 else tok_emb_reward
        tok_emb = torch.cat([tok_emb_enc, tmp1, tmp2], dim=-1).view(b, -1, n_embd)
        tok_emb = tok_emb[:, :tt]

        pos_emb = self.transformer.wpe(pos)  # position embeddings of shape (t, n_embd)
        x = self.transformer.drop(tok_emb + pos_emb)
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)
        return x

    def compute_context_vector_kvc(self, X_enc, X_action, X_reward, kv_cache, cnt):
        device = X_enc.device
        if self.config.n_embd == self.config.input_dim == 512:
            pass
        else:
            X_enc = self.input_encoder(X_enc)
        b, t, n_embd = X_enc.size()
        tt = X_enc.size(1) + X_action.size(1) + X_reward.size(1)
        pos = torch.arange(0, tt, dtype=torch.long, device=device)  # shape (t)
        tok_emb_action = self.transformer.wte(X_action)  # token embeddings of shape (b, t, n_embd)
        tok_emb_reward = self.transformer.wre(X_reward)
        tok_emb_enc = X_enc
        # from: https://stackoverflow.com/questions/61026393/pytorch-concatenate-rows-in-alternate-order
        tmp1 = F.pad(tok_emb_action, (0, 0, 0, 1)) if X_enc.size(1) == X_action.size(1) + 1 else tok_emb_action
        tmp2 = F.pad(tok_emb_reward, (0, 0, 0, 1)) if X_enc.size(1) == X_reward.size(1) + 1 else tok_emb_reward
        tok_emb = torch.cat([tok_emb_enc, tmp1, tmp2], dim=-1).view(b, -1, n_embd)
        tok_emb = tok_emb[:, :tt]

        pos_emb = self.transformer.wpe(pos)  # position embeddings of shape (t, n_embd)
        x = self.transformer.drop(tok_emb + pos_emb)

        x = x[:, cnt:]
        for l, block in enumerate(self.transformer.h):
            x, kv_cache_small = block.forward_kvc(x, kv_cache[l], cnt)
            kv_cache[l] = kv_cache_small
        x = self.transformer.ln_f(x)
        return x, kv_cache
    
    def forward(self, X_enc, X_action, X_reward, Y_enc):
        b = len(X_enc)
        x = self.compute_context_vector(X_enc, X_action, X_reward)
        x1 = self.enc_head(x[:, 2::3])
        loss_enc = F.mse_loss(x1, Y_enc, reduction='none')
        loss = loss_enc.mean()
        return loss, torch.mean(loss_enc.view(b, -1), dim=-1)

    def configure_optimizers(self, weight_decay, lr, betas, device_type):
        # start with all of the candidate parameters
        param_dict = {pn: p for pn, p in self.named_parameters()}
        # filter out those that do not require grad
        param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
        # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
        # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
        decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
        nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
        optim_groups = [
            {"params": decay_params, "weight_decay": weight_decay, 'lr': lr},
            {"params": nodecay_params, "weight_decay": 0.0, 'lr': lr},
        ]

        fused_available = "fused" in inspect.signature(torch.optim.AdamW).parameters
        use_fused = fused_available and device_type == "cuda"
        extra_args = dict(fused=True) if use_fused else dict()
        optimizer = torch.optim.AdamW(optim_groups, betas=betas, eps=1e-5, **extra_args)

        return optimizer