
from utils.conf import *
from collections import defaultdict


vv_list = defaultdict(dict)


def getReward(world, pred_concept, vertex_pair,org_concept, msg,  nl_msg):
        global vv_list
        # pred concept contains concepts predicted by the listener
        targ_seg,targ_sec,targ_color = pred_concept

        # org_concept are the true concept  
        org_seg, org_sec, org_color = org_concept
        src_idx2,targ_idx2 = vertex_pair
        w1,w2,w3 = msg
        nw1, nw2, nw3 = nl_msg

        vertices = []
        vertices_cols = []
        reward = 0
        c_flag = 0.0
      #  print("getR: [pred]", targ_seg.item(), targ_sec.item(), targ_color.item())
      #  print("getR: [org]", org_seg.item(), org_sec.item(), org_color.item())
#        if(org_concept[0] == pred_concept[0] and org_concept[1] == pred_concept[1] and org_concept[2] == pred_concept[2] ):
 #           print("$$$$ ", vertex_pair, org_concept, pred_concept)
         
        try:
            if targ_sec.item() != 0 or targ_seg.item() != 0:
                vertices = vv_list[src_idx2.item()][(targ_seg.item(), targ_sec.item())]
            #print(vertices)
        except:
            for v_idx in range(N_VERTEX):
                if v_idx == src_idx2:
                    continue
                # now find region of (src_idx, target = v_idx)
                p1, p2 = world._getconcepts(targ_idx = v_idx, src_idx = src_idx2)
                #print(v_seg, ":", v_sec, ":", v_col, ":", targ_color.item())
                #print(targ_seg.item(), ":", targ_sec.item())
                
                for ipn in [p1,p2]:                
                    v_seg, v_sec,v_col = ipn

                    try:
                        if(v_idx not in vv_list[src_idx2.item()][(v_seg, v_sec)]):
                            vv_list[src_idx2.item()][(v_seg, v_sec)].append(v_idx)
                            print(src_idx2.item(), v_seg, v_sec, "--> ", vv_list[src_idx2.item()][(v_seg, v_sec)])
                        if(v_idx not in vv_list[src_idx2.item()][(v_seg, 0)]):
                            vv_list[src_idx2.item()][(v_seg, 0)].append(v_idx)
                            print(src_idx2.item(), v_seg, 0, "--> ", vv_list[src_idx2.item()][(v_seg, 0)])
                        if(v_idx not in vv_list[src_idx2.item()][(0, v_sec)]):
                            print(src_idx2.item(), 0, v_sec, "--> ", vv_list[src_idx2.item()][(0, v_sec)])
                            vv_list[src_idx2.item()][(0, v_sec)].append(v_idx)
                            
                    except:
                        vv_list[src_idx2.item()][(v_seg, v_sec)] = [v_idx]
                        vv_list[src_idx2.item()][(v_seg, 0)] = [v_idx]
                        vv_list[src_idx2.item()][(0, v_sec)] = [v_idx]
                        print(src_idx2.item(), v_seg, v_sec, "##> ", vv_list[src_idx2.item()][(v_seg, v_sec)])
                

        if targ_sec.item() != 0 or targ_seg.item() != 0:
            try:
                vertices = vv_list[src_idx2.item()][(targ_seg.item(), targ_sec.item())]
            except:
                pass
        
        for vi in vertices:
            ci = world.getColor(vi)
            if ci == targ_color.item():
                vertices_cols.append(vi)

        #print(src_idx2.item(), targ_idx2.item(), "%%%%: ", vertices, vertices_cols)
                
            
                    
        
        # now if target_vertex is present in the region predicted by the listener
        #print(f'vertices = {vertices_cols}')
        if targ_idx2.item() in vertices_cols:
            reward = 100
            c_flag += 1
        elif targ_idx2.item() in vertices:
           # print("IIII: ",src_idx2, targ_idx2)
           # print(pred_concept)
            #print(org_concept)
            if(len(vertices) == 1):
                reward = -10
                c_flag += 1
            else:
                reward = -10
            # v_seg, v_sec,v_col = world._getconcepts(targ_idx = targ_idx, src_idx = src_idx)
            # print(f'reward = {reward}')
            # print(f'pred_concepts = {v_seg, v_sec,v_col}')
            # print(f'org_concepts = {org_sec, org_seg, org_color}')
            '''# check if other vertices are there
            if len(vertices) == 1:
                reward = 100
                c_flag+=1
            # otherwise check if color of the target vertex is unique
            elif vertices_cols.count(org_color)>1:
                # partial reward 
                reward = 50
                '''
        else:
            reward = -15


        if(org_seg.item() != 0 and w1.item() == nw1.item()):
                reward = -1200

        if(org_sec.item() != 0 and w2.item() == nw2.item()):
                reward = -1200

        if(org_color.item() != 0 and w3.item() == nw3.item()):
                reward = -1200


        if(org_sec.item()  == 0 and org_seg.item()  == 0 and org_color.item()  == 0):
            reward = -2000

        return reward,c_flag


def getBatchReward(world,pred_batch, vertex_batch, org_concepts,  msg,  nl_msg):
     reward = torch.LongTensor(BATCH_SIZE)
     accuracy=0
     for b in range(ZERO_DATA,BATCH_SIZE):
          reward[b],c_flag = getReward(world,pred_batch[b],vertex_batch[b],org_concepts[b], msg[b], nl_msg[b])
          accuracy+=c_flag    
     return reward,accuracy


def perConceptAccuracy(batchData, concept):
    # if both prediction are correct reward is 10 else -1
    
    acc_1 = 0
    acc_2 = 0
    acc_3 = 0
    for b in range(BATCH_SIZE):
        # print('\n')
        # print('*'*50)
        # print('checking per concept accurary')
        # print(f'batchData = {batchData[b]}')
        # print(f'concept = {concept[b]}')
        # print('*'*50)
        if batchData[b][0] == concept[b][0]:
            acc_1+=1
        if batchData[b][0] == concept[b][0] and batchData[b][1] == concept[b][1]:
            acc_2+=1
        if batchData[b][0] == concept[b][0] and batchData[b][1] == concept[b][1] and batchData[b][2] == concept[b][2]:
            acc_3+=1
        
    return acc_1,acc_2,acc_3