import torch 
import torch.nn as nn 

from copy import deepcopy 

from sal.utils import ForwardPolicyMeta, BaseNN, BaseTransformer 
from sal.gym.sequences import state_to_node 
from sal.pac_utils import BayesianMLP, MLP, BayesianPolicyMeta 

class SequenceTransformer(nn.Module): 

    def __init__(self, seq_size, embed_dim, output_dim, num_heads=4, device='cpu'):
        super(SequenceTransformer, self).__init__()  
        # self-attention 
        self.device = device 
        self.embed_dim = embed_dim 
        self.num_heads = num_heads 
        self.emb = nn.Linear(1, embed_dim).to(self.device) 
        self.qkv = nn.Linear(embed_dim, embed_dim * 3 * num_heads).to(self.device) 
        self.attn = nn.MultiheadAttention(embed_dim * num_heads, 
                                          num_heads=num_heads, batch_first=True).to(self.device)
        self.out = nn.Linear(seq_size * embed_dim, output_dim).to(self.device) 
 
    def forward(self, x: torch.Tensor): 
        embeddings = self.emb(x.unsqueeze(-1)) # (B, L, E) 
        q, k, v = torch.chunk(self.qkv(embeddings), chunks=3, dim=2) # (B, L, E * H)  
        attn_out, _ = self.attn(q, k, v, need_weights=False) # (B, L, E * H) 
        attn_out = attn_out.view(x.shape[0], x.shape[1], self.embed_dim, self.num_heads).mean(dim=-1) 
        attn_out = attn_out.flatten(start_dim=1) 
        return self.out(attn_out) 

class LSTMBase(nn.Module): 

    def __init__(self, embed_dim, hidden_dim, output_dim, device='cpu'): 
        super(LSTMBase, self).__init__() 
        self.device = device 
        self.emb = nn.Linear(1, embed_dim).to(self.device) 
        self.lstm = nn.LSTM(
            embed_dim, hidden_dim, bidirectional=True, batch_first=True  
        ).to(self.device) 
        self.out = nn.Sequential(nn.LeakyReLU(), nn.Linear(2 * hidden_dim, output_dim)).to(self.device) 

    def forward(self, x: torch.Tensor):
        max_idx = torch.argwhere((x == x.max())[-1]).max()  
        embeddings = self.emb(x[:, :max_idx].unsqueeze(-1)) 
        out, (_, _) = self.lstm(embeddings) 
        return self.out(out[:, -1]) 

class ForwardPolicy(ForwardPolicyMeta): 

    def __init__(
            self, seq_size, num_actions, hidden_dim, num_layers, device='cpu', eps=.3, base_model='mlp'
        ): 
        super(ForwardPolicy, self).__init__(eps=eps, device=device) 
        self.seq_size = seq_size 
        self.num_actions = num_actions 
        self.hidden_dim = hidden_dim 
        self.device = device 
        self.num_layers = num_layers 
        self.base_model = base_model 

        if self.base_model == 'mlp': 
            self.mlp_logit = BaseNN(self.seq_size, self.hidden_dim, self.num_layers, self.num_actions).to(self.device)  
            self.mlp_flows = BaseNN(self.seq_size, self.hidden_dim, self.num_layers, 1).to(self.device)  
        elif self.base_model  == 'transformer': 
            self.mlp_logit = BaseTransformer(
                self.seq_size, self.hidden_dim, self.num_layers, self.num_actions, device=self.device
            )
            self.mlp_flows = BaseTransformer(
                self.seq_size, self.hidden_dim, self.num_layers, 1, device=self.device 
            )
        elif self.base_model == 'transformer-sequence': 
            self.mlp_logit = SequenceTransformer(
                self.seq_size, self.hidden_dim, self.num_actions, device=self.device
            )
            self.mlp_flows = SequenceTransformer(
                self.seq_size, self.hidden_dim, 1, device=self.device 
            )
        elif self.base_model == 'lstm': 
            self.mlp_logit = LSTMBase(
                self.hidden_dim, self.hidden_dim, self.num_actions, device=self.device
            )
            self.mlp_flows = LSTMBase(
                self.hidden_dim, self.hidden_dim, 1, device=self.device 
            )
    
    def get_latent_emb(self, batch_state, gflownets=None): 
        # Convert everything to tensors if they're not already
        states = batch_state.state.to(torch.get_default_dtype())
        curr_indices = torch.ones((batch_state.batch_size,), device=self.device) * batch_state.curr_idx
        max_indices = torch.ones((batch_state.batch_size,), device=self.device) * batch_state.max_depth 
        if gflownets is not None: 
            node_indices = state_to_node(batch_state, len(gflownets)) 

        # Mask for selecting the model
        if gflownets is not None: 
            mask = curr_indices <= max_indices 
        else: 
            mask = torch.ones_like(curr_indices).to(bool)  
        
        # Apply `self.mlp_logit` and `self.mlp_flows` to all states
        logit_all = self.mlp_logit(states)
        flows_all = self.mlp_flows(states)

        # Apply `models[node_idx].mlp_logit` and `models[node_idx].mlp_flows` where mask is False
        logit_model = torch.empty_like(logit_all)
        flows_model = torch.empty_like(flows_all)

        if gflownets is not None: 
            for idx in range(len(gflownets)):
                node_mask = (node_indices == idx) & ~mask
                if node_mask.any():
                    logit_model[node_mask] = gflownets[idx].pf.mlp_logit(states[node_mask])
                    flows_model[node_mask] = gflownets[idx].pf.mlp_flows(states[node_mask])

        # Combine the results
        logit_lst = torch.where(mask.unsqueeze(-1), logit_all, logit_model)
        flows_lst = torch.where(mask.unsqueeze(-1), flows_all, flows_model)

        return (
            logit_lst, flows_lst.squeeze(dim=-1)   
        )    

    def get_pol(self, logits_flows, mask=None):
        del mask  
        logits, flows = logits_flows 
        pol = torch.softmax(logits, dim=-1) 
        return pol, flows

class BayesianPolicy(BayesianPolicyMeta): 

    def __init__(self, seq_size, num_actions, hidden_dim, num_layers, device='cpu', eps=.3, **bayesian_kwargs): 
        super(BayesianPolicy, self).__init__(eps=eps, device=device, **bayesian_kwargs)
        self.seq_size = seq_size 
        self.num_actions = num_actions
        self.hidden_dim = hidden_dim 
        
        self.mlp_logit_prior = MLP(seq_size, num_actions, [hidden_dim] * (num_layers - 1)).to(device)  
        self.mlp_logit_posterior = deepcopy(self.mlp_logit_prior)
        self.mlp_logit = BayesianMLP(
            self.mlp_logit_prior, self.mlp_logit_posterior, device=device, **bayesian_kwargs 
        ).to(self.device) 
        self.mlp_flows = MLP(seq_size, 1, [hidden_dim] * (num_layers - 1)).to(device) 
        pass

    def get_latent_emb(self, env): 
        state = env.state.to(dtype=torch.get_default_dtype()) 
        return (
            self.mlp_logit(state), 
            self.mlp_flows(state).squeeze(dim=-1)  
        ) 

    def get_pol(self, logits_flows, mask=None):
        del mask  
        logits, flows = logits_flows 
        pol = torch.softmax(logits, dim=-1) 
        return pol, flows

class ForwardPolicyLA(ForwardPolicyMeta): 

    def __init__(self, seq_size, src_size, hidden_dim, num_layers, device='cpu', eps=.3): 
        super(ForwardPolicyLA, self).__init__(eps=eps, device=device) 
        self.input_dim = seq_size + 1 
        self.hidden_dim = hidden_dim 
        self.num_layers = num_layers 

        self.mlp_emb = BaseNN(self.input_dim, self.hidden_dim, self.num_layers).to(self.device)  
        self.mlp_logit = BaseNN(2*self.hidden_dim, 2*self.hidden_dim, 1, 1).to(self.device)   
        self.mlp_flows = BaseNN(self.input_dim, self.hidden_dim, self.num_layers, 1).to(self.device) 
        
    def get_latent_emb(self, batch_state): 
        dtype = torch.get_default_dtype() 
        parent_emb = self.mlp_emb(
            batch_state.unique_input.type(dtype)   
        ) 
        emb = list() 
        for child in batch_state.get_children(): 
            emb.append(
                torch.hstack([
                    parent_emb, 
                    self.mlp_emb(child.unique_input.type(dtype))
                ]).unsqueeze(dim=1)  
            ) 
        emb = torch.cat(emb, dim=1) 
        return (
            self.mlp_logit(emb).squeeze(dim=-1), 
            self.mlp_flows(batch_state.unique_input.type(dtype)).squeeze(dim=-1)    
        ) 

    def get_pol(self, logits_flows, mask): 
        logits, flows = logits_flows 
        pol = (logits * mask + (1 - mask) * self.masked_value).softmax(dim=-1) 
        return pol, flows 

class BackwardPolicy(nn.Module): 

    masked_value = -1e5 

    def __init__(self, device='cpu'): 
        super(BackwardPolicy, self).__init__() 
        self.device = device 

    def forward(self, batch_state, actions=None): 
        zeros = torch.zeros((batch_state.batch_size,), device=self.device)
        return zeros, zeros 