# Script for buliding the spatial conceptualisation
from utils.library import *

import random
import string
from utils.conf import *
import pickle

class World:
    '''
        Generates all the hidden details for the environment 
    '''
    def __init__(self) -> None:
        # creating the radiuses of the concentric circles
        self.radii = RADII
        self.generate_vertex()
        self.generate_colors()
        self.generate_vocabulary()
        self.generate_concepts()
        self.generate_feature()
        
    
    def generate_vertex(self):
        '''
        generating vertexes for the environment
        '''
    
        torch.manual_seed(random.randint(0,1000))
        self.locations = (MAX_DIMENSION - MIN_DIMENSION)*\
                        torch.rand(N_VERTEX, 2) + MIN_DIMENSION
        
        with open('WORLD_LOC', 'wb') as fp:
            pickle.dump(self.locations, fp)
    
    
    def generate_colors(self):
        self.vertex_colors = np.array([-1]*N_VERTEX)
        
        s_index = N_SEGMENTS + N_SECTORS + 1
       # while(np.size(np.unique(self.vertex_colors)) != N_COLORS):
        #    self.vertex_colors = np.random.choice(range(s_index,N_COLORS+s_index),N_VERTEX)
        colors = list(range(N_COLORS))
        for k in range(N_VERTEX):
            c = np.random.choice(colors)
            self.vertex_colors[k] = c+s_index
            colors.remove(c)
            if(len(colors) == 0): 
                colors = list(range(N_COLORS)) 

        print(self.vertex_colors)
        with open('WORLD_COLORS', 'wb') as fp:
            pickle.dump(self.vertex_colors, fp)


    def generate_vocabulary(self):
        '''
        Returns vocabulary size equal to concept space size

        Arguments:
        ------------
                   n_concepts: number of concepts 
        Returns:
        ------------
                List containing vocabularies 
        '''
        self.vocabularies = []
        while(len(self.vocabularies)<N_CONCEPTS):
            res = ''.join(random.choices(string.ascii_lowercase,k=3))
            if res not in self.vocabularies:
                self.vocabularies.append(res)


    def generate_concepts(self):
        '''
        Assigns concept for each concept in the concept space
        '''
        # concepts for sectors (starting from 1)
        self.sectors = [i for i in range(1,N_SECTORS+1)]
        # concepts for segments
        self.segments = [i+100 for i in range(1, N_SEGMENTS+1)]
        # concepts for colors
        self.colors = [i+200 for i in range(1, N_COLORS+1)]
        self.all_concepts = self.sectors + self.segments + self.colors
    
    def getconcepts2(self, target_index: int, source_index: int) -> tuple:
        '''
        Gives the concepts of the corresponding source and target location

        Arguements:
        ----------------
                    target_loc : target location for the agent
                    source_loc : source location of the agent
        
        Returns:
        ----------------
                octant, segment ,quadrant
        '''
        target_loc = self.locations[target_index]
        source_loc = self.locations[source_index]
        co1 = target_loc[0] - source_loc[0]
        co2 = target_loc[1] - source_loc[1]
        point1, point2 = source_loc, target_loc
        length = torch.sqrt(torch.square(point1[0] - point2[0])+torch.square(point1[1] - point2[1]))
        angle = torch.rad2deg(torch.acos((point2[0]-point1[0])/length))
        if point1[1]>point2[1]:
            angle = 360 - angle
    
        octant = 0
        segment = 0
        quadrant = 0
        if angle>=0 and angle<=45: quadrant,octant = 1,1
        
        elif angle>45 and angle<=90: quadrant,octant = 1,2

        elif angle>90 and angle<=135: quadrant,octant = 2,3

        elif angle>135 and angle<=180: quadrant,octant = 2,4

        elif angle>180 and angle<=225: quadrant,octant =3,5
            
        elif angle>225 and angle<=270: quadrant,octant = 3,6

        elif angle>270 and angle<=315: quadrant,octant =4,7
        
        elif angle>315: quadrant,octant = 4,8
        
        
        # finding the circle
        # c_x,c_y =source_loc[0], source_loc[1] #coordinates of the origin
        distance = torch.sqrt(torch.square(co1) + torch.square(co2))
        
        # radiuses = torch.linspace(0, 20, steps=constants.n_segments)
        # radiuses = torch.load('data/radiuses.pt')

        for i, s in enumerate(self.radii):
            if distance<=s.item():
                segment = i
                break

        # color of target location
        color = self.vertex_colors[target_index]
        return quadrant, segment, quadrant, color
        

    def get_vocab_tensor(self, vocab):
        '''
            Create one hot encoding for the vocabulary given
        '''
        encoding = torch.zeros(len(self.vocabularies))
        i = self.vocabularies.index(vocab)
        encoding[i] = 1
        return encoding
    
    def get_concept_tensor(self, concept):
        '''
            Create one hot encoding for the concept given
        '''
        encoding = torch.zeros(len(self.all_concepts))
        i = self.all_concepts.index(concept)
        encoding[i] = 1
        return encoding
    def get_concept_enc(self, concept, concept_type):
        '''
            Create one hot encoding for the concept given along with 
            one hot encoding for the order of the concept
            '''
        if concept_type == 'sector': 
            # o_e = opts.comm_encoding['sector']
            i = self.sectors.index(concept)
            c_e = [0]*len(self.sectors)
            c_e[i] = 1
            # c_e.extend(o_e)
            return c_e
        elif concept_type == 'segment':
            # o_e = opts.comm_encoding['segment']
            i = self.segments.index(concept)
            c_e = [0]*len(self.segments)
            c_e[i] = 1
            # c_e.extend(o_e)
            return c_e
        elif concept_type == 'color':
            # o_e = opts.comm_encoding['color']
            i = self.colors.index(concept)
            c_e = [0]*len(self.colors)
            c_e[i] = 1
            # c_e.extend(o_e)
            return c_e
    
    def _getconcepts(self,targ_idx, src_idx):
        '''
            Gives the spatial concepts of the target vertex taking source vertex as the origin
            
            args:
            -----  
                targ_idx : index of target vertex
                src_idx : index of the source vertex

            Return:
            ------
                    [segment,sector,color]
        
        '''
        target_loc = self.locations[targ_idx]
        source_loc = self.locations[src_idx]
        co1 = target_loc[0] - source_loc[0]
        co2 = target_loc[1] - source_loc[1]
        point1, point2 = source_loc, target_loc
        length = torch.sqrt(torch.square(point1[0] - point2[0])+torch.square(point1[1] - point2[1]))
        angle = torch.rad2deg(torch.acos((point2[0]-point1[0])/length))
        if point1[1]>point2[1]:
            angle = 360 - angle

        sector = 0
        l_sector = []
        segment = N_SEGMENTS
        color = 0
        if angle>=0 and angle<180: l_sector = [4]     
        elif angle>=180 and angle<=360: l_sector = [5]
        #elif angle>240 and angle<=360: l_sector = [6]

        if angle>=0 and angle<90: l_sector.append(6)     
        elif angle>=90 and angle<180: l_sector.append(7)

         # finding the circle
        # c_x,c_y =source_loc[0], source_loc[1] #coordinates of the origin
        
        for i, s in enumerate(self.radii):
            if length <= s.item():
                segment = i
                break
    
        sector = np.random.choice(l_sector)
        l_sector.remove(sector)
        if(len(l_sector) != 0):
            sector2 = np.random.choice(l_sector)
        else:
            sector2 = sector
        # color of target location
        color = self.vertex_colors[targ_idx]
        return [segment, sector, color], [segment, sector2, color]
    
    def getconcepts(self,vertex_pairs):
        # now for each pair get 
        concepts1 = []
        concepts2 = []
        for src_idx, targ_idx in vertex_pairs:
            p1, p2 = self._getconcepts(targ_idx, src_idx)
            concepts1.append(p1)
            concepts2.append(p2)

        return torch.LongTensor(concepts1), torch.LongTensor(concepts2)
    
    def getCheckData(self):
        Xdata = [[1,4,8],
                [2,5,9],
                [3,6,10],
                [0,7,11]]
        return torch.LongTensor(Xdata)

    def getColor(self, idx):
        return self.vertex_colors[idx]
        
    def getvertexInregion(self, targ_sec,targ_seg,targ_col, src_idx,targ_idx):
        count = 0
        vertices = []
        vertices_cols = []
        reward = 0
        for v_idx in range(N_VERTEX):
            if v_idx == src_idx:
                continue
            # now find region of (src_idx, target = v_idx)
            v_seg, v_sec,v_col = self._getconcepts(src_idx, v_idx)

            if v_sec == targ_sec and v_seg == targ_seg:
                vertices.append(v_idx)
                vertices_cols.append(v_col)
     
        # now if target_vertex is present in the region predicted by the listener
        if targ_idx in vertices:
            # check if other vertices are there
            if len(vertices) == 1:
                reward = 100
            # otherwise check if color of the target vertex is unique
            elif vertices_cols.count(targ_col)>1:
                # partial reward 
                reward = 50
        else:
            reward = -1

        return reward
        

    def generate_feature(self):
        # keeping one vertex as source find euclidian distance between all the vertex
        self.feat = {}
        for i in range(N_VERTEX):
            point1 = self.locations[i]
            angles = []
            distances = []
            feat_vect = []

            for j in range(N_VERTEX):
                point2 = self.locations[j]
                if i==j:
                    continue
                dist = torch.sqrt(torch.square(point1[0] - point2[0])+torch.square(point1[1] - point2[1]))
                angle = torch.rad2deg(torch.acos((point2[0]-point1[0])/dist))
                angles.append(angle.tolist())
                distances.append(dist.tolist())
            
            feat_vect.extend(distances)
            feat_vect.extend(angles)
            self.feat[i] = feat_vect