'''
    Code for message length = 3
    and 3 concepts 
    train3.py -- train file for chatbot3.py


    config:

            N_AGENTS = 2
            NUM_ATTRS = 3
            MSG_LEN = 3
            N_SECTORS = 3
            N_SEGMENTS = 3
            N_COLORS = 3
            UNI_ATTR_VAL = 9
            N_CONCEPTS = 9
            N_VOCAB = 3
            IMG_FEAT_SIZE = 20  # embedding size of the input 



            # hyperparameters
            NUM_EPOCHS = 1000000
            BATCH_SIZE = 100
            TRAINING_SIZE = 0.9


            SPK_LEARNING_RATE = 0.009
            LIS_LEARNING_RATE = 0.009
            RNN_SIZE = 128
            RL_NEGATIVE_REWARD = 0
            RL_SCALE = 100
            MSG_MODE =  'GUMBEL'
            MSG_HARD = True
            TAU = 2
            CLIP = 50.0
            LAMBDA = 1


'''

import torch
import pickle
import numpy as np
import torch.nn as  nn
import torch.nn.functional as F
from torch.distributions import Categorical
from torch.distributions.one_hot_categorical import OneHotCategorical
from torch.distributions.relaxed_categorical import RelaxedOneHotCategorical

from torchviz import make_dot

from utils.conf import *
from collections import defaultdict


wstats1 = defaultdict(list)
wstats2 = defaultdict(list)
wstats3 = defaultdict(list)
wstats4 = defaultdict(list)



def cat_softmax(probs, mode,tau=1,hard=False,dim=-1, EXPLORE = False):
    if mode == 'REINFORCE' or mode == 'SCST':
        cat_distr = OneHotCategorical(probs = probs)
        return cat_distr.sample(),cat_distr.entropy()
    elif mode == 'GUMBEL':
        cat_distr = RelaxedOneHotCategorical(temperature=tau, probs=probs)
        y_soft = cat_distr.rsample()
    
    if hard:
        # straight through.
        index = y_soft.max(dim, keepdim=True)[1]
        y_hard = torch.zeros_like(y_soft).scatter_(dim, index, 1.0)
        ret = (y_hard - y_soft).detach() + y_soft
    else:
        # Reparameterization trick.
        ret = y_soft
    return ret,ret



# Speaking Agent is the speaker
class SpeakModule(nn.Module):
    def __init__(self, id, agt) -> None:
        super(SpeakModule,self).__init__()

        self.TVALUE = INIT_TVALUE
        self.ID = id
        self.agt = agt
        self.w_select_hidden_size = W_SELECT_HSIZE
        self.vertex_feature = nn.Linear(VERTEX_FEAT_VEC_SIZE, self.w_select_hidden_size)

        self.lstm_0 = nn.LSTMCell(input_size=2, hidden_size = self.w_select_hidden_size)
        self.out_0 = nn.Linear(self.w_select_hidden_size,2)

        self.output_size = N_VOCAB
        self.hidden_size = RNN_SIZE
        self.input_size = IMG_FEAT_SIZE
        self.evalFlag = False
        self.embedding = nn.Embedding(N_CONCEPTS, self.input_size)
        self.lstm = nn.LSTMCell(self.input_size, self.hidden_size)
        self.out = nn.Linear(self.hidden_size, self.output_size)


    def concept_select(self,feature_vec, self_mode=False, bvec=None):
        # print(f'feature_vec shape = {feature_vec.shape}')
        feat_encoding = self.vertex_feature(feature_vec)
        # print(f'feat_encoding = {feat_encoding.shape}')
        hidden = feat_encoding
        cell = torch.zeros(BATCH_SIZE,self.w_select_hidden_size)
        input_attr = torch.zeros(BATCH_SIZE,2)
        cselect = []

        logProbs = 0.
        rpb = torch.zeros(BATCH_SIZE)
        for i in range(COMM_ROUND):
            hidden,cell = self.lstm_0(input_attr,(hidden,cell))
            c = self.out_0(hidden)
            c_distr = F.softmax(c,dim=1)
            #print("CCCCCL ", c_distr)
            if(c_distr.isnan().any()):
                print("C_dist problem")
                print(c_distr)
            if(self_mode == True):
                for j in range(BATCH_SIZE):
                    rpb[j] += torch.log(c_distr[j][bvec[:,i][j]])
            chosen_c,_ = cat_softmax(c_distr,mode=MSG_MODE2, tau=self.TVALUE, hard=True)
            #print(chosen_c)
            #chosen_c = torch.nn.functional.gumbel_softmax(torch.log(c_distr),tau=self.TVALUE,hard=True)
            input_attr = chosen_c
            logProbs+=torch.log((c_distr*chosen_c).sum(dim=1))
            #cselect.append(torch.argmax(chosen_c,dim=1))       
            cselect.append(chosen_c[:,1])
        cselect_t = torch.stack(cselect,dim=1)
        cselect_int_t = cselect_t.type(torch.int)
        #print(torch.max(cselect_int_t))
        #print(torch.max(cselect_int_t) != 1)
        if(torch.max(cselect_int_t) != 1 or torch.min(cselect_int_t) < 0):
            print(list(self.lstm_0.parameters()))
            print(list(self.out_0.parameters()))
            print("c_distr: ", c_distr)
            print("c_select_t: ", cselect_t)
            print("chosen_c", chosen_c)
            print(cselect_int_t)
            print(torch.exp(logProbs))
            print("Exiting...due to incorrect(0/1) entry")
            cselect_t = cselect_int_t = None
            exit()
        #print("Probs")
        #print(torch.exp(logProbs.detach()))
        #print("*"*50)
            
        if(self_mode == False):
            #print(torch.max(torch.exp(logProbs.detach())), torch.min(torch.exp(logProbs.detach())), "TVALUE: ", TVALUE)
            if(torch.max(torch.exp(logProbs.detach())) >= 0.99 and torch.min(torch.exp(logProbs.detach())) >= 0.99):
                self.TVALUE = 1
            else:
                self.TVALUE = INIT_TVALUE
         #  print(cselect_t)
       # print(cselect_t.shape)
    
        return cselect_t, logProbs, cselect_int_t, rpb
   
    
    # def speak(self,batchData):
    #     batch_size = batchData.shape[0]
    #     cell = self.initHidden(batch_size)
    #     embeds = self.embedding(batchData)
    #     hidden = embeds.view(embeds.shape[0], -1)

    #     output = torch.zeros(batch_size, self.input_size)
    #     message = []
    #     self.actions = []
    #     for i in range(MSG_LEN):
    #         hidden, cell = self.lstm(output, (hidden, cell))
    #         output = self.out(hidden)
    #         outDistr = F.softmax(output, dim=1)
    #         if self.evalFlag:
    #             outDistr = F.one_hot(torch.argmax(outDistr, dim=1), num_classes= self.output_size).float()
    #         else:
    #             action_sampler = torch.distributions.Categorical(outDistr)
    #             action = action_sampler.sample()
    #             self.actions.append(action)
    #         message.append(outDistr)
    #     message = torch.stack(message)
    #     return message
    

   
    def speak(self,batchData,w_order=None,selfMode= False, targMsg = None):
    
        batch_size = batchData.shape[0]
        cell = self.initHidden(batch_size)
        hidden = self.initHidden(batch_size)
        # embedding  = self.embedding(batchData) # encoding the concept
        # hidden = embedding.view(embedding.shape[0], -1) # passing encoding as hidden
    
#       output = self.embedding(batchData[:,k])
        # print(f'embedding ---*** {hidden}')
        message = []
        log_probs = 0.
        logProbs = []
        ct2 = torch.zeros(batch_size)
        for i in range(MSG_LEN):
            if(w_order != None):
                try:
                    input = self.embedding(batchData[:,i]*w_order[:,i])
                except:
                    print(batchData[:,i])
                    print(w_order[:,i])
                    raise
            else:
                input = self.embedding(batchData[:,i])
            hidden, cell = self.lstm(input, (hidden, cell))
            output = self.out(hidden)
            
            outDistr = F.softmax(output, dim=1)
            if((self.agt.time+1) % 50000 == 0):
                #print("----- 200 -----")
                ddn = outDistr.tolist()
                dd = batchData[:,i].tolist()
                for di in range(len(dd)):
                    if(self.ID == 0):
                        wstats1[dd[di]].append(ddn[di])
                    elif(self.ID == 1):
                        wstats2[dd[di]].append(ddn[di])
                    elif(self.ID == 2):
                        wstats3[dd[di]].append(ddn[di])
                    elif(self.ID == 3):
                        wstats4[dd[di]].append(ddn[di])
                    

            if selfMode == True:
                targMsg2 = torch.argmax(targMsg,dim=2).T
                #print(targMsg2)
                j = 0
                for idx in targMsg2[:,i]:
                    ct2[j]+=torch.log(outDistr[j][idx])
                    #print(i, idx, outDistr[j][idx], outDistr[j][idx.item()], ct2)
                    j+=1
            if self.training:
                msg,entropy = cat_softmax(outDistr, mode=MSG_MODE, tau = TAU, hard=MSG_HARD, dim = 1)
            else:
                msg = F.one_hot(torch.argmax(outDistr, dim=1), num_classes = self.output_size).float()


            log_probs += torch.log((outDistr*msg).sum(dim=1))
            logProbs.append(torch.log((outDistr*msg).sum(dim=1)))
            message.append(msg)
            # output = msg
            
       #     input = self.embedding(batchData[:,i]*w_order[:,i])
        message = torch.stack(message)
        if((self.agt.time+1) %50000 == 0):
            #print("----- 200 ----- ", self.agt.time)
            if(self.ID == 0):
                with open('A1_DICT', 'wb') as fp:
                    pickle.dump(wstats1, fp)
            elif(self.ID == 1):
                with open('A2_DICT', 'wb') as fp:
                    pickle.dump(wstats2, fp)
            elif(self.ID == 2):
                with open('A3_DICT', 'wb') as fp:
                    pickle.dump(wstats3, fp)
            elif(self.ID == 3):
                with open('A4_DICT', 'wb') as fp:
                    pickle.dump(wstats4, fp)

        #print("CT2: ", ct2)
        return message, log_probs,ct2, logProbs
    

    def initHidden(self, batch_size):
        return torch.zeros(batch_size, self.hidden_size)


    def getVocab(self,batchData):
        print("$$$ ", self.ID)
        batch_size = batchData.shape[0]
        cell = self.initHidden(batch_size)
        hidden = self.initHidden(batch_size)
        message = []
        log_probs = 0.
        logProbs = []
        ct2 = torch.zeros(batch_size)
        for i in range(MSG_LEN):
            input = self.embedding(batchData[:,i])
            hidden, cell = self.lstm(input, (hidden, cell))
            output = self.out(hidden)
            #print(output)
            outDistr = F.softmax(output, dim=1)
            if(self.agt.time % 1 == 0):
                #print("----- 200 -----")
                ddn = outDistr.tolist()
                dd = batchData[:,i].tolist()
                for di in range(len(dd)):
                    if(self.ID == 0):
                       # print(">>> ", len(wstats1[0]))
                       # print(dd)
                        wstats1[dd[di]].append(ddn[di])
                    elif(self.ID == 1):
                       # print(">>> ", len(wstats2[0]))
                       # print(dd)
                        wstats2[dd[di]].append(ddn[di])
                    elif(self.ID == 2):
                      #  print(">>> ", len(wstats3[0]))
                       # print(dd)
                        wstats3[dd[di]].append(ddn[di])
                    elif(self.ID == 3):
                      #  print(">>> ", len(wstats4[0]))
                        wstats4[dd[di]].append(ddn[di])
        if(self.ID == 0):
            with open('A1_DICT', 'wb') as fp:
                pickle.dump(wstats1, fp)
        elif(self.ID == 1):
            with open('A2_DICT', 'wb') as fp:
                pickle.dump(wstats2, fp)
        elif(self.ID == 2):
            with open('A3_DICT', 'wb') as fp:
                pickle.dump(wstats3, fp)
        elif(self.ID == 3):
            with open('A4_DICT', 'wb') as fp:
                pickle.dump(wstats4, fp)



# Listening Agent is the listener
class ListenModule(nn.Module):
    def __init__(self, id, agt) -> None:
        super(ListenModule,self).__init__()
        self.input_size = N_VOCAB
        self.hidden_size = RNN_SIZE
        self.output_size = N_CONCEPTS
        self.ID = id
        self.agt = agt
        self.lstm = nn.LSTMCell(self.input_size, self.hidden_size)
        self.out = nn.Linear(self.hidden_size, self.output_size)
    
    
    # def listen_and_predict(self, message):
    #     pred_concepts = []
    #     self.actions = []
    #     batch_size = message.shape[1]
    #     # self.embed = nn.Embedding(N_VOCAB, self.input_size)
    #     hidden = self.initHidden(batch_size)
    #     cell = self.initHidden(batch_size)

    #     for i in range(message.shape[0]):
    #         hidden, cell = self.lstm(message[i], (hidden, cell))
    #         output = self.out(hidden)
    #         outDistr = F.softmax(output, dim=1)
    #         action_sampler = torch.distributions.Categorical(outDistr)
    #         actions = action_sampler.sample()
    #         log_prob = -action_sampler.log_prob(actions)
    #         self.actions.append(log_prob)
    #         pred_concepts.append(actions) 
    #     pred_concepts = torch.stack(pred_concepts, dim=1)
        
    #     return pred_concepts



    def listen(self,message,selfMode = False, tarConcept=None):
        batch_size = message.shape[1]
        hidden = self.initHidden(batch_size)
        cell = self.initHidden(batch_size)
        pred_concepts = []
        log_probs = 0.
        logProbs = []
        
        ct2 = torch.zeros(batch_size)
        for i in range(message.shape[0]):
            hidden, cell = self.lstm(message[i], (hidden,cell))
            output = self.out(hidden)
            outDistr = F.softmax(output, dim=1)
            
            if selfMode == True:
                j = 0
                for idx in tarConcept[:,i]:
                    ct2[j] +=torch.log(outDistr[j][idx])
                    j+=1

            if self.training:
                concept,etr = cat_softmax(outDistr, mode=MSG_MODE,tau = TAU, hard=MSG_HARD, dim = 1)
                ct = torch.argmax(concept.detach(),dim=1)                
            else:
                concept = F.one_hot(torch.argmax(outDistr, dim=1), num_classes = self.output_size).float()
            
        
            log_probs += torch.log((outDistr*concept).sum(dim=1))
            logProbs.append(torch.log((outDistr*concept).sum(dim=1)))
            pred_concepts.append(ct)
        pred_concepts = torch.stack(pred_concepts,dim=1)
        

        # print(f'pred_concept:\n{pred_concepts}')
        #return batchData, ct2,logProbs
        return pred_concepts, log_probs,ct2
    
    
    def initHidden(self, batch_size):
        return torch.zeros(batch_size, self.hidden_size)
    



class Agent(nn.Module):
    def __init__(self, id) -> None:
        super(Agent,self).__init__()
        self.speakModule = SpeakModule(id, self)
        self.listenModule = ListenModule(id, self)
        self.agentID = id
        self.time = 0