import torch as th
import torch.nn as nn
import torch.nn.functional as F
from utils.gp_nets_simple import FCNet
from utils.calc import count_total_parameters
from utils.embed import polynomial_embed, binary_embed
from functools import partial

from modules.agents.transfer.tr_agent import *
    
class StateEBM(nn.Module, MultiHeadBase):
    def __init__(self, task2decomposer, args) -> None:
        nn.Module.__init__(self)
        MultiHeadBase.__init__(self, adaptor_func=partial(FCNet, args.entity_embed_dim, 1, hidden_layer=2, hidden_dim=32), device=args.device, max_head=args.n_reuse_heads, include_cur=True, in_dim=1)
        self.state_enc = StateAttnFeatureExtractor(task2decomposer, args)
        self.ebm_temp = args.ebm_temp
        self.out_func = lambda x: F.sigmoid(x) / self.ebm_temp
        # self.out_func = lambda x: th.exp(F.tanh(x) / self.ebm_temp)
    
    def forward(self, states, task):
        state_feature = self.state_enc.forward(states, task)
        energy = self.out_func(self.cur_adaptor(state_feature))
        return energy
    
    def all_forward(self, states, task):
        state_feature = self.state_enc.forward(states, task)
        if self._vectorized_forward is None:
            out = self.parallel_forward(state_feature)
        else:
            out = self.parallel_forward(state_feature).permute(1, 2, 3, 0)
        all_energy = self.out_func(out)
        
        return all_energy
    
class COMADCritic(nn.Module):
    "Prototype for Q, V func"
    def __init__(self, task2input_shape_info, task2decomposer, task2n_agents, surrogate_decomposer, args, is_v=False):
        super().__init__()
        
        self.args = args
        self.is_v = is_v
        self.entity_embed_dim = args.entity_embed_dim
        
        inputs = [task2input_shape_info, task2decomposer, task2n_agents, surrogate_decomposer, args]
        
        self.attn_enc = ObsAttnFeatureExtractor(*inputs)
        
        if self.is_v: # NOTE for discrete
            hidden_dim = getattr(args, "pa_hidden_dim", 64)
            self.output_head = FCNet(self.entity_embed_dim, 1, hidden_layer=2, hidden_dim=hidden_dim, use_layer_norm=True)
        else:
            self.output_head = OutputHead(surrogate_decomposer, args, input_dim=self.entity_embed_dim)
        
        print("Critic init...")
        count_total_parameters(self, is_concrete=True)
        
    def forward(self, obs, task):
        # obs (+last_act +agent_id), shape=[bs, T, n, d_n] -> [bsTn, n_act|1]
        shape = obs.shape
        attn_feature, enemy_feats = self.attn_enc.forward(obs, task)
        
        if self.is_v:
            q = self.output_head(attn_feature)
        else:
            q = self.output_head.forward(attn_feature, enemy_feats)
                
        q = q.reshape(*shape[:-1], -1)
                
        return q
