"""
GPT model:
- the initial stem consists of a combination of token encoding and a positional encoding
- the meat of it is a uniform sequence of Transformer blocks
    - each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block
    - all blocks feed into a central residual pathway similar to resnets
- the final decoder is a linear projection into a vanilla Softmax classifier
"""

import math
import logging

import torch
import torch.nn as nn
from torch.nn import functional as F
import torchvision
from einops import rearrange

logger = logging.getLogger(__name__)

VAILD_OBS_DIM = 9

class GPTConfig:
    """base GPT config, params common to all GPT versions"""

    embd_pdrop = 0.1
    resid_pdrop = 0.1
    attn_pdrop = 0.1
    input_size = 10
    goal_size = 49152
    n_embd = 768
    n_layer = 12
    goal_type = None
    use_img_obs = False
    use_skill_head = False

    def __init__(self, vocab_size, goal_seq_lenth, obs_seq_lenth, subgoal_seq_lenth, **kwargs):
        self.vocab_size = vocab_size
        self.goal_seq_lenth = goal_seq_lenth
        self.obs_seq_lenth = obs_seq_lenth
        self.subgoal_seq_lenth = subgoal_seq_lenth
        for k, v in kwargs.items():
            setattr(self, k, v)


class GPT1Config(GPTConfig):
    """GPT-1 like network roughly 125M params"""

    n_layer = 12
    n_head = 12
    n_embd = 768


class CNN(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        _, _, c = input_dim
        self.conv1 = nn.Conv2d(in_channels=c, out_channels=6, kernel_size=6, stride=2)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=3, kernel_size=6)
        self.fc1 = nn.Linear(3 * 25 * 25, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, output_dim)
        self.normalize = torchvision.transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        )

    def forward(self, x):
        is_seq = x.dim() == 5
        if is_seq:
            n = x.shape[0]
            t = x.shape[1]
            x = rearrange(x, "n t h w c -> (n t) c h w")
        x = self.normalize(x)
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        if is_seq:
            x = rearrange(x, "(n t) e -> n t e", n=n, t=t)
        return x


class CausalSelfAttention(nn.Module):
    """
    A vanilla multi-head masked self-attention layer with a projection at the end.
    It is possible to use torch.nn.MultiheadAttention here but I am including an
    explicit implementation here to show that there is nothing too scary here.
    """

    def __init__(self, config, direction="past"):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads
        self.key = nn.Linear(config.n_embd, config.n_embd)
        self.query = nn.Linear(config.n_embd, config.n_embd)
        self.value = nn.Linear(config.n_embd, config.n_embd)
        # regularization
        self.attn_drop = nn.Dropout(config.attn_pdrop)
        self.resid_drop = nn.Dropout(config.resid_pdrop)
        # output projection
        self.proj = nn.Linear(config.n_embd, config.n_embd)
        # causal mask to ensure that attention is only applied to the left in the input sequence
        goal_seq_lenth = config.goal_seq_lenth
        obs_seq_lenth = config.obs_seq_lenth
        block_size = goal_seq_lenth + obs_seq_lenth
        if direction == 'past':
            mask = torch.tril(torch.ones(block_size, block_size))
        elif direction == 'future':
            mask = torch.triu(torch.ones(block_size, block_size))
        self.register_buffer(
            "mask",
            mask.view(1, 1, block_size, block_size
            ),
        )
        self.n_head = config.n_head

    def forward(self, x, x_mask):
        (
            B,
            T,
            C,
        ) = x.size()  # batch size, sequence length, embedding dimensionality (n_embd)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        k = (
            self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        )  # (B, nh, T, hs)
        q = (
            self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        )  # (B, nh, T, hs)
        v = (
            self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        )  # (B, nh, T, hs)

        x_mask = x_mask.unsqueeze(1).unsqueeze(2).repeat(1, 1, T, 1)

        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        cur_mask = self.mask[:, :, :T, :T] * x_mask
        att_logit = att.masked_fill(cur_mask == 0, float("-inf"))
        att = F.softmax(att_logit, dim=-1)
        att = att.masked_fill(att.isnan(), 0)
        att = self.attn_drop(att)
        y = att @ v  # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = (
            y.transpose(1, 2).contiguous().view(B, T, C)
        )  # re-assemble all head outputs side by side

        # output projection
        y = self.resid_drop(self.proj(y))
        return y


class Block(nn.Module):
    """an unassuming Transformer block"""

    def __init__(self, config, direction='past'):
        super().__init__()
        self.ln1 = nn.LayerNorm(config.n_embd)
        self.ln2 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config, direction)
        self.mlp = nn.Sequential(
            nn.Linear(config.n_embd, 4 * config.n_embd),
            nn.GELU(),
            nn.Linear(4 * config.n_embd, config.n_embd),
            nn.Dropout(config.resid_pdrop),
        )

    def forward(self, input):
        x, x_mask = input
        att_x = self.attn(self.ln1(x), x_mask)
        x = x + att_x
        x = x + self.mlp(self.ln2(x))
        return x, x_mask


class GPT(nn.Module):
    """the full GPT language model, with a context size of block_size"""

    def __init__(self, config: GPTConfig):
        super().__init__()
        self.use_img_obs = config.use_img_obs
        self.use_skill_head = config.use_skill_head
        print(f"use_img_obs: {self.use_img_obs}")
        # input embedding stem
        if self.use_img_obs:
            self.obs_emb = nn.Linear(VAILD_OBS_DIM, config.n_embd)
            self.img_emb = CNN((224, 224, 3), config.n_embd)
            self.joint_emb = nn.Linear(config.n_embd * 2, config.n_embd)
        else:
            self.obs_emb = nn.Linear(config.input_size, config.n_embd) # here, config.discrete_input = False
        if config.goal_type == "video":
            self.goal_emb = CNN((224, 224, 3), config.n_embd)
        elif config.goal_type == "future":
            self.goal_emb = nn.Linear(config.input_size, config.n_embd)
        else:
            print(f"goal_type {config.goal_type} not in [video, future]")
            exit()
        block_size = config.goal_seq_lenth + config.obs_seq_lenth
        self.pos_emb = nn.Parameter(torch.zeros(1, block_size, config.n_embd))
        self.drop = nn.Dropout(config.embd_pdrop)
        # transformer
        self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
        # decoder head
        self.ln_f = nn.LayerNorm(config.n_embd)
        self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        if self.use_skill_head:
            self.skill_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        else:
            self.skill_head = None

        prior_block_size = config.subgoal_seq_lenth + config.obs_seq_lenth
        self.prior_pos_emb = nn.Parameter(torch.zeros(1, prior_block_size, config.n_embd))
        self.prior_block = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])

        self.prior_ln_f = nn.LayerNorm(config.n_embd)
        self.prior_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        self.apply(self._init_weights)

        logger.info(
            "number of parameters: %e", sum(p.numel() for p in self.parameters())
        )

    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.LayerNorm):
            torch.nn.init.zeros_(module.bias)
            torch.nn.init.ones_(module.weight)
        elif isinstance(module, GPT):
            torch.nn.init.normal_(module.pos_emb, mean=0.0, std=0.02)
            torch.nn.init.normal_(module.prior_pos_emb, mean=0.0, std=0.02)

    def configure_optimizers(self, train_config):
        """
        This long function is unfortunately doing something very simple and is being very defensive:
        We are separating out all parameters of the model into two buckets: those that will experience
        weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
        We are then returning the PyTorch optimizer object.
        """

        # separate out all parameters to those that will and won't experience regularizing weight decay
        decay = set()
        no_decay = set()
        whitelist_weight_modules = (torch.nn.Linear, torch.nn.Conv2d)
        blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
        for mn, m in self.named_modules():
            for pn, p in m.named_parameters():
                fpn = "%s.%s" % (mn, pn) if mn else pn  # full param name

                if pn.endswith("bias"):
                    # all biases will not be decayed
                    no_decay.add(fpn)
                elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules):
                    # weights of whitelist modules will be weight decayed
                    decay.add(fpn)
                elif pn.endswith("weight") and isinstance(m, blacklist_weight_modules):
                    # weights of blacklist modules will NOT be weight decayed
                    no_decay.add(fpn)

        # special case the position embedding parameter in the root GPT module as not decayed
        no_decay.add("pos_emb")
        no_decay.add("prior_pos_emb")

        # validate that we considered every parameter
        param_dict = {pn: p for pn, p in self.named_parameters()}
        inter_params = decay & no_decay
        union_params = decay | no_decay
        assert (
            len(inter_params) == 0
        ), "parameters %s made it into both decay/no_decay sets!" % (str(inter_params),)
        assert (
            len(param_dict.keys() - union_params) == 0
        ), "parameters %s were not separated into either decay/no_decay set!" % (
            str(param_dict.keys() - union_params),
        )

        # create the pytorch optimizer object
        optim_groups = [
            {
                "params": [param_dict[pn] for pn in sorted(list(decay))],
                "weight_decay": train_config.weight_decay,
            },
            {
                "params": [param_dict[pn] for pn in sorted(list(no_decay))],
                "weight_decay": 0.0,
            },
        ]
        optimizer = torch.optim.AdamW(
            optim_groups, lr=train_config.learning_rate, betas=train_config.betas
        )
        return optimizer

    def forward_policy(self, obs, img, obs_mask, goal, goal_mask):

        b_obs, t_obs, obs_dim = obs.size()
        b_goal, t_goal = goal.shape[0], goal.shape[1]
        t = t_goal + t_obs
        assert b_obs == b_goal, "obs and goal should have same batch size"

        # forward the GPT model
        if self.use_img_obs:
            #obs_state_embeddings = self.obs_emb(obs[:, :, :VAILD_OBS_DIM])
            obs_img_embeddings = self.img_emb(img)
            #obs_joint_embeddings = torch.cat([obs_img_embeddings, obs_state_embeddings], dim=-1)
            #obs_embeddings = self.joint_emb(obs_joint_embeddings)
            obs_embeddings = obs_img_embeddings
        else:
            obs_embeddings = self.obs_emb(obs)  # each index maps to a (learnable) vector # b * t * e
        goal_embeddings = self.goal_emb(goal)
        
        token_embeddings = torch.cat([goal_embeddings, obs_embeddings], dim=1)  #[N, T_GOAL + T, E]
        token_mask = torch.cat([goal_mask, obs_mask], dim=1)

        position_embeddings = self.pos_emb[:, :t, :]  # each position maps to a (learnable) vector, self.pos_emb = [1, MAX_CONTEXT_LENTH, e] is nn.Parameter, by take :t --> [1, t, e]
        policy_x = self.drop(token_embeddings + position_embeddings)
        policy_x, _ = self.blocks((policy_x, token_mask)) # b * t * e
        policy_x = self.ln_f(policy_x) # b * t * e
        policy_outputs = self.head(policy_x) # b * t * a (#bins + #bins * actual action dims)
        if self.use_skill_head:
            skill_outputs = self.skill_head(policy_x)
        else:
            skill_outputs = None

        return policy_outputs, policy_x, skill_outputs

    def forward_prior(self, obs, img, obs_mask, subgoal, subgoal_mask):

        b_obs, t_obs, obs_dim = obs.size()
        b_goal, t_goal = subgoal.shape[0], subgoal.shape[1]
        t = t_goal + t_obs
        assert b_obs == b_goal, "obs and goal should have same batch size"

        if self.use_img_obs:
            #obs_state_embeddings = self.obs_emb(obs[:, :, :VAILD_OBS_DIM])
            obs_img_embeddings = self.img_emb(img)
            #obs_joint_embeddings = torch.cat([obs_img_embeddings, obs_state_embeddings], dim=-1)
            #obs_embeddings = self.joint_emb(obs_joint_embeddings)
            obs_embeddings = obs_img_embeddings
        else:
            obs_embeddings = self.obs_emb(obs)
        subgoal_embeddings = self.goal_emb(subgoal)

        obs_embeddings = obs_embeddings.detach()
        subgoal_embeddings = subgoal_embeddings.detach()

        token_embeddings = torch.cat([subgoal_embeddings, obs_embeddings], dim=1)  #[N, T_GOAL + T, E]
        token_mask = torch.cat([subgoal_mask, obs_mask], dim=1)

        position_embeddings = self.prior_pos_emb[:, :t, :]
        prior_x = token_embeddings + position_embeddings
        prior_x, _ = self.prior_block((prior_x, token_mask))
        prior_x = self.prior_ln_f(prior_x)
        prior_outputs = self.prior_head(prior_x)

        return prior_outputs, prior_x