import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import repeat, rearrange

from sub_models.attention_blocks import get_vector_mask
from sub_models.attention_blocks import PositionalEncoding1D,  AttentionBlockKVCache
from sub_models.attention_blocks import CrossAttention , MoEAttentionBlockKVCache, CrossAttention4#Gate, 


 
class StochasticTransformerKVCache(nn.Module):
    def __init__(self, stoch_dim, action_dim, feat_dim, \
                num_layers, num_heads, max_length, dropout):
        super().__init__()
        self.action_dim = action_dim
        self.feat_dim = feat_dim
        self.num_heads = num_heads
 
        self.stem = nn.Sequential(
                nn.Linear(stoch_dim +action_dim, feat_dim, bias=False),
                nn.LayerNorm(feat_dim),
                nn.ReLU(inplace=True),
                nn.Linear(feat_dim, feat_dim, bias=False),
                nn.LayerNorm(feat_dim)
        )
        self.position_encoding = PositionalEncoding1D(max_length=max_length, embed_dim=feat_dim)
        self.layer_norm = nn.LayerNorm(feat_dim, eps=1e-6)
        self.num_layers = num_layers
        self.layer_stack = nn.ModuleList([
            AttentionBlockKVCache(feat_dim=feat_dim, hidden_dim=feat_dim*2, num_heads=num_heads, dropout=dropout) for _ in range(num_layers)             
        ])
        self.kv_cache_list = {}


    def forward(self, samples, action, mask):
        action = F.one_hot(action.long(), self.action_dim).float() 
        feats = self.stem(torch.cat([samples, action], dim=-1))
        feats = self.position_encoding(feats)
        feats = self.layer_norm(feats) 
        for layer in self.layer_stack: 
            feats, attn = layer(feats, feats, feats, mask) 
        return feats

    def reset_kv_cache_list(self, env_name, batch_size, dtype):
        #self.env_list = env_list
        model_device = next(self.parameters()).device 
        self.kv_cache_list[env_name] = [] 
        for layer in self.layer_stack:
            self.kv_cache_list[env_name].append(torch.zeros(size=(batch_size, 0, self.feat_dim), dtype=dtype, device=model_device)) 
  
  
    def forward_with_kv_cache(self, samples, action, env):
        '''
        Forward pass with kv_cache, cache stored in self.kv_cache_list
        ''' 
        assert samples.shape[1] == 1
         
        mask = get_vector_mask(self.kv_cache_list[env][0].shape[1]+1, samples.device) 

        action = F.one_hot(action.long(), self.action_dim).float()
       
        feats = self.stem(torch.cat([samples, action], dim=-1))
        feats = self.position_encoding.forward_with_position(feats, position=self.kv_cache_list[env][0].shape[1])
        feats = self.layer_norm(feats) 
        for idx, layer in enumerate(self.layer_stack): 
            self.kv_cache_list[env][idx] = torch.cat([self.kv_cache_list[env][idx], feats], dim=1) 
            feats, attn = layer(feats, self.kv_cache_list[env][idx], self.kv_cache_list[env][idx], mask) 
        return feats


class StochasticTransformerwoaction(nn.Module):
    def __init__(self, task_dim, n_activate_experts, feat_dim, \
                num_layers, num_heads, max_length, dropout):
        super().__init__() 
        self.feat_dim = feat_dim
        self.num_heads = num_heads
        self.position_encoding = PositionalEncoding1D(max_length=max_length, embed_dim=feat_dim)
        self.layer_norm = nn.LayerNorm(feat_dim, eps=1e-6)
        self.num_layers = num_layers
        self.task_emb_stem =  nn.Sequential(
                nn.Linear(task_dim + n_activate_experts*feat_dim, feat_dim, bias=False),
                nn.LayerNorm(feat_dim),
                nn.ReLU(inplace=True),
                nn.Linear(feat_dim, feat_dim, bias=False),
                nn.LayerNorm(feat_dim)
        )
        self.layer_stack = nn.ModuleList([
            AttentionBlockKVCache(feat_dim=feat_dim, hidden_dim=feat_dim*2, num_heads=num_heads, dropout=dropout) for _ in range(num_layers)             
        ])
        self.shared_kv_cache_list = {}


    def forward(self, feats, mask):
        feats = self.task_emb_stem(feats)
        feats = self.position_encoding(feats)
        feats = self.layer_norm(feats) 
        for layer in self.layer_stack: 
            feats, attn = layer(feats, feats, feats, mask) 
        return feats

 
    def reset_kv_cache_list(self, env_name, batch_size, dtype):
        #self.env_list = env_list
        model_device = next(self.parameters()).device 
        self.shared_kv_cache_list[env_name] = [] 
        for layer in self.layer_stack:
            self.shared_kv_cache_list[env_name].append(torch.zeros(size=(batch_size, 0, self.feat_dim), dtype=dtype, device=model_device)) 
  
  
    def forward_with_kv_cache(self, feats, env):
        '''
        Forward pass with kv_cache, cache stored in self.kv_cache_list
        '''
     
        assert feats.shape[1] == 1
        #env_id = self.env_list.index(env)
        mask = get_vector_mask(self.shared_kv_cache_list[env][0].shape[1]+1, feats.device)
        feats = self.task_emb_stem(feats)
        feats = self.position_encoding.forward_with_position(feats, position=self.shared_kv_cache_list[env][0].shape[1])
        feats = self.layer_norm(feats) 
        for idx, layer in enumerate(self.layer_stack): 
            self.shared_kv_cache_list[env][idx] = torch.cat([self.shared_kv_cache_list[env][idx], feats], dim=1) 
            feats, attn = layer(feats, self.shared_kv_cache_list[env][idx],self.shared_kv_cache_list[env][idx], mask) 
        return feats
