import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as td
import numpy as np

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

class DecisionConvMLP(nn.Module):
    def __init__(self, env_name, env, goal_dim=2, h_dim=1024):
        super().__init__()
        self.env_name = env_name
        state_dim = env.observation_space['observation'].shape[0]
        self.obs_shape = env.observation_space['pixels'].shape
        self.act_dim = env.action_space.shape[0]
        self.h_dim = h_dim

        self.convnet = nn.Sequential(nn.Conv2d(self.obs_shape[-1], 32, 3, stride=2),
                            nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
                            nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
                            nn.ReLU())
        
        n = self.obs_shape[0]
        m = self.obs_shape[1]
        self.embedding_dim = ((n-2)//2-4)*((m-2)//2-4)*32

        self.trunk = nn.Sequential(nn.Linear(self.embedding_dim, 50),
                                   nn.LayerNorm(50), nn.Tanh())

        self.mlp = nn.Sequential(
                nn.Linear(50 + state_dim - 2 + goal_dim, h_dim),
                nn.ReLU(),
                nn.Linear(h_dim, h_dim),
                nn.ReLU(),
                nn.Linear(h_dim, self.act_dim),
                nn.Tanh()
            )         

    def forward(self, states, proprio, goals):
        states = states / 255.0 - 0.5
        h = self.trunk(self.convnet(states).view(-1, self.embedding_dim))
        h = torch.cat((h, proprio, goals), dim=-1)
        action_preds = self.mlp(h)
        return action_preds
    
class DecisionConvRvs(nn.Module):
    def __init__(self, env_name, env, goal_dim=2, h_dim=1024):
        super().__init__()
        self.env_name = env_name
        state_dim = env.observation_space['observation'].shape[0]
        self.obs_shape = env.observation_space['pixels'].shape
        self.act_dim = env.action_space.shape[0]
        self.h_dim = h_dim

        self.convnet = nn.Sequential(nn.Conv2d(self.obs_shape[-1], 32, 3, stride=2),
                            nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
                            nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
                            nn.ReLU())
        
        n = self.obs_shape[0]
        m = self.obs_shape[1]
        self.embedding_dim = ((n-2)//2-4)*((m-2)//2-4)*32

        self.trunk = nn.Sequential(nn.Linear(self.embedding_dim, 50),
                                   nn.LayerNorm(50), nn.Tanh())

        self.mlp = nn.Sequential(
                nn.Linear(50 + state_dim - 2 + goal_dim + 1, h_dim),
                nn.ReLU(),
                nn.Linear(h_dim, h_dim),
                nn.ReLU(),
                nn.Linear(h_dim, self.act_dim),
                nn.Tanh()
            )         

    def forward(self, states, proprio, goals, q):
        states = states / 255.0 - 0.5
        h = self.trunk(self.convnet(states).view(-1, self.embedding_dim))
        h = torch.cat((h, proprio, goals, q), dim=-1)
        action_preds = self.mlp(h)
        return action_preds

class DecisionConvV(nn.Module):
    def __init__(self, env_name, env, goal_dim=2, h_dim=1024):
        super().__init__()

        self.env_name = env_name
        state_dim = env.observation_space['observation'].shape[0]
        self.obs_shape = env.observation_space['pixels'].shape
        act_dim = env.action_space.shape[0]
        self.h_dim = h_dim

        self.convnet = nn.Sequential(nn.Conv2d(self.obs_shape[-1], 32, 3, stride=2),
                            nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
                            nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
                            nn.ReLU())
        
        n = self.obs_shape[0]
        m = self.obs_shape[1]
        self.embedding_dim = ((n-2)//2-4)*((m-2)//2-4)*32

        self.trunk = nn.Sequential(nn.Linear(self.embedding_dim, 50),
                                   nn.LayerNorm(50), nn.Tanh())

        self.mlp = nn.Sequential(
                nn.Linear(50 + state_dim - 2 + goal_dim, h_dim),
                nn.ReLU(),
                nn.Linear(h_dim, h_dim),
                nn.ReLU(),
                nn.Linear(h_dim, 1),
                nn.Tanh()
            )              

    def forward(self, states, proprio, goals):
        states = states / 255.0 - 0.5
        h = self.trunk(self.convnet(states).view(-1, self.embedding_dim))
        h = torch.cat((h, proprio, goals), dim=-1)
        q_preds = self.mlp(h)
        return q_preds

class DecisionRvs(nn.Module):
    def __init__(self, env_name, env, goal_dim=2, h_dim=1024):
        super().__init__()

        self.env_name = env_name
        self.state_dim = env.observation_space['observation'].shape[0]
        self.act_dim = env.action_space.shape[0]
        self.h_dim = h_dim

        self.mlp = nn.Sequential(
                nn.Linear(self.state_dim + goal_dim + 1, h_dim),
                nn.ReLU(),
                nn.Linear(h_dim, h_dim),
                nn.ReLU(),
                nn.Linear(h_dim, self.act_dim),
                nn.Tanh()
            )         

    def forward(self, states, goals, q):
        h = torch.cat((states, goals, q), dim=-1)
        action_preds = self.mlp(h)
        return action_preds

class DecisionMLP(nn.Module):
    def __init__(self, env_name, env, goal_dim=2, h_dim=1024):
        super().__init__()

        self.env_name = env_name
        self.state_dim = env.observation_space['observation'].shape[0]
        self.act_dim = env.action_space.shape[0]
        self.h_dim = h_dim

        self.mlp = nn.Sequential(
                nn.Linear(self.state_dim + goal_dim, h_dim),
                nn.ReLU(),
                nn.Linear(h_dim, h_dim),
                nn.ReLU(),
                nn.Linear(h_dim, self.act_dim),
                nn.Tanh()
            )         

    def forward(self, states, goals):
        h = torch.cat((states, goals), dim=-1)
        action_preds = self.mlp(h)
        return action_preds
    
class DecisionV(nn.Module):
    def __init__(self, env_name, env, goal_dim=2, h_dim=1024):
        super().__init__()

        self.env_name = env_name
        self.state_dim = env.observation_space['observation'].shape[0]
        #self.act_dim = env.action_space.shape[0]
        self.h_dim = h_dim

        self.mlp = nn.Sequential(
                nn.Linear(self.state_dim +goal_dim, h_dim),
                nn.ReLU(),
                nn.Linear(h_dim, h_dim),
                nn.ReLU(),
                nn.Linear(h_dim, 1),
                nn.Tanh()
            )         

    def forward(self, states, goals):
        h = torch.cat((states, goals), dim=-1)
        q_preds = self.mlp(h)
        return q_preds

    
class MaskedCausalAttention(nn.Module):
    '''
    Thanks https://github.com/nikhilbarhate99/min-decision-transformer/tree/master
    '''
    def __init__(self, h_dim, max_T, n_heads, drop_p):
        super().__init__()

        self.n_heads = n_heads
        self.max_T = max_T

        self.q_net = nn.Linear(h_dim, h_dim)
        self.k_net = nn.Linear(h_dim, h_dim)
        self.v_net = nn.Linear(h_dim, h_dim)

        self.proj_net = nn.Linear(h_dim, h_dim)

        self.dropout = drop_p
        self.att_drop = nn.Dropout(drop_p)
        self.proj_drop = nn.Dropout(drop_p)

    def forward(self, x):
        B,T,C = x.shape # batch size, seq length, h_dim * n_heads

        N, D = self.n_heads, C // self.n_heads # N = num heads, D = attention dim

        # rearrange q, k, v as (B, N, T, D)
        q = self.q_net(x).view(B, T, N, D).transpose(1,2)
        k = self.k_net(x).view(B, T, N, D).transpose(1,2)
        v = self.v_net(x).view(B, T, N, D).transpose(1,2)

        attention = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
        attention = attention.transpose(1, 2).contiguous().view(B,T,N*D)

        out = self.proj_drop(self.proj_net(attention))
        return out

class Block(nn.Module):
    def __init__(self, h_dim, max_T, n_heads, drop_p):
        super().__init__()
        self.attention = MaskedCausalAttention(h_dim, max_T, n_heads, drop_p)
        self.mlp = nn.Sequential(
                nn.Linear(h_dim, 4*h_dim),
                nn.GELU(),
                nn.Linear(4*h_dim, h_dim),
                nn.Dropout(drop_p),
            )
        self.ln1 = nn.LayerNorm(h_dim)
        self.ln2 = nn.LayerNorm(h_dim)

    def forward(self, x):
        x = x + self.attention(self.ln1(x)) # residual
        x = x + self.mlp(self.ln2(x)) # residual
        return x

class DecisionConvTransformer(nn.Module):
    def __init__(self, env_name, env, n_blocks, h_dim, context_len,
                 n_heads, drop_p, goal_dim=2, max_timestep=4096):
        super().__init__()

        self.env_name = env_name
        self.obs_shape = env.observation_space['pixels'].shape
        self.state_dim = env.observation_space['observation'].shape[0]
        self.act_dim = env.action_space.shape[0]
        self.goal_dim = goal_dim
        self.n_heads = n_heads
        self.h_dim = h_dim

        ### transformer blocks
        input_seq_len = 3 * context_len 
        blocks = [Block(h_dim, input_seq_len, n_heads, drop_p) for _ in range(n_blocks)]
        self.transformer = nn.Sequential(*blocks)

        ### projection heads (project to embedding)
        self.embed_timestep = nn.Embedding(max_timestep, h_dim)        

        self.convnet = nn.Sequential(nn.Conv2d(self.obs_shape[-1], 32, 3, stride=2),
                            nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
                            nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
                            nn.ReLU())
        
        n = self.obs_shape[0]
        m = self.obs_shape[1]
        self.embedding_dim = ((n-2)//2-4)*((m-2)//2-4)*32

        self.embed_state = nn.Sequential(nn.Linear(self.embedding_dim + self.state_dim - 2, h_dim),
                                   nn.LayerNorm(h_dim), nn.Tanh())        
        self.embed_goal = torch.nn.Linear(goal_dim, h_dim)
        self.embed_action = torch.nn.Linear(self.act_dim, h_dim)
        
        ### prediction heads
        self.final_ln = nn.LayerNorm(h_dim)
        self.predict_action = nn.Sequential(
            *([nn.Linear(h_dim, self.act_dim)] + ([nn.Tanh()]))
        )
    
    def forward(self, states, proprio, actions, goals):
        B, T, c, h, w = states.shape

        timesteps = torch.arange(0, T, dtype=torch.long, device=states.device) 
        time_embeddings = self.embed_timestep(timesteps)
        states = states / 255.0 - 0.5
        states = self.convnet(states.view(B*T,c,h,w)).view(B, T, self.embedding_dim)
        states = torch.cat((states, proprio), dim=-1)
        state_embeddings = self.embed_state(states) + time_embeddings  #B, T, h_dim
        action_embeddings = self.embed_action(actions) + time_embeddings    #B, T, h_dim
        goal_embeddings = self.embed_goal(goals) + time_embeddings          #B, T, h_dim

        h = torch.stack(
            (goal_embeddings, state_embeddings, action_embeddings), dim=1
        ).permute(0, 2, 1, 3).reshape(B, 3 * T, self.h_dim)
        
        # transformer and prediction
        h = self.transformer(h)

        h = self.final_ln(h)

        # get h reshaped such that its size = (B , 3 , T , h_dim) and
        # h[:, 0, t] is conditioned on the input sequence g_0, s_0, a_0 ... g_t
        # h[:, 1, t] is conditioned on the input sequence g_0, s_0, a_0 ... g_t, s_t
        # h[:, 2, t] is conditioned on the input sequence g_0, s_0, a_0 ... g_t, s_t, a_t
        # that is, for each timestep (t) we have 3 output embeddings from the transformer,
        # each conditioned on all previous timesteps plus 
        # the 3 input variables at that timestep (g_t, s_t, a_t) in sequence.
        h = h.reshape(B, T, 3, self.h_dim).permute(0, 2, 1, 3)              # B, 3, T, h_dim
        action_preds = self.predict_action(h[:,1])
        return action_preds

class DecisionTransformer(nn.Module):
    def __init__(self, env_name, env, n_blocks, h_dim, context_len,
                 n_heads, drop_p, goal_dim=2, max_timestep=4096):
        super().__init__()

        self.env_name = env_name
        self.state_dim = env.observation_space['observation'].shape[0]
        self.act_dim = env.action_space.shape[0]
        self.goal_dim = goal_dim
        self.n_heads = n_heads
        self.h_dim = h_dim

        ### transformer blocks
        input_seq_len = 3 * context_len 
        blocks = [Block(h_dim, input_seq_len, n_heads, drop_p) for _ in range(n_blocks)]
        self.transformer = nn.Sequential(*blocks)

        ### projection heads (project to embedding)
        self.embed_timestep = nn.Embedding(max_timestep, h_dim)        
        self.embed_goal = torch.nn.Linear(goal_dim, h_dim)
        self.embed_state = torch.nn.Linear(self.state_dim, h_dim)
        self.embed_action = torch.nn.Linear(self.act_dim, h_dim)

        ### prediction heads
        self.final_ln = nn.LayerNorm(h_dim)
        self.predict_action = nn.Sequential(
            *([nn.Linear(h_dim, self.act_dim)] + ([nn.Tanh()]))
        )

    def forward(self, states, actions, goals):
        B, T, _ = states.shape

        timesteps = torch.arange(0, T, dtype=torch.long, device=states.device) 
        time_embeddings = self.embed_timestep(timesteps)
        state_embeddings = self.embed_state(states) + time_embeddings       #B, T, h_dim
        action_embeddings = self.embed_action(actions) + time_embeddings    #B, T, h_dim
        goal_embeddings = self.embed_goal(goals) + time_embeddings          #B, T, h_dim

        h = torch.stack(
            (goal_embeddings, state_embeddings, action_embeddings), dim=1
        ).permute(0, 2, 1, 3).reshape(B, 3 * T, self.h_dim)
        
        # transformer and prediction
        h = self.transformer(h)

        h = self.final_ln(h)

        # get h reshaped such that its size = (B , 3 , T , h_dim) and
        # h[:, 0, t] is conditioned on the input sequence g_0, s_0, a_0 ... g_t
        # h[:, 1, t] is conditioned on the input sequence g_0, s_0, a_0 ... g_t, s_t
        # h[:, 2, t] is conditioned on the input sequence g_0, s_0, a_0 ... g_t, s_t, a_t
        # that is, for each timestep (t) we have 3 output embeddings from the transformer,
        # each conditioned on all previous timesteps plus 
        # the 3 input variables at that timestep (g_t, s_t, a_t) in sequence.
        h = h.reshape(B, T, 3, self.h_dim).permute(0, 2, 1, 3)              # B, 3, T, h_dim
        action_preds = self.predict_action(h[:,1])
        return action_preds

class DecisionMaxConvTransformer(nn.Module):
    def __init__(self, env_name, env, n_blocks, h_dim, context_len,
                 n_heads, drop_p, goal_dim=2, max_timestep=4096):
        super().__init__()

        self.env_name = env_name
        self.obs_shape = env.observation_space['pixels'].shape
        self.state_dim = env.observation_space['observation'].shape[0]
        self.act_dim = env.action_space.shape[0]
        self.goal_dim = goal_dim
        self.n_heads = n_heads
        self.h_dim = h_dim

        ### transformer blocks
        self.num_inputs = 3
        input_seq_len = self.num_inputs * context_len 
        blocks = [Block(h_dim, input_seq_len, n_heads, drop_p) for _ in range(n_blocks)]
        self.transformer = nn.Sequential(*blocks)

        ### projection heads (project to embedding)
        self.embed_timestep = nn.Embedding(max_timestep, h_dim)        

        self.convnet = nn.Sequential(nn.Conv2d(self.obs_shape[-1], 32, 3, stride=2),
                            nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
                            nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
                            nn.ReLU())
        
        n = self.obs_shape[0]
        m = self.obs_shape[1]
        self.embedding_dim = ((n-2)//2-4)*((m-2)//2-4)*32

        self.embed_state = nn.Sequential(nn.Linear(self.embedding_dim + self.state_dim - 2, h_dim),
                                   nn.LayerNorm(h_dim), nn.Tanh())        
        self.embed_goal = torch.nn.Linear(goal_dim, h_dim)
        self.embed_rtg = nn.Linear(1, h_dim)
        self.embed_action = torch.nn.Linear(self.act_dim, h_dim)
        
        ### prediction heads
        self.final_ln = nn.LayerNorm(h_dim)
        self.predict_action = nn.Sequential(
            *([nn.Linear(h_dim, self.act_dim)] + ([nn.Tanh()]))
        )

        ### prediction heads
        self.predict_rtg = nn.Linear(h_dim, 1)
        self.predict_state = nn.Linear(h_dim, self.state_dim)
    
    def forward(self,
                            timesteps, 
                            states, 
                            proprio,
                            actions, 
                            returns_to_go):
        B, T, c, h, w = states.shape

        timesteps = torch.arange(0, T, dtype=torch.long, device=states.device) 
        time_embeddings = self.embed_timestep(timesteps)
        states = states / 255.0 - 0.5
        states = self.convnet(states.view(B*T,c,h,w)).view(B, T, self.embedding_dim)
        states = torch.cat((states, proprio), dim=-1)
        state_embeddings = self.embed_state(states) + time_embeddings  #B, T, h_dim
        action_embeddings = self.embed_action(actions) + time_embeddings    #B, T, h_dim
        rtg_embeddings = self.embed_rtg(returns_to_go) + time_embeddings          #B, T, h_dim

        h = (
            torch.stack(
                (
                    state_embeddings, 
                    rtg_embeddings, 
                    action_embeddings,
                ),
                dim=1,
            )
            .permute(0, 2, 1, 3)
            .reshape(B, self.num_inputs * T, self.h_dim)
        )
        
        # transformer and prediction
        h = self.transformer(h)

        h = self.final_ln(h)

        # get h reshaped such that its size = (B , 3 , T , h_dim) and
        # h[:, 0, t] is conditioned on the input sequence g_0, s_0, a_0 ... g_t
        # h[:, 1, t] is conditioned on the input sequence g_0, s_0, a_0 ... g_t, s_t
        # h[:, 2, t] is conditioned on the input sequence g_0, s_0, a_0 ... g_t, s_t, a_t
        # that is, for each timestep (t) we have 3 output embeddings from the transformer,
        # each conditioned on all previous timesteps plus 
        # the 3 input variables at that timestep (g_t, s_t, a_t) in sequence.
        h = h.reshape(B, T, 3, self.h_dim).permute(0, 2, 1, 3)              # B, 3, T, h_dim
        action_preds = self.predict_action(h[:,1])

        # get predictions
        rtg_preds  = self.predict_rtg(h[:, 0])            # predict rtg given s
        action_dist_preds = self.predict_action(h[:, 1])  # predict action given s, R
        state_preds = self.predict_state(h[:, 2])         # predict next state given s, R, a
        return (
            rtg_preds,
            action_dist_preds, 
            state_preds, 
        )

class DecisionMaxTransformer(nn.Module):
    def __init__(self, env_name, env, n_blocks, h_dim, context_len,
                 n_heads, drop_p, goal_dim=2, max_timestep=4096):
        super().__init__()

        self.env_name = env_name
        self.state_dim = env.observation_space['observation'].shape[0]
        self.act_dim = env.action_space.shape[0]
        self.goal_dim = goal_dim
        self.n_heads = n_heads
        self.h_dim = h_dim

        ### transformer blocks
        self.num_inputs = 3
        input_seq_len = self.num_inputs * context_len 
        blocks = [Block(h_dim, input_seq_len, n_heads, drop_p) for _ in range(n_blocks)]
        self.transformer = nn.Sequential(*blocks)

        ### projection heads (project to embedding)
        self.embed_timestep = nn.Embedding(max_timestep, h_dim)
        self.embed_state_goal = torch.nn.Linear(self.state_dim + self.goal_dim, h_dim)
        self.embed_rtg = nn.Linear(1, h_dim)
        self.embed_action = torch.nn.Linear(self.act_dim, h_dim)

        ### prediction heads
        self.predict_rtg = nn.Linear(h_dim, 1)
        self.predict_state = nn.Linear(h_dim, self.state_dim)

        ### prediction heads
        self.final_ln = nn.LayerNorm(h_dim)
        self.predict_action = nn.Sequential(
            *([nn.Linear(h_dim, self.act_dim)] + ([nn.Tanh()]))
        )
    
    def forward(self, 
                            timesteps, 
                            state_goals, 
                            actions, 
                            returns_to_go):
        B, T, _ = state_goals.shape

        timesteps = torch.arange(0, T, dtype=torch.long, device=state_goals.device) 
        time_embeddings = self.embed_timestep(timesteps)
        state_goal_embeddings = self.embed_state_goal(state_goals) + time_embeddings       #B, T, h_dim
        action_embeddings = self.embed_action(actions) + time_embeddings    #B, T, h_dim
        rtg_embeddings = self.embed_rtg(returns_to_go) + time_embeddings          #B, T, h_dim

        h = (
            torch.stack(
                (
                    state_goal_embeddings, 
                    rtg_embeddings,
                    action_embeddings,
                ),
                dim=1,
            )
            .permute(0, 2, 1, 3)
            .reshape(B, self.num_inputs * T, self.h_dim)
        )

        
        # transformer and prediction
        h = self.transformer(h)

        h = self.final_ln(h)

        # get h reshaped such that its size = (B , 3 , T , h_dim) and
        # h[:, 0, t] is conditioned on the input sequence g_0, s_0, a_0 ... g_t
        # h[:, 1, t] is conditioned on the input sequence g_0, s_0, a_0 ... g_t, s_t
        # h[:, 2, t] is conditioned on the input sequence g_0, s_0, a_0 ... g_t, s_t, a_t
        # that is, for each timestep (t) we have 3 output embeddings from the transformer,
        # each conditioned on all previous timesteps plus 
        # the 3 input variables at that timestep (g_t, s_t, a_t) in sequence.
        h = h.reshape(B, T, 3, self.h_dim).permute(0, 2, 1, 3)              # B, 3, T, h_dim

        # get predictions
        rtg_preds  = self.predict_rtg(h[:, 0])            # predict rtg given s
        action_dist_preds = self.predict_action(h[:, 1])  # predict action given s, R
        state_preds = self.predict_state(h[:, 2])         # predict next state given s, R, a
        return (
            rtg_preds,
            action_dist_preds, 
            state_preds, 
        )