import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch import distributions as pyd
import wandb


class MaskedCausalAttention(nn.Module):
    def __init__(self, h_dim, max_T, n_heads, drop_p,w=0.1,b=-0.1):
        super().__init__()

        self.n_heads = n_heads
        self.max_T = max_T

        self.q_net = nn.Linear(h_dim, h_dim)
        self.k_net = nn.Linear(h_dim, h_dim)
        self.v_net = nn.Linear(h_dim, h_dim)

        self.proj_net = nn.Linear(h_dim, h_dim)

        self.att_drop = nn.Dropout(drop_p)
        self.proj_drop = nn.Dropout(drop_p)

        # causal mask
        ones = torch.ones((max_T, max_T))
        mask = torch.tril(ones).view(1, 1, max_T, max_T)
        self.register_buffer('mask', mask)

        # # Gaussian parameters
        # self.w = nn.Parameter(torch.tensor(w))  # decay factor
        # self.b = nn.Parameter(torch.tensor(b))  # self-attention bias

        self.w = w  # decay factor
        self.b = b  # self-attention bias

        # precompute distance matrix (T, T)
        distances = torch.arange(max_T).view(1, -1) - torch.arange(max_T).view(-1, 1)
        distances = distances.float().pow(2)  # squared distance
        self.register_buffer('distances', distances)  # (T, T)

    def forward(self, x):
        B, T, C = x.shape  # batch size, seq length, h_dim * n_heads
        N, D = self.n_heads, C // self.n_heads

        # rearrange q, k, v as (B, N, T, D)
        q = self.q_net(x).view(B, T, N, D).transpose(1, 2)
        k = self.k_net(x).view(B, T, N, D).transpose(1, 2)
        v = self.v_net(x).view(B, T, N, D).transpose(1, 2)

        # base attention weights (B, N, T, T)
        weights = q @ k.transpose(2, 3)

        # --- Gaussian bias per query modality ---
        # Gaussian bias
        gaussian_bias = -torch.abs(self.w * self.distances[:T, :T] + self.b)  # (T, T)

        # add Gaussian bias to attention logits
        weights = (weights + gaussian_bias)/ math.sqrt(D)


        # causal mask applied
        weights = weights.masked_fill(self.mask[..., :T, :T] == 0, float('-inf'))

        # normalize
        normalized_weights = F.softmax(weights, dim=-1)

        # attention (B, N, T, D)
        attention = self.att_drop(normalized_weights @ v)

        # gather heads and project (B, N, T, D) -> (B, T, N*D)
        attention = attention.transpose(1, 2).contiguous().view(B, T, N * D)
        out = self.proj_drop(self.proj_net(attention))

        return out, normalized_weights
#
# class MaskedCausalAttention(nn.Module):
#     def __init__(self, h_dim, max_T, n_heads, drop_p,init_sigmas = torch.FloatTensor([0.5,0.5,0.5])):
#         super().__init__()
#
#         self.n_heads = n_heads
#         self.max_T = max_T
#
#         self.q_net = nn.Linear(h_dim, h_dim)
#         self.k_net = nn.Linear(h_dim, h_dim)
#         self.v_net = nn.Linear(h_dim, h_dim)
#
#         self.proj_net = nn.Linear(h_dim, h_dim)
#
#         self.att_drop = nn.Dropout(drop_p)
#         self.proj_drop = nn.Dropout(drop_p)
#
#         # causal mask
#         ones = torch.ones((max_T, max_T))
#         mask = torch.tril(ones).view(1, 1, max_T, max_T)
#         self.register_buffer('mask', mask)
#
#         # relative position matrix: rel[i,j] = i - j
#         rel = torch.arange(max_T).view(max_T, 1) - torch.arange(max_T).view(1, max_T)
#         self.register_buffer("rel_positions", rel)
#
#         # learnable log-sigmas for 3 modalities (RTG=0, S=1, A=2)
#           # can be tuned
#         self.sigmas = nn.Parameter(init_sigmas)
#
#     def forward(self, x):
#         B, T, C = x.shape  # batch size, seq length, h_dim * n_heads
#         N, D = self.n_heads, C // self.n_heads
#
#         # rearrange q, k, v as (B, N, T, D)
#         q = self.q_net(x).view(B, T, N, D).transpose(1, 2)
#         k = self.k_net(x).view(B, T, N, D).transpose(1, 2)
#         v = self.v_net(x).view(B, T, N, D).transpose(1, 2)
#
#         # base attention weights (B, N, T, T)
#         weights = q @ k.transpose(2, 3) / math.sqrt(D)
#
#         # --- Gaussian bias per query modality ---
#         rel = -self.rel_positions[:T, :T].to(x.device)  # (T,T)
#         modalities = torch.arange(T, device=x.device) % 3  # query modality per row
#         self.sigmas=self.sigmas.to(device=x.device)
#
#         sigma_per_row = (self.sigmas[modalities]).view(T, 1)  # (T,1)
#         gauss_bias = torch.exp(- ((rel ** 2) / (sigma_per_row ** 2 + 1e-8))/2)/sigma_per_row  # (T,T)
#
#         # broadcast to batch & heads: (B,N,T,T)
#         gaussian_bias = gauss_bias.unsqueeze(0).unsqueeze(0).expand(B, N, -1, -1)
#
#         # add Gaussian bias to attention logits
#         weights = weights + gaussian_bias
#
#
#         # causal mask applied
#         weights = weights.masked_fill(self.mask[..., :T, :T] == 0, float('-inf'))
#
#         # normalize
#         normalized_weights = F.softmax(weights, dim=-1)
#
#         # attention (B, N, T, D)
#         attention = self.att_drop(normalized_weights @ v)
#
#         # gather heads and project (B, N, T, D) -> (B, T, N*D)
#         attention = attention.transpose(1, 2).contiguous().view(B, T, N * D)
#         out = self.proj_drop(self.proj_net(attention))
#
#         return out, normalized_weights

class Block(nn.Module):
    def __init__(self, h_dim, max_T, n_heads, drop_p=0.1,w=0.1,b=-0.1):
        super().__init__()
        self.attention = MaskedCausalAttention(h_dim, max_T, n_heads, drop_p,w,b)
        self.mlp = nn.Sequential(
            nn.Linear(h_dim, 4 * h_dim),
            nn.GELU(),
            nn.Linear(4 * h_dim, h_dim),
            nn.Dropout(drop_p),
        )
        self.ln1 = nn.LayerNorm(h_dim)
        self.ln2 = nn.LayerNorm(h_dim)

    def forward(self, x):
        # Attention -> LayerNorm -> MLP -> LayerNorm
        att_output, att_weights = self.attention(x)
        x = x + att_output  # residual
        x = self.ln1(x)
        x = x + self.mlp(x)  # residual
        x = self.ln2(x)
        return x, att_weights

class DecisionTransformer(nn.Module):
    """
    This model uses GPT to model (Return_1, state_1, action_1, Return_2, state_2, ...)
    """

    def __init__(
            self,
            state_dim,
            act_dim,
            hidden_size,
            max_length=None,
            max_ep_len=4096,
            action_tanh=True,
            sar=False,
            scale=1.,
            n_layer=6,
            n_head=8,
            action_range=(-1., 1.),
            state_mean=1,
            state_std=0,
            predict_rewards=False,
            w=0.1,
            b=-0.1,

            **kwargs
    ):
        super().__init__()

        self.action_range = action_range
        self.state_mean = state_mean
        self.state_std = state_std


        self.hidden_size = hidden_size

        self.state_dim = state_dim
        self.act_dim = act_dim
        self.sar = sar
        self.scale = scale

        self.max_length = max_length

        input_seq_len = 3 * max_length
        blocks = [Block(hidden_size, input_seq_len, n_head,w,b) for _ in range(n_layer)]
        self.transformer = nn.Sequential(*blocks)

        self.embed_timestep = nn.Embedding(max_ep_len, hidden_size)
        self.embed_return = torch.nn.Linear(1, hidden_size)
        self.embed_rewards = torch.nn.Linear(1, hidden_size)
        self.embed_state = torch.nn.Linear(self.state_dim, hidden_size)
        self.embed_action = torch.nn.Linear(self.act_dim, hidden_size)

        self.embed_ln = nn.LayerNorm(hidden_size)

        # note: we don't predict states or returns for the paper
        self.predict_state = torch.nn.Linear(hidden_size, self.state_dim)
        # self.predict_action = DiagGaussianActor(hidden_size, self.act_dim)
        self.predict_action = nn.Sequential(
            *(
                [nn.Linear(hidden_size, self.act_dim)]
                + ([nn.Tanh()] if action_tanh else [])
            )
        )

        if predict_rewards:
            self.predict_rewards = torch.nn.Linear(hidden_size, 1)
        else:
            self.predict_rewards = None


    def get_representation(self, states, actions, rewards=None, returns_to_go=None, timesteps=None, attention_mask=None):
        """
        Extracts a pooled sequence representation for contrastive learning.
        Returns: (batch_size, hidden_size)
        """
        batch_size, seq_length = states.shape[0], states.shape[1]
        time_embeddings = self.embed_timestep(timesteps)
        state_embeddings = self.embed_state(states) + time_embeddings
        action_embeddings = self.embed_action(actions) + time_embeddings
        returns_embeddings = self.embed_return(returns_to_go) + time_embeddings

        if self.sar:
            reward_embeddings = self.embed_rewards(rewards / self.scale) + time_embeddings
            stacked_inputs = (torch.stack((state_embeddings, action_embeddings, reward_embeddings), dim=1)
                              .permute(0, 2, 1, 3).reshape(batch_size, 3 * seq_length, self.hidden_size))
        else:
            stacked_inputs = (torch.stack((returns_embeddings, state_embeddings, action_embeddings), dim=1)
                              .permute(0, 2, 1, 3).reshape(batch_size, 3 * seq_length, self.hidden_size))

        stacked_inputs = self.embed_ln(stacked_inputs)

        for block in self.transformer:
            stacked_inputs, _ = block(stacked_inputs)

        # reshape to [B, T, 3, H]
        x = stacked_inputs.reshape(batch_size, seq_length, 3, self.hidden_size).permute(0, 2, 1, 3)
        state_action_tokens = x[:, 1:]  # (B, T, H)

        rep = state_action_tokens.mean(dim=1)

        return rep  # (B, H)




    def forward(self, states, actions, rewards=None, targets=None, returns_to_go=None, timesteps=None,
                attention_mask=None):

        batch_size, seq_length = states.shape[0], states.shape[1]

        time_embeddings = self.embed_timestep(timesteps)
        # embed each modality with a different head
        state_embeddings = self.embed_state(states) + time_embeddings
        action_embeddings = self.embed_action(actions) + time_embeddings
        returns_embeddings = self.embed_return(returns_to_go) + time_embeddings
        if self.sar:
            reward_embeddings = self.embed_rewards(rewards / self.scale) + time_embeddings

        # this makes the sequence look like (R_1, s_1, a_1, R_2, s_2, a_2, ...)
        # which works nice in an autoregressive sense since states predict actions
        if self.sar:
            stacked_inputs = (torch.stack((state_embeddings, action_embeddings, reward_embeddings), dim=1)
                              .permute(0, 2, 1, 3).reshape(batch_size, 3 * seq_length, self.hidden_size))
        else:
            stacked_inputs = (torch.stack(
                (returns_embeddings, state_embeddings, action_embeddings), dim=1)  # -->B,3,T,H
                              .permute(0, 2, 1, 3).reshape(batch_size, 3 * seq_length,
                                                           self.hidden_size))  # ---> r,s,a, r,s,a, r,s,a
        stacked_inputs = self.embed_ln(stacked_inputs)

        # we feed in the input embeddings (not word indices as in NLP) to the model
        attention_maps = []
        for block in self.transformer:
            stacked_inputs, attention = block(stacked_inputs)
            attention_maps.append(attention)

        # reshape x so that the second dimension corresponds to the original
        # returns (0), states (1), or actions (2); i.e. x[:,1,t] is the token for s_t
        x = stacked_inputs.reshape(batch_size, seq_length, 3, self.hidden_size).permute(0, 2, 1, 3)

        # get predictions
        if self.sar:
            action_preds = self.predict_action(x[:, 0])
            rewards_preds = self.predict_rewards(x[:, 1])
            state_preds = self.predict_state(x[:, 2])
        else:
            action_preds = self.predict_action(x[:, 1])  # predict action given state
            state_preds = self.predict_state(x[:, 2])  #
            # rewards_preds = self.predict_return(x[:, 2])
            # rewards_preds = None
            rewards_preds = self.predict_rewards(x[:, 2]) if self.predict_rewards is not None else None

        return state_preds, action_preds, rewards_preds, attention_maps



    def get_action_critic(self, critic, states, actions, rewards=None, returns_to_go=None, timesteps=None, **kwargs):
        # we don't care about the past rewards in this model
        repeats = 1
        states = states.reshape(1, -1, self.state_dim).repeat_interleave(repeats=repeats, dim=0)
        actions = actions.reshape(1, -1, self.act_dim).repeat_interleave(repeats=repeats, dim=0)
        rewards = rewards.reshape(1, -1, 1).repeat_interleave(repeats=repeats, dim=0)
        timesteps = timesteps.reshape(1, -1).repeat_interleave(repeats=repeats, dim=0)

        bs = returns_to_go.shape[0]
        returns_to_go = returns_to_go.reshape(bs, -1, 1).repeat_interleave(repeats=repeats // bs, dim=0)
        returns_to_go = torch.cat([returns_to_go,
                                   torch.randn((repeats - returns_to_go.shape[0], returns_to_go.shape[1], 1),
                                               device=returns_to_go.device)], dim=0)
        # returns_to_go = torch.cat([returns_to_go,
        #                            torch.zeros((repeats - returns_to_go.shape[0], returns_to_go.shape[1], 1),
        #                                        device=returns_to_go.device)], dim=0)

        if self.max_length is not None:
            states = states[:, -self.max_length:]
            actions = actions[:, -self.max_length:]
            returns_to_go = returns_to_go[:, -self.max_length:]
            rewards = rewards[:, -self.max_length:]
            timesteps = timesteps[:, -self.max_length:]

            # padding
            attention_mask = torch.cat([torch.zeros(self.max_length - states.shape[1]), torch.ones(states.shape[1])])
            attention_mask = attention_mask.to(dtype=torch.long, device=states.device).reshape(1, -1).repeat_interleave(
                repeats=repeats, dim=0)
            states = torch.cat(
                [torch.zeros((states.shape[0], self.max_length - states.shape[1], self.state_dim),
                             device=states.device), states], dim=1).to(dtype=torch.float32)
            returns_to_go = torch.cat(
                [torch.zeros((returns_to_go.shape[0], self.max_length - returns_to_go.shape[1], 1),
                             device=returns_to_go.device), returns_to_go], dim=1).to(dtype=torch.float32)
            timesteps = torch.cat(
                [torch.zeros((timesteps.shape[0], self.max_length - timesteps.shape[1]), device=timesteps.device),
                 timesteps], dim=1).to(dtype=torch.long)
            rewards = torch.cat(
                [torch.zeros((rewards.shape[0], self.max_length - rewards.shape[1], 1), device=rewards.device),
                 rewards], dim=1).to(dtype=torch.float32)

            actions = torch.cat(
                [torch.zeros((actions.shape[0], self.max_length - actions.shape[1], self.act_dim),
                             device=actions.device), actions],
                dim=1).to(dtype=torch.float32)
        else:
            attention_mask = None

        # returns_to_go[:, -1] = returns_to_go[:, -1] + torch.randn_like(returns_to_go[:, -1]) * 0.1

        _, action_preds, return_preds, attentions = self.forward(states, actions, rewards, None,
                                                                 returns_to_go=returns_to_go, timesteps=timesteps,
                                                                 attention_mask=attention_mask, **kwargs)



        action_preds = action_preds
        action_preds = action_preds[:, -1, :]
        if return_preds is not None:
            return_preds = return_preds[:, -1, :]
        #
        # state_mean=torch.FloatTensor(self.state_mean).to(state_rpt.device)
        # state_std=torch.FloatTensor(self.state_std).to(state_rpt.device)
        # q1,q2 = critic(state_rpt*state_std-state_mean, action_preds)
        # q_value = torch.min(q1, q2).flatten()
        # idx = torch.multinomial(F.softmax(q_value, dim=-1), 1)
        idx=0

        if return_preds is not None:
            return_preds = return_preds[idx].item()

        info = {"reward": 0,
                "q_value": 0,
                # "q_value": q_value[idx].item(),
                "return": return_preds if return_preds is not None else 0,
                "attention": torch.stack(attentions, 1).mean(2).squeeze(0)[idx].squeeze(0).cpu().numpy(),
                "attention_mask": attention_mask[idx],}

        return self.clamp_action(action_preds[idx]), info


    def clamp_action(self, action):
        return action.clamp(*self.action_range)

    def get_action(self, states, actions, rewards=None, returns_to_go=None, timesteps=None, **kwargs):
        # we don't care about the past rewards in this model

        states = states.reshape(1, -1, self.state_dim)
        actions = actions.reshape(1, -1, self.act_dim)
        rewards = rewards.reshape(1, -1, 1)
        timesteps = timesteps.reshape(1, -1)

        returns_to_go = returns_to_go[0].reshape(1, -1, 1)

        if self.max_length is not None:
            states = states[:, -self.max_length:]
            actions = actions[:, -self.max_length:]
            returns_to_go = returns_to_go[:, -self.max_length:]
            rewards = rewards[:, -self.max_length:]
            timesteps = timesteps[:, -self.max_length:]

            # padding
            attention_mask = torch.cat([torch.zeros(self.max_length - states.shape[1]), torch.ones(states.shape[1])])
            attention_mask = attention_mask.to(dtype=torch.long, device=states.device).reshape(1, -1)
            states = torch.cat(
                [torch.zeros((states.shape[0], self.max_length - states.shape[1], self.state_dim),
                             device=states.device), states],
                dim=1).to(dtype=torch.float32)
            returns_to_go = torch.cat(
                [torch.zeros((returns_to_go.shape[0], self.max_length - returns_to_go.shape[1], 1),
                             device=returns_to_go.device), returns_to_go],
                dim=1).to(dtype=torch.float32)
            timesteps = torch.cat(
                [torch.zeros((timesteps.shape[0], self.max_length - timesteps.shape[1]), device=timesteps.device),
                 timesteps],
                dim=1
            ).to(dtype=torch.long)
            rewards = torch.cat(
                [torch.zeros((rewards.shape[0], self.max_length - rewards.shape[1], 1), device=rewards.device),
                 rewards], dim=1).to(dtype=torch.float32)

            actions = torch.cat(
                [torch.zeros((actions.shape[0], self.max_length - actions.shape[1], self.act_dim),
                             device=actions.device), actions],
                dim=1).to(dtype=torch.float32)
        else:
            attention_mask = None

        _, action_preds, return_preds, attentions = self.forward(states, actions, rewards, None,
                                                                 returns_to_go=returns_to_go, timesteps=timesteps,
                                                                 attention_mask=attention_mask, **kwargs)

        action_preds = action_preds[:, -1, :]
        idx =0
        info = {"reward": 0,
                "q_value": 0,
                # "q_value": q_value[idx].item(),
                "return": return_preds if return_preds is not None else 0,
                "attention": torch.stack(attentions, 1).mean(2).squeeze(0)[idx].squeeze(0).cpu().numpy(),
                "attention_mask": attention_mask[idx], }

        return self.clamp_action(action_preds[0]),info

