from dataclasses import dataclass
from typing import Any, Optional, Tuple

from einops import rearrange
import torch
import torch.nn as nn
import torch.nn.functional as F

from data import Batch
from .kv_caching import KeysValues
from .slicer import Embedder, Head
from .tokenizer import Tokenizer
from .transformer import Transformer, TransformerConfig
from utils import init_weights


@dataclass
class GPTOutput:
    logits_obs: torch.FloatTensor
    logits_rew: torch.FloatTensor
    logits_end: torch.FloatTensor


class GPT(nn.Module):
    def __init__(self, obs_vocab_size: int, act_vocab_size: int, config: TransformerConfig) -> None:
        super().__init__()
        self.obs_vocab_size, self.act_vocab_size = obs_vocab_size, act_vocab_size
        self.config = config
        self.transformer = Transformer(config)

        all_but_last_obs_tokens_pattern = torch.ones(config.tokens_per_block)
        all_but_last_obs_tokens_pattern[-2] = 0
        act_tokens_pattern = torch.zeros(self.config.tokens_per_block)
        act_tokens_pattern[-1] = 1
        obs_tokens_pattern = 1 - act_tokens_pattern

        self.pos_emb = nn.Embedding(config.max_tokens, config.embed_dim)

        self.embedder = Embedder(
            max_blocks=config.max_blocks,
            block_masks=[act_tokens_pattern, obs_tokens_pattern],
            embedding_tables=nn.ModuleList([nn.Embedding(act_vocab_size, config.embed_dim), nn.Embedding(obs_vocab_size, config.embed_dim)])
        )

        self.head_observations = Head(
            max_blocks=config.max_blocks,
            block_mask=all_but_last_obs_tokens_pattern,
            head_module=nn.Sequential(
                nn.Linear(config.embed_dim, config.embed_dim),
                nn.ReLU(),
                nn.Linear(config.embed_dim, obs_vocab_size)
            )
        )

        self.head_rewards = Head(
            max_blocks=config.max_blocks,
            block_mask=act_tokens_pattern,
            head_module=nn.Sequential(
                nn.Linear(config.embed_dim, config.embed_dim),
                nn.ReLU(),
                nn.Linear(config.embed_dim, 3)
            )
        )

        self.head_ends = Head(
            max_blocks=config.max_blocks,
            block_mask=act_tokens_pattern,
            head_module=nn.Sequential(
                nn.Linear(config.embed_dim, config.embed_dim),
                nn.ReLU(),
                nn.Linear(config.embed_dim, 2)
            )
        )

        self.apply(init_weights)

    def __repr__(self) -> str:
        return "gpt"

    def forward(self, tokens: torch.LongTensor, past_keys_values: Optional[KeysValues] = None) -> GPTOutput:

        num_steps = tokens.size(1)  # (B, T)
        assert num_steps <= self.config.max_tokens
        prev_steps = 0 if past_keys_values is None else past_keys_values.size

        sequences = self.embedder(tokens, num_steps, prev_steps) + self.pos_emb(prev_steps + torch.arange(num_steps, device=tokens.device))

        x = self.transformer(sequences, past_keys_values)

        logits_obs = self.head_observations(x, num_steps=num_steps, prev_steps=prev_steps)
        logits_rew = self.head_rewards(x, num_steps=num_steps, prev_steps=prev_steps)
        logits_end = self.head_ends(x, num_steps=num_steps, prev_steps=prev_steps)

        return GPTOutput(logits_obs, logits_rew, logits_end)

    def compute_loss(self, batch: Batch, tokenizer: Tokenizer):

        with torch.no_grad():
            obs_tokens = tokenizer.encode(batch.obs).tokens  # (BL, K)

        act_tokens = batch.act.unsqueeze(-1) # (B, L) -> (B, L, 1)
        tokens = rearrange(torch.cat((obs_tokens, act_tokens), dim=2), 'b l k1 -> b (l k1)')  # (B, L(K+1))

        outputs = self(tokens)

        mask = batch.mask_padding

        logits_obs = rearrange(outputs.logits_obs[:, :-1], 'b tk o -> (b tk) o')
        target_obs = self.compute_target_obs(obs_tokens, mask)
        loss_obs = F.cross_entropy(logits_obs, target=target_obs)

        loss_rew = F.cross_entropy(outputs.logits_rew[mask], target=batch.rew[mask].sign().long().add(1)) # reward clipped to {-1, 0, 1}
        loss_end = F.cross_entropy(outputs.logits_end[mask], target=batch.end[mask])
        
        return loss_obs + loss_rew + loss_end, {'loss_obs': loss_obs.detach(), 'loss_rew': loss_rew.detach(), 'loss_end': loss_end.detach()}

    def compute_target_obs(self, obs_tokens, mask_padding):
        mask_fill = torch.logical_not(mask_padding)
        target_obs = rearrange(obs_tokens.masked_fill(mask_fill.unsqueeze(-1).expand_as(obs_tokens), -100), 'b t k -> b (t k)')[:, 1:]
        return target_obs.reshape(-1)
