import torch.nn as nn; from torch.distributions.categorical import Categorical; import torch;  from sklearn.decomposition import PCA; import torch.nn.functional as F

class forward_backward(nn.Module):
    def __init__(self, **params):
        super(forward_backward, self).__init__()
        self.__dict__.update(params)
          
    def forward(self, net_input):   
        net_input = net_input[None, None, :]                                                  # LSTM inputs and outputs are of shape: sequence length, batch size, input dim
        if self.use_vanilla_torch:
            net_input = net_input.squeeze(0)
        self.get_actor(net_input)                                                             # run forward Actor
        self.get_critic(net_input)                                                            # run forward Critic
        return self.action, self.action_prob, self.value
          
    def get_actor(self, actor_input, eps = 1e-4):
        if self.train_recurrent:
            actor_return = self.actor(actor_input, self.actor_streams)
        else:
            with torch.no_grad():
                actor_return = self.actor(actor_input, self.actor_streams)

        if self.use_vanilla_torch:
            (S,L) = actor_return
            actor_output = S
            self.gate = torch.zeros(4,1, device=self.device)
        else:
            actor_output, (S,L), self.gate = actor_return

        
        self.Q_value = self.to_action(actor_output).view(-1)                                    # output to Q values            
        self.actor_streams = (S, L)                                                             # long and short term memory streams
        if self.use_vanilla_torch:
            self.LTM = L.detach()[None, ...]         
            self.output = S.detach()[None, ...]
        else:
            self.LTM = L.detach()
            self.output = S.detach()
        self.input = actor_input
        self.action_prob = self.soft(self.Q_value).view(-1)                                      # Q values to action probabilities
        self.action =  Categorical(self.action_prob).sample().view(-1).detach()                  # sample action 
        # if self.subnets == None:
        #     actor_output, (S, L), self.gate = self.actor(actor_input, self.actor_streams)
        #     if self.train_recurrent:
        #         actor_output = actor_output.detach()
        #         S = S.detach()
        #         L = L.detach()
        #     self.Q_value = self.to_action(actor_output).view(-1)                                    # output to Q values            
        #     self.actor_streams = (S, L)                                                             # long and short term memory streams
        #     self.LTM = L.detach()            
        #     self.output = S.detach()
        #     self.input = actor_input
        #     self.action_prob = self.soft(self.Q_value).view(-1)                                      # Q values to action probabilities
        #     self.action =  Categorical(self.action_prob).sample().view(-1).detach()                  # sample action 

        # else:
        #     joint = int(self.subnets == 'joint')
            
        #     state_input = torch.cat((actor_input, joint*self.pred_ctx.detach()[None,None,:]), dim = -1)
        #     state_recur, (state_S, state_L), self.gate = self.state_net(state_input, self.state_net_streams)
        #     self.pred_state = self.soft(self.to_state(state_recur)).view(-1)         
        #     self.state_net_streams = (state_S, state_L)                                                            

        #     LLR = torch.log((self.pred_state.detach()[0, None] + eps) / (self.pred_state.detach()[1, None] + eps))            
            
        #     ctx_input = torch.cat((actor_input, joint*LLR[None,None,:]), dim = -1)
        #     ctx_recur, (ctx_S, ctx_L), _ = self.ctx_net(ctx_input, self.ctx_net_streams)
        #     self.pred_ctx = torch.sigmoid(self.to_ctx(ctx_recur)).view(-1)            
        #     """ zero out ctx subnet """ 
        #     # self.pred_ctx = self.pred_ctx *0 
        #     """ zero out ctx subnet """ 
                
        #     self.ctx_net_streams = (ctx_S, ctx_L)                                                            
                        
        #     """ pick a net to visualize """ 
        #     self.LTM = state_L.detach()  
        #     self.output = state_S.detach()
        #     self.input = state_input
        #     """ pick a net to visualize """ 
            
        #     state_ctx = torch.cat((LLR, self.pred_ctx.detach()), -1)  
        #     self.Q_value = self.subnets_to_action(state_ctx).view(-1)
        #     self.action_prob = self.soft(self.Q_value).view(-1)                                      # Q values to action probabilities
        #     self.action =  Categorical(self.action_prob).sample().view(-1).detach()                  # sample action 
            
        """ random licking """
        # if torch.rand(1, device = self.device) < self.RAP:
        #     self.action = Categorical(self.RAP).sample().view(-1).detach() 
        # self.action_prob = (1-self.RAP)*self.action_prob + self.RAP * self.RAP
        """ random licking """
        
    def get_critic(self, critic_input):
        if self.train_recurrent:
            critic_return  = self.critic(critic_input, self.critic_streams)
        else:
            with torch.no_grad():
                critic_return = self.critic(critic_input, self.critic_streams)

        if self.use_vanilla_torch:
            (S, L) = critic_return
            critic_output = S
        else:
            critic_output, (S,L), _ = critic_return

        self.value = self.to_value(critic_output).view(-1)                                       #  output to value
        self.critic_streams = (S, L)                                                             # long and short term memory streams                                                                                                               # short term memory stream
          
    def backwards(self, R = 0, critic_loss = 0, policy_loss = 0, entropy_loss = 0, episode_ratio = 0):
        if self.decrease_entropy:                                                                                    # if linearly decreasing entropy throughout training
            self.B_ent = 1 - episode_ratio                                                                             # entropy goes from 1 to 0 

        # self.rewards[-1] = self.values[-1].detach()   
        # RPE = torch.zeros(len(self.rewards), device = self.device)
        # for i, r in reversed(list(enumerate(self.rewards))):                                                           # iterate through rewards from most to least recent 
        #     R = r + self.discount*R                                                                                    # sum of discounted future rewards
        #     RPE[i] = R - self.values[i]                                                                                   # reward prediction error               
        # critic_loss = (RPE**2).mean()                                                                         # value loss is MSE between predicted value and sum of discounted future reward
        # entropy_loss = (self.action_probs * (self.action_probs.log())).mean()            # entropy loss is maximizing entropy
        # policy_loss = (- self.action_probs[self.actions.long(), :].log() * RPE.detach()).mean()                               # policy loss minimizes the negative LLR(action) * reward prediction error        

        self.rewards[-1] = self.values[-1].detach()        
        for i, r in reversed(list(enumerate(self.rewards))):                                                           # iterate through rewards from most to least recent 
            R = r + self.discount*R                                                                                    # sum of discounted future rewards
            # R = r if j == 0 else r + self.discount * self.values[i + 1].detach()                                     # Temporal difference version              
            RPE = R - self.values[i]                                                                                   # reward prediction error               
            act_i = self.actions[i].long()                                                                             # index of action taken 
            critic_loss = critic_loss + RPE**2                                                                         # value loss is MSE between predicted value and sum of discounted future reward
            entropy_loss = entropy_loss + (self.action_probs[:, i] * (self.action_probs[:, i].log())).sum()            # entropy loss is maximizing entropy
            policy_loss = policy_loss - self.action_probs[act_i, i].log() * RPE.detach()                               # policy loss minimizes the negative LLR(action) * reward prediction error        

        self.update(critic_loss, entropy_loss, policy_loss)
          
    def update(self, critic_loss, entropy_loss, policy_loss):
        loss = (policy_loss + self.B_val*critic_loss + self.B_ent*entropy_loss).mean()
        if self.subnets != None:
            loss = loss + self.get_subnet_loss()       
        
        self.optim.zero_grad()
        loss.backward()
        self.optim.step()
        
    def get_subnet_loss(self):
        state_loss = self.CEL(self.pred_states, self.true_states)/self.true_states.shape[1]
        ctx_loss = self.MSE(self.pred_ctxs, self.true_ctxs)
        return state_loss + ctx_loss
        