#!/usr/bin/env python
# coding: utf-8

# In[ ]:


import numpy as np
import math
import copy
import time
import cvxopt
import matplotlib.patches as patch
import matplotlib.pyplot as plt
#record the intial state of agent
class Agent(object):
    def __init__(self, state, actions, time_step,neighbors=None):
        # np.array([x_1,x_2])
        self.state = np.array(state)
        # the trajectory for agent
        self.traj = [np.array(state)]       
        # the time frequency for update e.g. 0.05s
        self.time_step = time_step          
    
        # Action Space 
        self.actions = actions             #[[0,1],[0,2],[1,0],[2,0]]
        self.n_actions = len(self.actions)
        self.action_indices = np.array(range(self.n_actions))
        self.next_action_index = 0
        # Neighbors 
        self.neighbors = neighbors

    def motion_model(self, state, action):
        '''
        :param state: The current state at time t.
        :param action: The current action at time t.
        :return: The resulting state x_{t+1}
        '''
        return state+action*self.time_step
    
    def apply_next_action(self):
        """
        Applies the next action to modify the agent state.
        :param t: Time step.
        :return: None
        """
        self.state = self.motion_model(self.state, self.actions[self.next_action_index])

    def Initialize(self,N=16):
        """
        Intialize the action_prob_dist
        :param N: the number of agents
        :return: None
        """
        self.action_prob_dist=(1/self.n_actions)*np.ones((N,self.n_actions))
        self.pre_prob_dist=(1/self.n_actions)*np.ones((N,self.n_actions))
        

    
    def OSG_intialize(self,n_time_step=2000):
        self.J = math.ceil(math.log2(n_time_step)) # number of experts
        self.g = np.sqrt(math.log(self.J) / n_time_step) 
        self.beta = 1 / n_time_step
        self.gamma = [np.sqrt(math.log(self.n_actions * n_time_step) / 2 ** (j - 1)) for j in range(1, self.J + 1)]   # J x 1
        self.expert_weight =np.array([1.0 / self.J for i in range(self.J)])                                  # weights of experts        
        self.action_weight = np.array([[1.0 / self.n_actions for k in range(self.n_actions)] for i in range(self.J)]) # weights of actions for all experts        
        self.action_prob_dist = None
        self.loss =[0.0 for i in range(self.n_actions)]
        
    
    
    def get_losses(self, agents_considered, targets):
        """
        Returns the losses of all possible actions based on the estimation result of just executed actions
        :return: The losses of all possible actions.
        """
        #agents_considered is the list of locations/states of agents considered
    
        losses = np.zeros(self.n_actions)
        # the loss f(A\cup\{a\})-f(A)
        if len(agents_considered) == 0:
            # f(empty set) = 0 before any agent's moving
            # f(a): distance with action a without any other agent's action
            obj_action = []
            for i in range(self.n_actions):
                action = self.actions[i] # self.action_indices 需要用 np.array()
                state = self.motion_model(self.state, action)
                obj_action.append(sum([Reward_func(np.linalg.norm(tar.state - state)) for tar in targets]))
            losses =np.array(obj_action)
        else:
            tar_min=[max([Reward_func(np.linalg.norm(tar.state - ag)) for ag in agents_considered]) for tar in targets]
            curr_obj = sum(tar_min)
            length=len(targets)
            for i in range(self.n_actions):
                action = self.actions[i]
                state = self.motion_model(self.state, action)
                # f(A_{i-1} U a)
                temp_obj = 0
                for j in range(length):
                    temp_obj += max([Reward_func(np.linalg.norm(targets[j].state - state)), tar_min[j]])
        

                losses[i] = temp_obj-curr_obj #obj_action[i]
                #if abs(losses[i])<0.0001:
                    #print(losses[i])
                    #print(np.array([1/(np.linalg.norm(targets[j].state - state)+ 0.001) for j in range(length)])<=np.array(tar_min))
        return losses

    def update_experts(self):
        """
        Updates the parameters of experts after getting losses (from t to t + 1)
        :param t: The index of time step
        :return: None
        """
        for j in range(self.J):
            Index_Set=[self.gamma[j] * self.loss[i] for i in range(self.n_actions)]
            MAX=max(Index_Set)
            if MAX>700:
                Index_Set=[Index_Set[i]*700/MAX  for i in range(self.n_actions)]
            v = [self.action_weight[j][i] * np.exp(Index_Set[i]) for i in range(self.n_actions)]
            #################
            update1=np.dot(np.array(self.loss), np.array(self.action_weight[j]))
            self.action_weight[j]= [self.beta * np.sum(v) / self.n_actions + (1 - self.beta) * v[i] for i in range(self.n_actions)]
            self.expert_weight[j] = self.expert_weight[j] * np.exp(self.g *update1)
            

        self.action_weight= np.array([self.action_weight[j] / np.linalg.norm(self.action_weight[j], ord=1) for j in range(self.J)],dtype=np.float128)
        self.expert_weight= self.expert_weight/ np.linalg.norm(np.array(self.expert_weight), ord=1)
    
    def get_action_prob_dist(self):
        """
        Returns the output of FSF* (the predicted action probability distribution)
        :param t: The index of time step.
        :return: None.
        """
        q = np.array(self.expert_weight) # J x 1
        p = np.array(self.action_weight) # J x m
        self.action_prob_dist=np.dot(q, p).tolist() # m x 1, note that self.action_prob_dist is a list

