import random
from typing import Optional, Tuple
from debug import debug_print
import torch
import torch.nn as nn
import numpy as np
import functools
import copy
import os
import math
from torch import Tensor
import torch.nn.functional as F
import tqdm
import onpolicy.algorithms.diffusion_ac.mlp as mlp
MAX_BZ_SIZE = 1024

class GaussianFourierProjection(nn.Module):
  """Gaussian random features for encoding time steps."""  
  def __init__(self, embed_dim, scale=30.):
    super().__init__()
    # Randomly sample weights during initialization. These weights are fixed 
    # during optimization and are not trainable.
    self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False)
  def forward(self, x):
    x_proj = x[..., None] * self.W.view(*([1,] * len(x.shape)), -1) * 2 * np.pi
    return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)


class Dense(nn.Module):
  """A fully connected layer that reshapes outputs to feature maps."""
  def __init__(self, input_dim, output_dim):
    super().__init__()
    self.dense = nn.Linear(input_dim, output_dim)
  def forward(self, x):
    return self.dense(x)

class SiLU(nn.Module):
  def __init__(self):
    super().__init__()
  def forward(self, x):
    return x * torch.sigmoid(x)

class MlpQnet(nn.Module):
    def __init__(self, sdim, adim) -> None:
        super().__init__()
        self.sort = nn.Sequential(nn.Linear(sdim+adim, 256), SiLU(), nn.Linear(256, 256), SiLU(),nn.Linear(256, 1))
    def forward(self, s, a):
        return self.sort(torch.cat([s,a], axis=-1))

class MUlQnet(nn.Module):
    def __init__(self, sdim, adim) -> None:
        super().__init__()
        self.sort = nn.Sequential(nn.Linear(sdim+adim, 512), SiLU(), nn.Linear(512, 256), SiLU(),nn.Linear(256, 256), SiLU(),nn.Linear(256, 1))
    def forward(self, s, a):
        return self.sort(torch.cat([s,a], axis=-1))
    
class Residual_Block(nn.Module):
    def __init__(self, input_dim, output_dim, t_dim=256, last=False):
        super().__init__()
        self.time_mlp = nn.Sequential(
            SiLU(),
            nn.Linear(t_dim, output_dim),
        )
        self.dense1 = nn.Sequential(nn.Linear(input_dim, output_dim),SiLU())
        self.dense2 = nn.Sequential(nn.Linear(output_dim, output_dim),SiLU())
        self.modify_x = nn.Linear(input_dim, output_dim) if input_dim != output_dim else nn.Identity()
    def forward(self, x, t):
        h1 = self.dense1(x) + self.time_mlp(t)
        h2 = self.dense2(h1)
        return h2 + self.modify_x(x)

class ScoreNet(nn.Module):
    def __init__(self, state_dim, action_dim, t_dim=32, unet_hidden_size = 256, device=torch.device('cpu'), **kwargs):
        super().__init__()
        # The swish activation function
        self.act = lambda x: x * torch.sigmoid(x)
        self.embed = nn.Sequential(GaussianFourierProjection(embed_dim=t_dim), nn.Linear(t_dim, t_dim))
        self.pre_sort_condition = nn.Sequential(Dense(state_dim, unet_hidden_size // 4 - t_dim), SiLU())
        self.sort_t = nn.Sequential(
                        nn.Linear(unet_hidden_size // 4, unet_hidden_size // 2),                        
                        SiLU(),
                        nn.Linear(unet_hidden_size // 2, unet_hidden_size // 2),
                    )
        self.down_block1 = Residual_Block(action_dim, unet_hidden_size * 2)
        self.down_block2 = Residual_Block(unet_hidden_size * 2, unet_hidden_size)
        self.down_block3 = Residual_Block(unet_hidden_size, unet_hidden_size // 2)
        self.middle1 = Residual_Block(unet_hidden_size // 2, unet_hidden_size // 2)
        self.up_block3 = Residual_Block(unet_hidden_size, unet_hidden_size)
        self.up_block2 = Residual_Block(unet_hidden_size * 2, unet_hidden_size * 2)
        self.last = nn.Linear(unet_hidden_size * 4, action_dim)
        self.device = device
        
    def forward(self, x, t, state):
        embed = self.embed(t)
        
        embed = torch.cat([self.pre_sort_condition(state), embed], dim=-1)
        embed = self.sort_t(embed)
        d1 = self.down_block1(x, embed)
        d2 = self.down_block2(d1, embed)
        d3 = self.down_block3(d2, embed)
        u3 = self.middle1(d3, embed)
        u2 = self.up_block3(torch.cat([d3, u3], dim=-1), embed)
        u1 = self.up_block2(torch.cat([d2, u2], dim=-1), embed)
        u0 = torch.cat([d1, u1], dim=-1)
        h = self.last(u0)

        return h
    
class Unet(nn.Module):
    def __init__(self, state_dim, action_dim, t_dim=32, hidden_size=256, device=torch.device('cpu')):
        super().__init__()
        # Time embedding
        self.embed = nn.Sequential(
            GaussianFourierProjection(embed_dim=t_dim),
            nn.Linear(t_dim, t_dim)
        )
        
        # State conditioning
        self.state_encoder = nn.Sequential(
            Dense(state_dim, hidden_size // 4 - t_dim),
            SiLU()
        )
        
        # Time and state processing
        self.time_state_processor = nn.Sequential(
            nn.Linear(hidden_size // 4, hidden_size // 2),
            SiLU(),
            nn.Linear(hidden_size // 2, hidden_size // 2)
        )
        
        # Downsampling path
        self.down1 = Residual_Block(action_dim, hidden_size * 2)
        self.down2 = Residual_Block(hidden_size * 2, hidden_size)
        self.down3 = Residual_Block(hidden_size, hidden_size // 2)
        
        # Bottleneck
        self.middle = Residual_Block(hidden_size // 2, hidden_size // 2)
        
        # Upsampling path
        self.up3 = Residual_Block(hidden_size, hidden_size)
        self.up2 = Residual_Block(hidden_size * 2, hidden_size * 2)
        
        # Final projection
        self.final = nn.Linear(hidden_size * 4, action_dim)
        
        self.device = device

    def forward(self, x, t, state):
        # Time embedding
        t_embed = self.embed(t)
        
        # Process state and time
        state_embed = self.state_encoder(state)
        embed = torch.cat([state_embed, t_embed], dim=-1)
        embed = self.time_state_processor(embed)
        
        # Downsampling
        d1 = self.down1(x, embed)
        d2 = self.down2(d1, embed)
        d3 = self.down3(d2, embed)
        
        # Bottleneck
        middle = self.middle(d3, embed)
        
        # Upsampling with skip connections
        u3 = self.up3(torch.cat([d3, middle], dim=-1), embed)
        u2 = self.up2(torch.cat([d2, u3], dim=-1), embed)
        
        # Final concatenation and projection
        u1 = torch.cat([d1, u2], dim=-1)
        output = self.final(u1)
        
        return output

class MlpScoreNet(nn.Module):
    def __init__(self, state_dim, action_dim, t_dim=32, unet_hidden_size = 256, device=torch.device('cpu'), **kwargs):
        super().__init__()
        t_dim = unet_hidden_size // 8
        # debug_print(unet_hidden_size)
        # The swish activation function
        # debug_print(state_dim)
        self.act = lambda x: x * torch.sigmoid(x)
        self.embed = nn.Sequential(GaussianFourierProjection(embed_dim=t_dim), nn.Linear(t_dim, t_dim))
        self.joint_train = kwargs.get('joint_train', True)
        self.rnum_agents = kwargs.get('rnum_agents', 1)
        hidden_size = unet_hidden_size
        self.dense1 = Dense(t_dim, 64)
        # state_dim *= self.rnum_agents #
        # action_dim *= self.rnum_agents #
        self.dense2 = Dense(state_dim + action_dim, hidden_size - 64)
        self.block1 = mlp.ResidualMLP([hidden_size] * 5, activation_type='Mish', out_activation_type='Identity')
        # self.block1 = nn.Sequential(
        #     nn.Linear(hidden_size, hidden_size * 2),
        #     nn.Mish(),
        #     # nn.Mish(),
        #     nn.Linear(hidden_size * 2, hidden_size),
        #     nn.Mish(),
        #     nn.Linear(hidden_size, hidden_size),
        #     nn.Mish(),
        #     nn.Linear(hidden_size, hidden_size),
        # )
        # self.block2 = nn.Sequential(
        #     nn.Linear(hidden_size, hidden_size * 2),
        #     nn.ReLU(),
        #     nn.Linear(hidden_size * 2, hidden_size),
        #     nn.ReLU(),
        #     nn.Linear(hidden_size, hidden_size),
        #     nn.ReLU(),
        # )
        self.decoder = Dense(hidden_size, action_dim)
        self.device = device

    def forward(self, x, t, state):
        # debug_print(x.shape, t.shape, state.shape)
        if not self.joint_train:
            # state = state.reshape(state.shape[0], -1)
            # x = x.clone()
            x = x.reshape(x.shape[0], self.rnum_agents, -1)
            # t = t.clone()
            t = t[:, None].repeat(1, self.rnum_agents)
        
        x = torch.cat([state, x], dim=-1)
        # Obtain the Gaussian random feature embedding for t   
        embed = self.act(self.embed(t))
        # Encoding path
        h = torch.cat((self.dense2(x), self.dense1(embed)),dim=-1)
        
        # a = h.clone()
        # h = self.block2(h) + a
        h = self.block1(h)
        h = self.decoder(self.act(h))
        # debug_print(h.shape)
        h = h.reshape(h.shape[0], -1)

        return h
    

class EncoderLayer(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.embed_dim = config.n_embed
        self.self_attn = Attention(self.embed_dim,
                                   config.n_head,
                                   dropout=config.attention_dropout)
        self.normalize_before = config.normalize_before
        self.self_attn_layer_norm = LayerNorm(self.embed_dim)
        self.dropout = config.dropout
        self.activation_fn = ACT2FN[config.activation_function]
        self.activation_dropout = config.activation_dropout
        self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim)
        self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim)
        self.final_layer_norm = LayerNorm(self.embed_dim)

    def forward(self, x, encoder_padding_mask):
        """
        Args:
            x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
            encoder_padding_mask (ByteTensor): binary ByteTensor of shape
                `(batch, src_len)` where padding elements are indicated by ``1``.
            for t_tgt, t_src is excluded (or masked out), =0 means it is
            included in attention

        Returns:
            encoded output of shape `(seq_len, batch, embed_dim)`
        """
        residual = x
        if self.normalize_before:
            x = self.self_attn_layer_norm(x)
        x = self.self_attn(query=x,
                           key=x,
                           key_padding_mask=encoder_padding_mask)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        if not self.normalize_before:
            x = self.self_attn_layer_norm(x)

        residual = x
        if self.normalize_before:
            x = self.final_layer_norm(x)
        x = self.activation_fn(self.fc1(x))
        x = F.dropout(x, p=self.activation_dropout, training=self.training)
        x = self.fc2(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        if not self.normalize_before:
            x = self.final_layer_norm(x)
        if torch.isinf(x).any() or torch.isnan(x).any():
            clamp_value = torch.finfo(x.dtype).max - 1000
            x = torch.clamp(x, min=-clamp_value, max=clamp_value)
        return x


class TransformerEncoder(nn.Module):
    """
    Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer
    is a :class:`EncoderLayer`.

    Args:
        config: 
    """

    def __init__(self, config, embed_dim):
        super().__init__()

        self.dropout = config.dropout
        self.layerdrop = config.encoder_layerdrop

        # embed_dim = embed_tokens.embedding_dim
        self.embed_scale = math.sqrt(
            embed_dim) if config.scale_embedding else 1.0
        # self.padding_idx = embed_tokens.padding_idx
        # self.max_source_positions = config.max_position_embeddings

        # self.embed_tokens = embed_tokens
        self.layers = nn.ModuleList(
            [EncoderLayer(config) for _ in range(config.n_layer)])
        self.layernorm_embedding = LayerNorm(
            embed_dim) if config.normalize_embedding else nn.Identity()

        self.layer_norm = LayerNorm(
            config.n_embed) if config.add_final_layer_norm else None

    def forward(self, x, attention_mask=None):
        """
        Args:
            input_ids (LongTensor): tokens in the source language of shape
                `(batch, src_len)`
            attention_mask (torch.LongTensor): indicating which indices are padding tokens.
        """
        # check attention mask and invert

        # inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
        # embed_pos = self.embed_positions(input_ids)
        # x = inputs_embeds # + embed_pos
        x = self.layernorm_embedding(x)
        x = F.dropout(x, p=self.dropout, training=self.training)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        for encoder_layer in self.layers:

            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
            dropout_probability = random.uniform(0, 1)
            if self.training and (dropout_probability <
                                  self.layerdrop):  # skip the layer
                continue
            else:
                x = encoder_layer(x, attention_mask)

        if self.layer_norm:
            x = self.layer_norm(x)

        # T x B x C -> B x T x C
        x = x.transpose(0, 1)

        return x


ACT2FN = {
    "relu": F.relu,
    "gelu": F.gelu,
    "tanh": torch.tanh,
    "sigmoid": torch.sigmoid,
}


class Config(object):
    dropout = 0.1
    attention_dropout = 0.0
    encoder_layerdrop = 0.0
    decoder_layerdrop = 0.0
    scale_embedding = None
    static_position_embeddings = False
    extra_pos_embeddings = 0
    normalize_before = False
    activation_function = "gelu"
    activation_dropout = 0.0
    normalize_embedding = True
    add_final_layer_norm = False
    init_std = 0.02

    def __init__(self, **kwargs):
        for k, v in kwargs.items():
            setattr(self, k, v)





class Attention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(
            self,
            embed_dim,
            num_heads,
            dropout=0.0,
            bias=True,
            encoder_decoder_attention=False,  # otherwise self_attention
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
        self.scaling = self.head_dim**-0.5

        self.encoder_decoder_attention = encoder_decoder_attention
        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.cache_key = "encoder_decoder" if self.encoder_decoder_attention else "self"

    def _shape(self, tensor, seq_len, bsz):
        return tensor.contiguous().view(seq_len, bsz * self.num_heads,
                                        self.head_dim).transpose(0, 1)

    def forward(
        self,
        query,
        key: Optional[Tensor],
        key_padding_mask: Optional[Tensor] = None,
        attn_mask: Optional[Tensor] = None,
    ) -> Tuple[Tensor, Optional[Tensor]]:
        """
        Compute the attention output. You need to apply key_padding_mask and attn_mask before softmax operation.

        Args:
            query (torch.Tensor): The input query tensor, shape (seq_len, batch_size, embed_dim).
            key (Optional[torch.Tensor]): The input key tensor, shape (seq_len, batch_size, embed_dim).
                                         If None, it's assumed to be the same as the query tensor.
            key_padding_mask (Optional[torch.Tensor]): The key padding mask tensor, shape (batch_size, seq_len).
                                                      Default: None
            attn_mask (Optional[torch.Tensor]): The attention mask tensor, shape (seq_len, seq_len).
                                               Default: None

        Returns:
            attn_output (torch.Tensor): The attention output tensor, shape (seq_len, batch_size, embed_dim).

        """
        ##############################################################################
        #                  TODO: You need to complete the code here                  #
        ##############################################################################
        q, k, v = self.q_proj(query), self.k_proj(key), self.v_proj(key)
        q, k, v = [t.view(t.shape[0], t.shape[1], self.num_heads, self.head_dim) for t in (q, k, v)]
        score = torch.matmul(q.permute(1, 2, 0, 3), k.permute(1, 2, 3, 0)) * self.scaling #batch_size  num_heads  seq_len  seq_len
        if attn_mask is not None:
            score = score + attn_mask[None, None]
        if key_padding_mask is not None:
            score = score.masked_fill(key_padding_mask[:, None, None, :], float('-inf'))
        attn_output = torch.matmul(F.softmax(score, dim=-1), v.permute(1, 2, 0, 3)).permute(2, 0, 1, 3).reshape(query.shape[0], query.shape[1], -1)
        ##############################################################################
        #                              END OF YOUR CODE                              #
        ##############################################################################
        return attn_output


def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True):

    return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)

class TransformerScoreNet(nn.Module):
    def __init__(self, state_dim, action_dim, num_agents, t_dim=32, device=torch.device('cpu'), **kwargs):
        super().__init__()
        # action_dim = action_dim // num_agents
        # The swish activation function
        debug_print('transformer', state_dim, action_dim, num_agents, t_dim)
        embed_dim = 256
        self.config = Config(
            n_embed=embed_dim,
            n_layer=2,
            n_head=2,
            ffn_dim=embed_dim,
        )
        self.act = lambda x: x * torch.sigmoid(x)
        self.num_agents = num_agents
        self.embed = nn.Sequential(GaussianFourierProjection(embed_dim=t_dim), nn.Linear(t_dim, t_dim))
        # self.project = nn.Linear(state_dim + action_dim, embed_dim)
        self.transformer = TransformerEncoder(self.config, embed_dim)
        self.dense1 = Dense(t_dim, 16)
        self.dense2 = Dense(state_dim + action_dim, embed_dim - 16)
        self.block1 = nn.Sequential(
            nn.Linear(256 * num_agents, 256),
            SiLU(),
            nn.Linear(256, 512),
            SiLU(),
            nn.Linear(512, 256),
            SiLU(),
            nn.Linear(256, 256),
            SiLU(),
            nn.Linear(256, 256),
        )
        # action_dim *= num_agents
        self.decoder = nn.Sequential(
            nn.Linear(embed_dim, 256),
            nn.LayerNorm(256),
            # SiLU(),
            # # nn.Linear(512, 256),
            # # SiLU(),
            # nn.Linear(256, 256),
            # nn.LayerNorm(256),
            # SiLU(),
            Dense(256, action_dim)
        )
        # self.device = device

    def forward(self, x, t, state):
        # debug_print(x.shape, t.shape, state.shape)
        x = x.clone()
        t = t.clone()
        x = torch.cat([state, x.reshape(*x.shape[:-1], self.num_agents, -1)], dim=-1)
        # debug_print(t.shape)
        t=t.unsqueeze(-1).repeat(1, self.num_agents) * 2 * self.num_agents + torch.arange(self.num_agents, device=x.device).unsqueeze(0)
        # # Obtain the Gaussian random feature embedding for t   
        embed = self.act(self.embed(t))
        
        # embed = self.act(self.embed(t).unsqueeze(1)).repeat(1, self.num_agents, 1)
        # debug_print(embed.shape)
        # debug_print(x.shape, embed.shape)
        h = torch.cat((self.dense2(x), self.dense1(embed)),dim=-1)
        # debug_print(h.shape)
        # debug_print(h.reshape(h.shape[0], -1).shape)
        # debug_print(embed.shape, h.shape)
        
        out = self.transformer(h)
        # out = self.block1(h.reshape(h.shape[0], -1))
        # debug_print(out.shape)
        out = self.decoder(self.act(out)).reshape(out.shape[0], -1)
        # out = F.tanh(out) * 10
        # debug_print(out.shape)
        # Encoding path
        
        # h = self.block1(h)
        # h = self.decoder(self.act(h))

        return out
    