import math

import torch
import torch.nn as nn
from torch.nn import functional as F

class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)

class SharedAdaLinear(nn.Linear):
    def forward(self, cond_BD):
        C = self.weight.shape[0] // 6
        return super().forward(cond_BD).view(-1, 1, 6, C)  
    
class Adapter(nn.Module):
    def __init__(self, dim, bottleneck_dim=64):
        super().__init__()
        self.down = nn.Linear(dim, bottleneck_dim)
        self.relu = nn.GELU()
        self.up = nn.Linear(bottleneck_dim, dim)

    def forward(self, x):
        return self.up(self.relu(self.down(x)))

class Critic(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=256):
        super(Critic, self).__init__()
        self.q1_model = nn.Sequential(nn.Linear(state_dim + action_dim, hidden_dim),
                                      nn.Mish(),
                                      nn.Linear(hidden_dim, hidden_dim),
                                      nn.Mish(),
                                      nn.Linear(hidden_dim, hidden_dim),
                                      nn.Mish(),
                                      nn.Linear(hidden_dim, 1))

        self.q2_model = nn.Sequential(nn.Linear(state_dim + action_dim, hidden_dim),
                                      nn.Mish(),
                                      nn.Linear(hidden_dim, hidden_dim),
                                      nn.Mish(),
                                      nn.Linear(hidden_dim, hidden_dim),
                                      nn.Mish(),
                                      nn.Linear(hidden_dim, 1))

    def forward(self, state, action):
        x = torch.cat([state, action], dim=-1)
        return self.q1_model(x), self.q2_model(x)

    def q1(self, state, action):
        x = torch.cat([state, action], dim=-1)
        return self.q1_model(x)

    def q_min(self, state, action):
        q1, q2 = self.forward(state, action)
        return torch.min(q1, q2)

class SelfAttention(nn.Module):

    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.key = nn.Linear(config.n_embd, config.n_embd)
        self.query = nn.Linear(config.n_embd, config.n_embd)
        self.value = nn.Linear(config.n_embd, config.n_embd)
        self.attn_drop = nn.Dropout(config.attn_pdrop)
        self.resid_drop = nn.Dropout(config.resid_pdrop)
        self.proj = nn.Linear(config.n_embd, config.n_embd)
        joined_dim = config.observation_dim + config.action_dim + 2
        self.n_head = config.n_head

        self.caching, self.cached_k, self.cached_v = False, None, None
    
    def kv_caching(self, enable: bool): self.caching, self.cached_k, self.cached_v = enable, None, None

    def forward(self, x, attn_bias=None):
        B, T, C = x.size()

        k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) 
        q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) 
        v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) 

        if self.caching and self.cached_k is not None and self.cached_v is not None:
            k = torch.cat([self.cached_k, k], dim=2)
            v = torch.cat([self.cached_v, v], dim=2)
        
        if self.caching:
            self.cached_k, self.cached_v = k, v 

        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        if attn_bias is not None:
            att = att.masked_fill(attn_bias[:,:,:T,:T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        self._attn_map = att.clone()
        att = self.attn_drop(att)
        y = att @ v 
        y = y.transpose(1, 2).contiguous().view(B, T, C) 

        y = self.resid_drop(self.proj(y))
        return y

class AdaLayerNormSelfAttnBlock(nn.Module):

    def __init__(self, config, cond_dim: int, shared_aln: bool = False):
        super().__init__()
        self.n_embd = config.n_embd
        self.cond_dim = cond_dim
        self.shared_aln = shared_aln

        self.ln_wo_grad = nn.LayerNorm(config.n_embd, elementwise_affine=False)
        self.attn = SelfAttention(config)
        self.mlp = nn.Sequential(
            nn.Linear(config.n_embd, 4 * config.n_embd),
            nn.GELU(),
            nn.Linear(4 * config.n_embd, config.n_embd),
            nn.Dropout(config.resid_pdrop),
        )
        
        if self.shared_aln:
            self.ada_gss = nn.Parameter(torch.randn(1, 1, 6, config.n_embd) / config.n_embd**0.5)                                               # 116C
            self.shared_ada_linear = nn.Sequential(nn.GELU(), SharedAdaLinear(self.n_embd, 6*self.cond_dim)) if shared_aln else nn.Identity()   # B16D
        else:
            self.ada_lin = nn.Sequential(
                nn.Linear(cond_dim, 6 * config.n_embd),
                nn.GELU(),
            )

        self.drop_path = nn.Dropout(config.resid_pdrop) if config.resid_pdrop > 0 else nn.Identity()

    def forward(self, x: torch.Tensor, cond: torch.Tensor, attn_bias=None):
        """
        Args:
            x: [B, T, embedding]
            cond: [B, 1, embedding]
        # """
        B, T, C = x.shape
        
        if self.shared_aln:
            cond = self.shared_ada_linear(cond)
            cond_proj = cond.expand(-1, T, -1, -1) 
            cond_params = self.ada_gss + cond_proj 
            gamma1, gamma2, scale1, scale2, shift1, shift2 = cond_params.unbind(2)
        else:
            cond_expanded = cond.expand(-1, T, -1)  
            params = self.ada_lin(cond_expanded)   
            params = params.view(B, T, 6, C)        
            gamma1, gamma2, scale1, scale2, shift1, shift2 = params.unbind(2)

        x = x + self.attn( self.ln_wo_grad(x).mul(scale1.add(1)).add_(shift1), attn_bias=attn_bias ).mul_(gamma1)
        x = x + self.mlp( self.ln_wo_grad(x).mul(scale2.add(1)).add_(shift2) ).mul(gamma2)

        return x

class CondAttentionBlock(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.ln1 = nn.LayerNorm(config.n_embd)
        self.ln2 = nn.LayerNorm(config.n_embd)
        self.attn = SelfAttention(config)
        self.mlp = nn.Sequential(
            nn.Linear(config.n_embd, 4 * config.n_embd),
            nn.GELU(),
            nn.Linear(4 * config.n_embd, config.n_embd),
            nn.Dropout(config.resid_pdrop),
        )
        self.cond_proj = nn.Sequential( nn.GELU(), nn.Linear(config.n_embd, config.n_embd) )

    def forward(self, x, cond=None, attn_bias=None):
        if cond is not None:
            cond_emb = self.cond_proj(cond)
        else:
            cond_emb = torch.zeros_like(x[:, :1, :]) 
        x = x + cond_emb
        x = x + self.attn(self.ln1(x), attn_bias)
        x = x + self.mlp(self.ln2(x + cond_emb))
        return x

class AttentionBlock(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.ln1 = nn.LayerNorm(config.n_embd)
        self.ln2 = nn.LayerNorm(config.n_embd)
        self.attn = SelfAttention(config)
        self.mlp = nn.Sequential(
            nn.Linear(config.n_embd, 4 * config.n_embd),
            nn.GELU(),
            nn.Linear(4 * config.n_embd, config.n_embd),
            nn.Dropout(config.resid_pdrop),
        )

    def forward(self, x, attn_bias=None):
        x = x + self.attn(self.ln1(x), attn_bias)
        x = x + self.mlp(self.ln2(x))
        return x

class CausalSelfAttention(nn.Module):

    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.key = nn.Linear(config.n_embd, config.n_embd)
        self.query = nn.Linear(config.n_embd, config.n_embd)
        self.value = nn.Linear(config.n_embd, config.n_embd)
        self.attn_drop = nn.Dropout(config.attn_pdrop)
        self.resid_drop = nn.Dropout(config.resid_pdrop)
        self.proj = nn.Linear(config.n_embd, config.n_embd)
        self.register_buffer("mask", torch.tril(torch.ones(config.block_size, config.block_size))
                                     .view(1, 1, config.block_size, config.block_size))
        if "action_dim" in config:
            joined_dim = config.observation_dim + config.action_dim + 2
            self.mask.squeeze()[:,joined_dim-1::joined_dim] = 0
        self.n_head = config.n_head

    def forward(self, x, layer_past=None):
        B, T, C = x.size()

        k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) 
        q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) 
        v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) 

        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        self._attn_map = att.clone()
        att = self.attn_drop(att)
        y = att @ v 
        y = y.transpose(1, 2).contiguous().view(B, T, C) 

        # output projection
        y = self.resid_drop(self.proj(y))
        return y

class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln1 = nn.LayerNorm(config.n_embd)
        self.ln2 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.mlp = nn.Sequential(
            nn.Linear(config.n_embd, 4 * config.n_embd),
            nn.GELU(),
            nn.Linear(4 * config.n_embd, config.n_embd),
            nn.Dropout(config.resid_pdrop),
        )

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

class AsymBlock(nn.Module):
    def __init__(self, config, out_tokens):
        super().__init__()
        self.key = nn.Linear(config.n_embd, config.n_embd)
        self.query = nn.Parameter(torch.rand(1, out_tokens, config.n_embd))
        self.value = nn.Linear(config.n_embd, config.n_embd)

        self.ln1 = nn.LayerNorm(config.n_embd)
        self.ln2 = nn.LayerNorm(config.n_embd)
        self.attention = nn.MultiheadAttention(config.n_embd, config.n_head, batch_first=True)
        self.mlp = nn.Sequential(
            nn.Linear(config.n_embd, config.n_embd),
        )

    def forward(self, x):
        x = self.ln1(x)
        key = self.key(x)
        value = self.value(x)
        query = self.query.repeat([x.shape[0], 1, 1])
        attn_output, attn_output_weights = self.attention(query, key, value)
        x = self.mlp(self.ln2(attn_output))
        return x
    

class ScaleCausalSelfAttention(nn.Module):

    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.key = nn.Linear(config.n_embd, config.n_embd)
        self.query = nn.Linear(config.n_embd, config.n_embd)
        self.value = nn.Linear(config.n_embd, config.n_embd)
        self.attn_drop = nn.Dropout(config.attn_pdrop)
        self.resid_drop = nn.Dropout(config.resid_pdrop)
        self.proj = nn.Linear(config.n_embd, config.n_embd)
        self.patch_nums = config.v_patch_nums
        self.L = sum(self.patch_nums)
        d = torch.cat([torch.full((pn,), i) for i, pn in enumerate(self.patch_nums)]).view(1, self.L, 1)
        dT = d.transpose(1, 2) 
        causal_mask = torch.where(d >= dT, 1, 0).reshape(1, 1, self.L, self.L)
        self.register_buffer("mask", causal_mask)
        self.n_head = config.n_head

    def forward(self, x, layer_past=None):
        B, T, C = x.size()

        k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) 
        q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) 
        v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) 

        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        self._attn_map = att.clone()
        att = self.attn_drop(att)
        y = att @ v 
        y = y.transpose(1, 2).contiguous().view(B, T, C) 

        y = self.resid_drop(self.proj(y))
        return y
    
class ScaleBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln1 = nn.LayerNorm(config.n_embd)
        self.ln2 = nn.LayerNorm(config.n_embd)
        self.attn = ScaleCausalSelfAttention(config)
        self.mlp = nn.Sequential(
            nn.Linear(config.n_embd, 4 * config.n_embd),
            nn.GELU(),
            nn.Linear(4 * config.n_embd, config.n_embd),
            nn.Dropout(config.resid_pdrop),
        )

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x


def sample_gumbel(shape, device, eps=1e-20):
    U = torch.rand(shape, device=device)
    return -torch.log(-torch.log(U + eps) + eps)

def gumbel_softmax_sample(logits, temperature):
    device = logits.device
    gumbel_noise = sample_gumbel(logits.size(), device=device)
    y = logits + gumbel_noise
    return F.softmax(y / temperature, dim=-1)

def gumbel_softmax(logits, temperature, hard=True):
    y_soft = gumbel_softmax_sample(logits, temperature)
    
    if hard:
        index = y_soft.max(dim=-1, keepdim=True)[1]
        y_hard = torch.zeros_like(y_soft, device=logits.device).scatter_(-1, index, 1.0)
        y = y_hard - y_soft.detach() + y_soft
    else:
        y = y_soft
    
    return y