import torch


# generate vertex code
# generate a set of vertices and pass it to the listener 


from utils.graphworld import  World
from utils.datagen import DataLoader
from utils.conf import *
from utils.reward import *




from models.chatbot import *
from utils.conf import *

from time import sleep
import torch.nn as nn
import numpy as np



world = World()
data = DataLoader()
speakingAgent = SpeakModule()
listeningAgent = ListenModule()

# speakerOptimiser = torch.optim.Adam(speakingAgent.parameters(), lr=LEARNING_RATE)
# listenerOptimiser = torch.optim.Adam(listeningAgent.parameters(), lr=LEARNING_RATE)
optimizer = torch.optim.Adam([{'params': speakingAgent.parameters(), \
                                'lr':SPK_LEARNING_RATE},\
                        {'params': listeningAgent.parameters(), \
                                'lr':LIS_LEARNING_RATE}])



# printing setting details
dic = {'N_AGENTS':N_AGENTS,
       'NUM_ATTRS':NUM_ATTRS,
       'MSG_LEN':MSG_LEN,
       'N_SECTORS':N_SECTORS,
       'N_SECTORS':N_SECTORS,
       'N_COLORS':N_COLORS,
       'N_CONCEPTS':N_CONCEPTS,
       'N_VOCAB':N_VOCAB,
       'N_VERTEX':N_VERTEX}

dic2 = {'num_epochs':NUM_EPOCHS,
        'batch_size':BATCH_SIZE,
        'msg_mode':MSG_MODE}


print(dic)
print(dic2)

# trainPhase
# speakingAgent.train()
# listeningAgent.train()
Y = []
ACCURACY = []
LOSS = []
for epoch in range(NUM_EPOCHS):
    vertex_batch = data.getBatch(BATCH_SIZE) # pair of index of src, target
    # get the corresponding [segment,sector,colors] from source and target pairs
    batchData = world.getconcepts(vertex_batch) 

    
    # passing concept to speaker and getting message
    message, spk_log_probs,logProbs = speakingAgent.speak(batchData)
   
    # passing message to listener and getting concept
    pred_concept,lis_log_probs,ct2 = listeningAgent.listen_and_predict(message,batchData)
    # print(concept)
    
    

    rewards,accuracy = getBatchReward(world,pred_concept,vertex_batch,batchData)

    

    # Backpropagation
    optimizer.zero_grad()
    
    
    lis_loss = -(((rewards.detach()-lis_log_probs.detach()-spk_log_probs.detach()-1)*(lis_log_probs+spk_log_probs)).mean())
    
    lis_loss.backward()
    # spk_loss.backward()
    nn.utils.clip_grad_norm_(listeningAgent.parameters(), 50.0)
    
    # speakerOptimiser.zero_grad()
    # spk_loss = -((rewards.detach()*spk_log_probs).mean())
    # spk_loss.backward()

    # update 
    optimizer.step()
    # speakerOptimiser.step()
    
    if epoch%100 == 0:
        # print(torch.exp(lis_log_probs.detach()))
        print(f'epoch:{[epoch]}, lis_loss:{lis_loss:.4f}, accuracy:{accuracy/BATCH_SIZE}')
        Y.append(epoch)
        LOSS.append(lis_loss.detach())
        ACCURACY.append(int((accuracy/BATCH_SIZE)*100))
        np.save('Y.npy',Y)
        np.save('LOSS.npy',LOSS)
        np.save('ACCURACY.npy',ACCURACY)
        

        
        

