'''
    Training for 3 words setup
    chatbot4.py -- neural network architecture

    config:
                N_AGENTS = 2
                NUM_ATTRS = 3
                MSG_LEN = 3
                N_SECTORS = 3
                N_SEGMENTS = 3
                N_COLORS = 3
                UNI_ATTR_VAL = 9
                N_CONCEPTS = 9
                N_VOCAB = 3
                IMG_FEAT_SIZE = 20  # embedding size of the input 



                # hyperparameters
                NUM_EPOCHS = 1000000
                BATCH_SIZE = 100
                TRAINING_SIZE = 0.9


                SPK_LEARNING_RATE = 0.009
                LIS_LEARNING_RATE = 0.009
                RNN_SIZE = 128
                RL_NEGATIVE_REWARD = 0
                RL_SCALE = 100
                MSG_MODE =  'GUMBEL'
                MSG_HARD = True
                TAU = 2
                CLIP = 50.0
                LAMBDA = 1



'''
'''
    Instantiating 2 agents both having a speaking and listening module. We add extra KL divergence term for loss between speaker's listening probability and 

    Experiments:
        Taking only one agent as speaker and other agent as listener
'''



import pickle
import torch
from tqdm import tqdm
from colorama import Fore, Back, Style
from utils import *
from utils.graphworld import  World
from models.chatbot3 import *
from utils.datagen import *
from utils.conf import *
# from utils.helper import DotDic
from time import sleep
import torch.nn as nn
import numpy as np
from collections import defaultdict
from utils.reward import *



data = DataLoader()
world = World()
agent1 = Agent(0)
agent2 = Agent(1)
agent3 = Agent(2)
agent4 = Agent(3)
agents = [agent1,agent2, agent3, agent4]
vusage = defaultdict(dict)
cusage = defaultdict(int)
vc_usage = defaultdict(int)
vclog = defaultdict(int)

seldict0 = defaultdict(int)
seldict2_0  = {}
seldictsize0 = 0.0
seldict1 = defaultdict(int)
seldict2_1  = {}
seldictsize1 = 0.0

seldict2 = defaultdict(int)
seldict2_2  = {}
seldictsize2 = 0.0
seldict3 = defaultdict(int)
seldict2_3  = {}
seldictsize3 = 0.0

seldict_def = dict()
seldict_def[(0,0,0)] = 0
seldict_def[(0,0,1)] = 0
seldict_def[(0,1,0)] = 0
seldict_def[(0,1,1)] = 0
seldict_def[(1,0,0)] = 0
seldict_def[(1,0,1)] = 0
seldict_def[(1,1,0)] = 0
seldict_def[(1,1,1)] = 0












# speakerOptimiser = torch.optim.Adam(speakingAgent.parameters(), lr=LEARNING_RATE)
# listenerOptimiser = torch.optim.Adam(listeningAgent.parameters(), lr=LEARNING_RATE)
optimizer = torch.optim.SGD([{'params': agent1.parameters(), \
                                'lr':SPK_LEARNING_RATE},\
                        {'params': agent2.parameters(), \
                                'lr':LIS_LEARNING_RATE},\
                        {'params': agent3.parameters(), \
                                'lr':SPK_LEARNING_RATE}, \
                        {'params': agent4.parameters(), \
                                'lr':LIS_LEARNING_RATE}])

optimizer2 = torch.optim.SGD([{'params': agent1.speakModule.lstm_0.parameters(), \
                                'lr':1e-4},\
                                {'params': agent1.speakModule.out_0.parameters(), \
                                 'lr':1e-4},\
                                {'params': agent2.speakModule.lstm_0.parameters(), \
                                'lr':1e-4},\
                                {'params': agent2.speakModule.out_0.parameters(), \
                                 'lr':1e-4},\
                                {'params': agent3.speakModule.lstm_0.parameters(), \
                                'lr':1e-4},\
                                {'params': agent3.speakModule.out_0.parameters(), \
                                 'lr':1e-4},\
                                {'params': agent4.speakModule.lstm_0.parameters(), \
                                'lr':1e-4},\
                                {'params': agent4.speakModule.out_0.parameters(), \
                                 'lr':1e-4}])



opt01 = torch.optim.Adam([{'params': agent1.speakModule.parameters(), \
                                'lr':SPK_LEARNING_RATE},\
                        {'params': agent2.listenModule.parameters(), \
                                'lr':LIS_LEARNING_RATE}])

opt10 = torch.optim.Adam([{'params': agent1.listenModule.parameters(), \
                                'lr':SPK_LEARNING_RATE},\
                        {'params': agent2.speakModule.parameters(), \
                                'lr':LIS_LEARNING_RATE}])


optz = [opt01, opt10]


# generate feature vector for each source,target pair
def get_vertex_feature_batch(vertex_batch):
    graphBatch = torch.Tensor(BATCH_SIZE,VERTEX_FEAT_VEC_SIZE)
    b_i = 0
    for src,targ in vertex_batch:
        b = []
        b.extend(world.locations[src.item()].tolist())
        b.extend(world.locations[targ.item()].tolist())
        b.extend(world.feat[src.item()])    
        graphBatch[b_i] = torch.Tensor(b)
        b_i+=1
    return graphBatch

def calculateCorrectPred(batchData, concept):
    # if both prediction are correct reward is 10 else -1
    batch_size = batchData.shape[0]
    reward = torch.zeros(batch_size)
    acc_count_1 = 0
    acc_count_2 = 0
    acc_count_3 = 0
    # print(batchData)
    # print(concept)
    for b in range(batch_size):
        reward[b]=-1
        # if batchData[b][0] == concept[b][0] and batchData[b][1] == concept[b][1] and batchData[b][2] == concept[b][2]:
        if batchData[b][0] == concept[b][0]:
            reward[b]=5
            acc_count_1+=1
        if batchData[b][0] == concept[b][0] and batchData[b][1] == concept[b][1] :
            reward[b]=10
            acc_count_2+=1
        if batchData[b][0] == concept[b][0] and batchData[b][1] == concept[b][1] and batchData[b][2] == concept[b][2]:
            reward[b]=100
            acc_count_3+=1

    return reward,acc_count_1,acc_count_2,acc_count_3


def calculateCorrectPred3(batchData, concept):
    # if both prediction are correct reward is 10 else -1
    batch_size = batchData.shape[0]
    reward = torch.zeros(batch_size)
    acc_count_1 = 0
    acc_count_2 = 0
    acc_count_3 = 0
    # print(batchData)
    # print(concept)
    for b in range(batch_size):
                
        reward[b]=0
        # if batchData[b][0] == concept[b][0] and batchData[b][1] == concept[b][1] and batchData[b][2] == concept[b][2]:
        if batchData[b][0] == concept[b][0]:
            reward[b]=0
            acc_count_1+=1
        else:
            reward[b]=-5
        if batchData[b][0] == concept[b][0] and batchData[b][1] == concept[b][1] :
            reward[b]=0
            acc_count_2+=1
        else:
            reward[b] -= 5
        if batchData[b][0] == concept[b][0] and batchData[b][1] == concept[b][1] and batchData[b][2] == concept[b][2]:
            reward[b]=100
            acc_count_3+=1
        else:
            reward[b] -= 5
        
        if(b >= ZERO_DATA):
            if(batchData[b][0] == batchData[b][1] and batchData[b][1] == batchData[b][2] == 0):
                reward[b] = -1000
        else:
            if(concept[b][1] == concept[b][2] == 0):
                reward[b] += 100


            

    return reward,acc_count_1,acc_count_2,acc_count_3



def calculateCorrectPred2(batchData):
    # if both prediction are correct reward is 10 else -1
    batch_size = batchData.shape[0]
    reward = torch.zeros(batch_size)
    acc_count = 0
    acc_count_2 = 0
    acc_count_3 = 0
    # print(batchData)
    # print(concept)
    for b in range(batch_size):
        reward[b]=-1
        if(b < ZERO_DATA):
            if batchData[b][1] == batchData[b][2] :
                reward[b]=100
                acc_count+=1
        else:
            if batchData[b][0] == batchData[b][1] and batchData[b][0] == batchData[b][2] :
                reward[b]=100
                acc_count+=1

    return reward,acc_count


speakers_idx = np.zeros(NUM_EPOCHS,dtype = int)
speakers2_idx = np.zeros(NUM_EPOCHS,dtype = int)
listener_idx = np.zeros(NUM_EPOCHS,dtype = int)
listener2_idx = np.zeros(NUM_EPOCHS,dtype = int)

for i in range(0,NUM_EPOCHS,500):
    agl = list(range(NUM_AGENTS))
    s1 = np.random.choice(agl)   
    agl.remove(s1)
    l1 = np.random.choice(agl)   
    agl.remove(l1)
    speakers_idx[i:i+500] = s1
    listener_idx[i:i+500] = l1
    if(NUM_AGENTS == 4):
        s2 = np.random.choice(agl)   
        agl.remove(s2)
        l2 = np.random.choice(agl)   
        agl.remove(l2)
        speakers2_idx[i:i+500] = s2
        listener2_idx[i:i+500] = l2



Y = []
ACCURACY_1 = []
ACCURACY_2 = []
ACCURACY_3 = []
LOSS = []
LOSS_0 = []
LOSS_2 = []
LOSS_3 = []

SUCCESS = []
SEL_PROB_MAX_0 = []
SEL_PROB_MIN_0 = []
SEL_PROB_MAX_1 = []
SEL_PROB_MIN_1 = []
REWARDS = []

cc = defaultdict(int)
mm = defaultdict(int)

def update_c_count(batchData):
    cc.clear()
    for b in batchData:
        cc[tuple(b.tolist())]+=1
    

def update_m_count(message):
    mm.clear()
    for i in range(BATCH_SIZE):
        mm[(message[0][i], message[1][i], message[2][i])]+=1

def get_c_count(c_pair):
    return cc[c_pair]
def get_m_count(m_pair):
    return mm[m_pair]

for epoch in tqdm(range(NUM_EPOCHS), colour="red"):
    cc_vect = torch.zeros(BATCH_SIZE)
    mm_vect = torch.zeros(BATCH_SIZE)
    cc2_vect = torch.zeros(BATCH_SIZE)
    #vusage.clear()

    # count pr_c for each pair in batch
    if(epoch %2 == 0):
        speaker = agents[speakers_idx[epoch]]
        listener = agents[listener_idx[epoch]]
    else:
        speaker = agents[speakers2_idx[epoch]]
        listener = agents[listener2_idx[epoch]]

    #listener2 = agents[listener2_idx[epoch]]



    vertex_batch, nullData = data.get_vertex_batch(BATCH_SIZE) # pair of index of src, target
 
    v_feature_batch  = get_vertex_feature_batch(vertex_batch)
    c_select_f, select_probs, c_select, _ = speaker.speakModule.concept_select(v_feature_batch)

    if(c_select == None):
        continue
    
    
    
    #print(vertex_batch)
    batchData, batchData2  = world.getconcepts(vertex_batch) 


    
    #batchData2  = world.getconcepts(vertex_batch) 
    #for kk in range(BATCH_SIZE):
    #    print(batchData[kk].tolist(), batchData2[kk].tolist())
    #exit()

   
    batchData = batchData*c_select
    batchData2  = batchData2*c_select 

    for rr in range(ZERO_DATA):
        batchData[rr] = nullData[rr]
        #batchData = nullData
    #print(batchData)
    
    #print(batchData*c_select)
    #print(batchData)
    
    # get the batchData
    #batchData, nullData = data.getBatch(BATCH_SIZE)
  

    # print(batchData.shape)
    # print(batchData[:,])
    # exit()
    update_c_count(batchData)

    if(True):
        nl_msg, nl_log_spk_prob,_, nl_log_prob_vect = speaker.speakModule.speak(nullData)

        d_nl_msg = nl_msg.detach()

        d_nl_msg_t = torch.argmax(d_nl_msg,dim=2)
        reward_nl, acc_nl = calculateCorrectPred2(d_nl_msg_t.T)
        #print(reward_nl)
    
    #print((nl_log_prob_vect[0]-nl_log_prob_vect[1])**2.0+(nl_log_prob_vect[0]-nl_log_prob_vect[2])**2.0)


    update_c_count(batchData)

    message, spk_log_probs,_, _ = speaker.speakModule.speak(batchData, c_select)

    #_, _ ,chk,_ = speaker.speakModule.speak(batchData, c_select, selfMode=True, targMsg=message)

    #print("Chk: ", chk)
    #print("Spk_log_probs: ", spk_log_probs)
    #exit()

    d_message = message.detach()
   

    message_t = torch.argmax(d_message,dim=2)
    
    
    d_msg_t = message_t.detach()
    update_m_count(d_msg_t.tolist())
    msg_list = d_msg_t.tolist()

    for i in range(BATCH_SIZE):
        cc_vect[i] = cc[tuple(batchData[i].tolist())]
        
        mm_vect[i] = mm[(msg_list[0][i], msg_list[1][i], msg_list[2][i])]

        cusage[batchData[i][0].item()] += 1
        cusage[batchData[i][1].item()] += 1
        cusage[batchData[i][2].item()] += 1

        for mi in range(3):
            vc_usage[msg_list[mi][i]] += 1

        try:
            vusage[batchData[i][0].item()][msg_list[0][i]] += 1
        except KeyError:
            vusage[batchData[i][0].item()][msg_list[0][i]] = 1
        try:      
            vusage[batchData[i][1].item()][msg_list[1][i]] += 1
        except KeyError:        
            vusage[batchData[i][1].item()][msg_list[1][i]] = 1
        try:
            vusage[batchData[i][2].item()][msg_list[2][i]] += 1
        except KeyError:
            vusage[batchData[i][2].item()][msg_list[2][i]] = 1

    ratio = cc_vect/mm_vect




    # print(f'ratio = {ratio}')
    # exit(0)
    # pass the message to listenining module of same agent 
    s_concept, _ , s_lis_log_probs = speaker.listenModule.listen(d_message,selfMode = True, tarConcept =  batchData)
    d_s_concept = s_concept.detach()
    
    concept,lis_log_probs,l_ct2 = listener.listenModule.listen(message,selfMode = True, tarConcept= batchData)

    d_concept = concept.detach()

    _ ,_ , l_ct3 = listener.listenModule.listen(message,selfMode = True, tarConcept= batchData2)
    
    b_d_concept = d_concept > 0
    b_d_concept = b_d_concept.type(torch.int)

    # pass the concept to the speaking module of the listener agent
    l_message, _ , l_spk_log_probs, _ = listener.speakModule.speak(d_concept,selfMode = True, targMsg = d_message)
    d_l_message = l_message.detach()


############################# 3_agents
    #concept2,lis_log_probs2,l_ct2_2 = listener2.listenModule.listen(message,selfMode = True, tarConcept= batchData)

    #d_concept2 = concept2.detach()
    # pass the concept to the speaking module of the listener agent
    #l_message2, _ , l_spk_log_probs2 = listener2.speakModule.speak(d_concept2,selfMode = True, targMsg = d_message)

    #d_l_message2 = l_message2.detach()

########################### 3_agents

    _,_,_,l_prb_csel = listener.speakModule.concept_select(v_feature_batch, True, b_d_concept)


    rewards2, ac_1_2,ac_2_2,ac_3_2 = calculateCorrectPred(batchData, d_s_concept)

    #print("Epoch: [ ", epoch, " ]")
    rewards, ac_1,ac_2,ac_3 = calculateCorrectPred3(batchData, d_concept)


    rewards_copy2 = torch.LongTensor(BATCH_SIZE)
    for vv in range(BATCH_SIZE): 
        rewards_copy2[vv] = rewards[vv]
    

   
    accuracy = 0.0
    if epoch > 5:
        rewards_v, accuracy = getBatchReward(world, d_concept, vertex_batch, batchData,  d_msg_t.T, d_nl_msg_t.T)
        for vv in range(ZERO_DATA, BATCH_SIZE):
            rewards[vv] = rewards_v[vv]
    
    #for vv in range(ZERO_DATA):
    #    rewards[vv] = rewards[vv]
    
    rewards_copy = torch.LongTensor(BATCH_SIZE)
    for vv in range(BATCH_SIZE):
        rewards_copy[vv] = rewards[vv]
    #print("bd: ", batchData)
    #print("d_c", d_concept)
    #print("rw:::", rewards)
   # print(rewards_v)
    ### 3 agents
    #rewards_2a, ac_1_2a, ac_2_2a, ac_3_2a = calculateCorrectPred(batchData, d_concept2)
    ### 3 agents

    l_message_t = torch.argmax(d_l_message,dim=2)
    rewards3, ac_1_3,ac_2_3,ac_3_3 = calculateCorrectPred(message_t.T, l_message_t.T)


    RI = np.random.randint(BATCH_SIZE, size=50)
    for x_i in RI:
        if(rewards[x_i] < 100 and rewards[x_i] != 1200):
            rewards[x_i] = 100
            lis_log_probs[x_i] = (l_ct2[x_i] + l_ct3[x_i])/2.0


    l_cc = defaultdict(int)
    #concept_copy = concept.detach()
    for b in d_concept:
        l_cc[tuple(b.tolist())]+=1

    for i in range(BATCH_SIZE):
        cc2_vect[i] = l_cc[tuple(d_concept[i].tolist())]

    ratio2 = mm_vect/cc2_vect
    #print("d_concept: ", d_concept)
    #print(l_cc)
    #print(mm)
    #print(l_spk_log_probs)
    #print("RATIO2: ", ratio2)
    #print("RATIO: ", ratio)
    ### 3 agents
    #l_cc2 = defaultdict(int)
    #for b in d_concept2:
     #   l_cc2[tuple(b.tolist())]+=1

#    for i in range(BATCH_SIZE):
 #       cc2_vect[i] = l_cc2[tuple(d_concept2[i].tolist())]

  #  ratio2_2a = cc2_vect/mm_vect
#############################

    optimizer.zero_grad()
    optimizer2.zero_grad()
    #optz[listener_idx[epoch]].zero_grad()
    # print(f'{(-l_spk_log_probs*(1/ratio.detach())-s_lis_log_probs*ratio.detach()).mean()}')
    lis_loss = -(((rewards.detach()-ALPHA*(lis_log_probs.detach()+spk_log_probs.detach()+1))*(lis_log_probs+spk_log_probs)+\
                  LAMBDA*(l_spk_log_probs*ratio2.detach()+s_lis_log_probs*ratio.detach())).mean())
    #lis_loss = -(((rewards.detach()-lis_log_probs.detach()-spk_log_probs.detach()-1)*(lis_log_probs+spk_log_probs)).mean())
    #3_agents
    #lis_loss_2a = -(((rewards_2a.detach()-lis_log_probs2.detach()-spk_log_probs.detach()-1)*((lis_log_probs2+spk_log_probs)+LAMBDA*(l_spk_log_probs2*(1/ratio2_2a.detach())+s_lis_log_probs*ratio.detach()))).mean())
    #spk_loss_null = ((nl_log_prob_vect[0]-nl_log_prob_vect[1])**2.0+(nl_log_prob_vect[0]-nl_log_prob_vect[2])**2.0).mean()
    spk_loss_null = -(((reward_nl.detach())*nl_log_spk_prob).mean())


    loss2 = -((s_lis_log_probs*ratio.detach()).mean())
    loss3 = -((l_spk_log_probs*(ratio2.detach())).mean())
    floss = loss2+loss3

    #print(c_select)
    #loss_w_sel = (torch.norm(c_select, dim=1)**2.0).mean()

    loss_c_sel = -(l_prb_csel).mean()

    #print("LLLLL: ", loss3)
    #print("LLLLL2: ", loss2)
    #loss_w_sel_0 = rewards_copy.detach()*select_probs
    loss_w_sel = 0.01*((0.1*torch.square(c_select_f).sum(dim=1, keepdim=True)-0.9*(rewards_copy.detach()-W_SEL_ENTROPY*select_probs.detach()-W_SEL_ENTROPY)*select_probs).mean() + loss_c_sel)
    #loss_w_sel = (9*torch.square(c_select_f).sum(dim=1, keepdim=True)).mean() + loss_c_sel

    #print("LOSS_W_SELL", loss_w_sel)
    #loss_w_sel.requires_grad = True

    #loss_w_sel = -((select_probs).mean())
    #print(loss_w_sel)
    loss4_2a = lis_loss + spk_loss_null + loss_w_sel# + lis_loss_2a #3_agents
    #lis_loss = -(((rewards.detach()-lis_log_probs.detach()-spk_log_probs.detach()-1)*(lis_log_probs+spk_log_probs)).mean())
    #lis_loss = -(((rewards.detach())*((lis_log_probs+spk_log_probs)+LAMBDA*(l_spk_log_probs*(1/ratio.detach())+s_lis_log_probs*ratio.detach()))).mean())


    #loss_w_sel.backward()
    loss4_2a.backward()

    #loss2.backward()
    #loss3.backward()
    nn.utils.clip_grad_norm_(listener.parameters(), 100.0)
    nn.utils.clip_grad_norm_(speaker.parameters(), 100.0)
    #nn.utils.clip_grad_norm_(listener2.parameters(),50.0)

    #print("GRAD: ", speaker.speakModule.lstm_0.weight_hh.grad)
    
    #optimizer2.step()
    optimizer.step()
    #optz[listener_idx[epoch]].step()

    #print(list(speaker.speakModule.lstm_0.parameters()))
    #print(list(speaker.speakModule.out_0.parameters()))
    
    INTERVAL = 100
    if epoch%INTERVAL == 0:
        dataX = world.getCheckData()
        speaker.speakModule.getVocab(dataX)
        print()
        print("*"*100)
        print(vv_list)
        #print(d_nl_msg_t)
        #print(f'******************************, accuracy:{ac_1_2a/BATCH_SIZE}, acc_2:{ac_2_2a/BATCH_SIZE}, acc_3:{ac_3_2a/BATCH_SIZE}')
        Y.append(epoch)
        for k,v in vusage.items():
            bc = np.array((list(v.keys()), list(v.values())))
            #bc[0] = bc[0]*cusage[k]
            print("[ ", k, " ]")
            sorted_items = sorted(v.items(), key=lambda e: e[1], reverse=True)
            print(' # '.join('{}: {}'.format(qq, round(rr/cusage[k], 4)) for qq, rr in sorted_items))
            vc = sorted_items[0]
            try:
                vclog[k].append(vc[1]/cusage[k])
            except:
                vclog[k] = [vc[1]/cusage[k]]

        print("~"*100)
        sorted_items2 = sorted(vc_usage.items(), key=lambda e: e[0], reverse=False)
        print(' $$ '.join('{}: {} ({})'.format(qq, round(rr/(270*(epoch+1)), 4), rr) for qq, rr in sorted_items2))
       
    #    print(list(speaker.speakModule.lstm_0.parameters()))
        #    if paramx.requires_grad:
         #       print(namex, paramx.data)        
        print(c_select_f)
        print("pred concept: ", d_concept)
        print("orig concept: ", batchData)
        print("Message: ", message_t.T)
        print(d_nl_msg_t)
        print("null-rw:", reward_nl)
        print("guided-rw:", rewards)
        print("org-rw:", rewards_copy)
        print("Concept-match-rw:", rewards_copy2)

        print("msg_t: ", message_t)
        print("l_msg_t: ", l_message_t)
        print(torch.exp(select_probs.detach()))
        print(torch.max(torch.exp(select_probs.detach())), torch.min(torch.exp(select_probs.detach())), "TVALUE: ", speaker.speakModule.TVALUE)
        print(Fore.RED + f'epoch:{[epoch]}, v_accuracy = {accuracy/(BATCH_SIZE-ZERO_DATA)}')
        print("LOSS_W_SELL", loss_w_sel)
        print("-"*150)
        print(Fore.GREEN + f'epoch:{[epoch]}, speaker = {speaker.agentID}, list = {listener.agentID},lis_loss:{lis_loss:.4f}')
        print(f'** accuracy:{ac_1/BATCH_SIZE}, acc_2:{ac_2/BATCH_SIZE}, acc_3:{ac_3/BATCH_SIZE}')
        print(f'** accuracy:{ac_1_2/BATCH_SIZE}, acc_2:{ac_2_2/BATCH_SIZE}, acc_3:{ac_3_2/BATCH_SIZE}')
        print(f'** accuracy:{ac_1_3/BATCH_SIZE}, acc_2:{ac_2_3/BATCH_SIZE}, acc_3:{ac_3_3/BATCH_SIZE}')
        print("loss2: ", loss2, " loss3: ", loss3)
        #print(f'***************************************************, loss:{loss_w_sel:.4f}')
        print(Style.RESET_ALL)
        #logfile = open('outlog2.npy', 'wb') 
        
        
        LOSS.append(lis_loss.detach().item())
        LOSS_0.append(loss_w_sel.detach().item())
        LOSS_2.append(loss2.detach().item())
        LOSS_3.append(loss3.detach().item())
        
        ACCURACY_1.append(int((ac_1/BATCH_SIZE)*100))
        ACCURACY_2.append(int((ac_2/BATCH_SIZE)*100))
        ACCURACY_3.append(int((ac_3/BATCH_SIZE)*100))
        REWARDS.append(rewards.mean().item())
        if(speakers_idx[epoch] == 0):
            SEL_PROB_MAX_0.append(torch.max(torch.exp(select_probs.detach())).item())
            SEL_PROB_MIN_0.append(torch.min(torch.exp(select_probs.detach())).item())
        else:
            SEL_PROB_MAX_1.append(torch.max(torch.exp(select_probs.detach())).item())
            SEL_PROB_MIN_1.append(torch.min(torch.exp(select_probs.detach())).item())
        SUCCESS.append(accuracy/(BATCH_SIZE-ZERO_DATA))

        with open('LOSS', 'wb') as fp:
            pickle.dump(LOSS, fp)
        with open('A_1', 'wb') as fp:
            pickle.dump(ACCURACY_1, fp)
        with open('A_2', 'wb') as fp:
            pickle.dump(ACCURACY_2, fp)
        with open('A_3', 'wb') as fp:
            pickle.dump(ACCURACY_3, fp)
        with open('SS', 'wb') as fp:
            pickle.dump(SUCCESS, fp)
        with open('RW', 'wb') as fp:
            pickle.dump(REWARDS, fp)
        with open('S_P_MX_0', 'wb') as fp:
            pickle.dump(SEL_PROB_MAX_0, fp)
        with open('S_P_MN_0', 'wb') as fp:
            pickle.dump(SEL_PROB_MIN_0, fp)
        with open('S_P_MX_1', 'wb') as fp:
            pickle.dump(SEL_PROB_MAX_1, fp)
        with open('S_P_MN_1', 'wb') as fp:
            pickle.dump(SEL_PROB_MIN_1, fp)
        with open('LOSS_0', 'wb') as fp:
            pickle.dump(LOSS_0, fp)
        with open('VOCAB', 'wb') as fp:
            pickle.dump(vclog, fp)
        with open('LOSS2', 'wb') as fp:
            pickle.dump(LOSS_2, fp)
        with open('LOSS3', 'wb') as fp:
            pickle.dump(LOSS_3, fp)
        with open('WUSAGE', 'wb') as fp:
            pickle.dump(sorted_items2, fp)

        with open('CUSAGE', 'wb') as fp:
            pickle.dump(cusage, fp)



        if(speaker.agentID == 0):
            seldictsize0 = 100.0
            seldict0 = seldict_def.copy()
            for e in c_select:
                seldict0[tuple(e.tolist())] += 1
            for e in seldict0:
                print(e, ": ", seldict0[e]/seldictsize0)
                try:
                    seldict2_0[e].append((epoch, seldict0[e]/seldictsize0))
                except:
                    seldict2_0[e] = [(epoch, seldict0[e]/seldictsize0)]
            with open('WOSTAS0', 'wb') as fp:
                pickle.dump(seldict2_0, fp)
        elif(speaker.agentID == 1):
            seldictsize1 = 100.0
            seldict1 = seldict_def.copy()
            for e in c_select:
                seldict1[tuple(e.tolist())] += 1
            for e in seldict1:
                print(e, ": ", seldict1[e]/seldictsize1)
                try:
                    seldict2_1[e].append((epoch, seldict1[e]/seldictsize1))
                except:
                    seldict2_1[e] = [(epoch, seldict1[e]/seldictsize1)]
            with open('WOSTAS1', 'wb') as fp:
                pickle.dump(seldict2_1, fp)
        elif(speaker.agentID == 2):
            seldictsize2 = 100.0
            seldict2 = seldict_def.copy()
            for e in c_select:
                seldict2[tuple(e.tolist())] += 1
            for e in seldict2:
                print(e, ": ", seldict2[e]/seldictsize2)
                try:
                    seldict2_2[e].append((epoch, seldict2[e]/seldictsize2))
                except:
                    seldict2_2[e] = [(epoch, seldict2[e]/seldictsize2)]
            with open('WOSTAS2', 'wb') as fp:
                pickle.dump(seldict2_2, fp)
        elif(speaker.agentID == 3):
            seldictsize3 = 100.0
            seldict3 = seldict_def.copy()
            for e in c_select:
                seldict3[tuple(e.tolist())] += 1
            for e in seldict3:
                print(e, ": ", seldict3[e]/seldictsize3)
                try:
                    seldict2_3[e].append((epoch, seldict3[e]/seldictsize3))
                except:
                    seldict2_3[e] = [(epoch, seldict3[e]/seldictsize3)]
            with open('WOSTAS3', 'wb') as fp:
                pickle.dump(seldict2_3, fp)

        

        #np.savez('outlog3.npy', np.array(ACCURACY_1), np.array(ACCURACY_2), np.array(ACCURACY_3), np.array(REWARDS), np.array(SEL_PROB_MAX), np.array(SEL_PROB_MIN), np.array(SUCCESS))
        

