

import torch 
import torch.nn as nn
from utils.graphconf import *
import torch.nn.functional as F
from torch.distributions.one_hot_categorical import OneHotCategorical
from torch.distributions.relaxed_categorical import RelaxedOneHotCategorical



def cat_softmax(probs, mode,tau=1,hard=False,dim=-1, EXPLORE = False):
    if mode == 'REINFORCE' or mode == 'SCST':
        cat_distr = OneHotCategorical(probs = probs)
        return cat_distr.sample(),cat_distr.entropy()
    elif mode == 'GUMBEL':
        cat_distr = RelaxedOneHotCategorical(temperature=tau, probs=probs)
        y_soft = cat_distr.rsample()
    
    if hard:
        # straight through.
        index = y_soft.max(dim, keepdim=True)[1]
        y_hard = torch.zeros_like(y_soft).scatter_(dim, index, 1.0)
        ret = (y_hard - y_soft).detach() + y_soft
    else:
        # Reparameterization trick.
        ret = y_soft
    return ret,ret





# speaking module
class SpeakModule(nn.Module):
    def __init__(self):
        super(SpeakModule, self).__init__()
        
        self.first_lstm_hidden_size = 20
        self.feat_vec_enc_layer = nn.Linear(FEAT_VEC_SIZE, self.first_lstm_hidden_size)

        self.lstm1 = nn.LSTMCell(input_size=2, hidden_size = self.first_lstm_hidden_size)
        self.out1 = nn.Linear(self.first_lstm_hidden_size,2)

        self.output_size = N_VOCAB
        self.hidden_size = RNN_SIZE
        self.input_size = IMG_FEAT_SIZE
        self.evalFlag = False
        self.embedding = nn.Embedding(N_CONCEPTS, self.input_size)
        self.lstm = nn.LSTMCell(self.input_size, self.hidden_size)
        self.out = nn.Linear(self.hidden_size, self.output_size)
    
    
    def chooseConceptOrder(self,feature_vec):
        # print(f'feature_vec shape = {feature_vec.shape}')
        feat_encoding = self.feat_vec_enc_layer(feature_vec)
        # print(f'feat_encoding = {feat_encoding.shape}')
        hidden = feat_encoding
        cell = torch.zeros(BATCH_SIZE,self.first_lstm_hidden_size)
        input_attr = torch.zeros(BATCH_SIZE,2)
        order = []

        logProbs = 0.
        for i in range(COMM_ROUND):
            hidden,cell = self.lstm1(input_attr,(hidden,cell))
            c = self.out1(hidden)
            c_distr = F.softmax(c,dim=1)
            chosen_c,_ = cat_softmax(c_distr,mode=MSG_MODE)
            input_attr = chosen_c
            logProbs+=torch.log((c_distr*chosen_c).sum(dim=1))
            order.append(torch.argmax(chosen_c,dim=1))       
            
        ord = torch.stack(order,dim=1)
        
        return ord
    
    
    
    def speak(self,comm_order, batchData):
        '''
            This module utters the messages 

            arguement:
                comm_order : order of the concepts for which messages are to be uttered
                batchData : concepts for which messages has to be uttered
        '''
        batch_size = batchData.shape[0]
        cell = self.initHidden(batch_size)
        hidden = self.initHidden(batch_size)
        # print(f'batchData ={batchData}')
        # embedding  = self.embedding(batchData) # encoding the concept
        # hidden = embedding.view(embedding.shape[0], -1) # passing encoding as hidden
        k = 0
        output = self.embedding(batchData[:,k]*comm_order[:,k])
        k+=1
        # print(f'embedding ---*** {hidden}')
        message = []
        log_probs = 0.
        logProbs = []
        for i in range(MSG_LEN):
            hidden, cell = self.lstm(output, (hidden, cell))
            output = self.out(hidden)
            outDistr = F.softmax(output, dim=1)
            
            if self.training:
                msg,entropy = cat_softmax(outDistr, mode=MSG_MODE, tau = TAU, hard=MSG_HARD, dim = 1)
            else:
                msg = F.one_hot(torch.argmax(outDistr, dim=1), num_classes = self.output_size).float()


            log_probs += torch.log((outDistr*msg).sum(dim=1))
            logProbs.append((outDistr.detach()*msg.detach()).sum(dim=1))
            message.append(msg)
            # output = msg
            if k<MSG_LEN:
                output = self.embedding(batchData[:,k]*comm_order[:,k])
                # print(batchData[:,k])
                k+=1
        # print(f'logProbs********{logProbs}')
        message = torch.stack(message)
        return message, log_probs,logProbs
        

    def initHidden(self,batch_size):
        return torch.zeros(batch_size,self.hidden_size)

class ListenModule(nn.Module):
    def __init__(self) -> None:
        super(ListenModule,self).__init__()
        self.input_size = N_VOCAB
        self.hidden_size = RNN_SIZE
        self.output_size = N_CONCEPTS

        self.lstm = nn.LSTMCell(self.input_size, self.hidden_size)
        self.out = nn.Linear(self.hidden_size, self.output_size)
    
    def listen_and_predict(self, message,batchData):
        batch_size = message.shape[1]
        hidden = self.initHidden(batch_size)
        cell = self.initHidden(batch_size)
        # print(f'message shape {message}')
        pred_concepts = []
        log_probs = 0.
        logProbs = []
        org_prob = []
        ct2 = torch.zeros(batch_size)
        for i in range(message.shape[0]):
            hidden, cell = self.lstm(message[i], (hidden,cell))
            output = self.out(hidden)
            outDistr = F.softmax(output, dim=1)
            
            if self.training:
                concept,etr = cat_softmax(outDistr, mode=MSG_MODE,tau = TAU, hard=MSG_HARD, dim = 1)
                
                
                ct = torch.argmax(concept,dim=1)
                j=0
                # print(f'batchData = {batchData[:,i]}')
                for idx in batchData[:,i]:
                    ct2[j]+=torch.log(outDistr[j][idx])
                    j += 1
                
            else:
                concept = F.one_hot(torch.argmax(outDistr, dim=1), num_classes = self.output_size).float()
            
            # print(f'ct2 shape ={ct2.shape}')
        
            log_probs += torch.log((outDistr*concept).sum(dim=1))
            logProbs.append(torch.log((outDistr*concept).sum(dim=1)))
            pred_concepts.append(ct)
        # print(f'ct2 ={ct2.detach()}')

    
        pred_concepts = torch.stack(pred_concepts,dim=1)

        # print(f'pred_concept:\n{pred_concepts}')
        #return batchData, ct2,logProbs
        return pred_concepts, log_probs,ct2
    
    
    
    def initHidden(self, batch_size):
        return torch.zeros(batch_size, self.hidden_size)
