# get the source and target vertex 

# create a feature vector from the source vertex and target vertex consider other vertex

# [x1,y1, x2,y2, distance(p1,....,pn), angle(p1,...,p2)]

from utils.graphworld import World
from utils.datagen import DataLoader
from utils.conf import *
from utils.reward import *
import json

from models.chatbot import *
from utils.conf import *

world = World()
data = DataLoader()


# print(f'vertex_batch = {vertex_batch}')
# print(f'batchData = {batchData}')
# print(f'feat = {world.feat}')



# generate feature vector for each source,target pair
def batchfeature(vertex_batch):
    graphBatch = torch.Tensor(BATCH_SIZE,FEAT_VEC_SIZE)
    b_i = 0
    for src,targ in vertex_batch:
        b = []
        b.extend(world.locations[src].tolist())
        b.extend(world.locations[targ].tolist())
        b.extend(world.feat[src.item()])    
        graphBatch[b_i] = torch.Tensor(b)
        b_i+=1
    return graphBatch


# Create two agents for speaking and listening

speaker = SpeakModule()
listener = ListenModule()

optimizer = torch.optim.Adam([{'params': speaker.parameters(), \
                                'lr':SPK_LEARNING_RATE},\
                        {'params': listener.parameters(), \
                                'lr':LIS_LEARNING_RATE}])


# print configuration of the setup
setting_details = {'N_AGENTS':N_AGENTS,
       'NUM_ATTRS':NUM_ATTRS,
       'MSG_LEN':MSG_LEN,
       'N_SECTORS':N_SECTORS,
       'N_SEGMENTS':N_SEGMENTS,
       'N_COLORS':N_COLORS,
       'N_CONCEPTS':N_CONCEPTS,
       'N_VOCAB':N_VOCAB,
       'N_VERTEX':N_VERTEX,
       'num_epochs':NUM_EPOCHS,
        'batch_size':BATCH_SIZE,
        'msg_mode':MSG_MODE}

print(json.dumps(setting_details, indent = 4))




Y = []
LOSS = []
ACCURACY = []
# START THE TRAINING

for epoch in range(NUM_EPOCHS):
    vertex_batch = data.getBatch(BATCH_SIZE)
    batchData = world.getconcepts(vertex_batch) 
    # print(f'batchData = {batchData[:,0]}')
    # print(f'batchData shape = {batchData[:,0].shape}')
    graphBatch  = batchfeature(vertex_batch)
    
    # pass the vertex batch to speaker 
    comm_order = speaker.chooseConceptOrder(graphBatch)
    # print(f'batchData  = {batchData}')
    # print(f'comm order = {comm_order}')

    # print(f'batchData * comm_order = {batchData[:,0]*comm_order[:,0]}')

    # print(f'comm order = {comm_order}')

    
    # passing order and concepts to the speaker to utter messages for the listener
    message,spk_log_probs,logProbs = speaker.speak(comm_order, batchData)
    pred_concept,lis_log_probs,ct2 = listener.listen_and_predict(message,batchData)

    # print(f'message = {message}')
    # print(f'pred_concept = {pred_concept}')
    rewards1,accuracy = getBatchReward(world,pred_concept,vertex_batch,batchData)
    rewards,acc_1,acc_2,acc_3 = perConceptAccuracy(batchData,pred_concept)

    # Backpropagation
    optimizer.zero_grad()
    
    
    lis_loss = -(((rewards.detach()-lis_log_probs.detach()-spk_log_probs.detach()-1)*(lis_log_probs+spk_log_probs)).mean())
    
    lis_loss.backward()
    # spk_loss.backward()
    nn.utils.clip_grad_norm_(listener.parameters(), 50.0)
    
    # speakerOptimiser.zero_grad()
    # spk_loss = -((rewards.detach()*spk_log_probs).mean())
    # spk_loss.backward()

    # update 
    optimizer.step()
    # speakerOptimiser.step()
    
    if epoch%50 == 0:
        # print(torch.exp(lis_log_probs.detach()))
        print(f'rewards1 = {rewards1}')
        print(f'rewards = {rewards}')
        
        print(f'epoch:{[epoch]}, lis_loss:{lis_loss:.4f}, accuracy:{accuracy/BATCH_SIZE}, first: {acc_1/BATCH_SIZE}, first_two :{acc_2/BATCH_SIZE}, first_3 = {acc_3/BATCH_SIZE}')
        Y.append(epoch)
        LOSS.append(lis_loss.detach())
        ACCURACY.append(int((accuracy/BATCH_SIZE)*100))
        np.save('Y.npy',Y)
        np.save('LOSS.npy',LOSS)
        np.save('ACCURACY.npy',ACCURACY)

        for p in speaker.parameters():
            print(p.data)
        
        print('*'*100)
        for p in listener.parameters():
            print(p.data)
    






























