import numpy as np


class OccupancyMeasure:
    def __init__(self, shape, agent_num):
        self.agent_num = agent_num
        self.lambdas = [np.zeros(shape) for _ in range(agent_num)]
        self.local_lambdas = [np.zeros(shape) for _ in range(agent_num)]

    def update_lambdas(self):
        for aid in range(self.agent_num):
            self.local_lambdas[aid] /= self.local_lambdas[aid].sum()
            self.lambdas[aid] += self.local_lambdas[aid]

    def normalize(self):
        for aid in range(self.agent_num):
            self.lambdas[aid] /= self.lambdas[aid].sum()

    def reset(self):
        for aid in range(self.agent_num):
            self.lambdas[aid] *= 0
            self.local_lambdas[aid] *= 0


class StateOccupancyMeasure(OccupancyMeasure):
    def __init__(self, state_shape, agent_num):
        super().__init__(state_shape, agent_num)

    def count_cur_state(self, state):
        for aid in range(self.agent_num):
            i, j = state[aid]
            self.local_lambdas[aid][i, j] += 1

    def get_prob(self, state):
        prob = [0] * self.agent_num
        for aid in range(self.agent_num):
            i, j = state[aid]
            prob[aid] = self.lambdas[aid][i, j]
        return prob
    

class ContStateOccupancyMeasure(OccupancyMeasure):
    def __init__(self, s, bins, agent_num):
        self.min_x = self.min_y = -s
        self.max_x = self.max_y = s
        super().__init__((bins, bins), agent_num)

    def count_cur_state(self, state):
        for aid in range(self.agent_num):
            x, y = state[aid][2], state[aid][3]
            i = int((x - self.min_x) / (self.max_x - self.min_x) * self.lambdas[aid].shape[0])
            j = int((y - self.min_y) / (self.max_y - self.min_y) * self.lambdas[aid].shape[1])
            i = max(0, min(i, self.lambdas[aid].shape[0] - 1))  # Add boundary check for i
            j = max(0, min(j, self.lambdas[aid].shape[1] - 1))  # Add boundary check for j
            # print(self.agent_num, self.local_lambdas[aid].shape, x, y, i, j)
            self.local_lambdas[aid][i, j] += 1

    def get_prob(self, state):
        prob = [0] * self.agent_num
        for aid in range(self.agent_num):
            x, y = state[aid][2], state[aid][3]
            i = int((x - self.min_x) / (self.max_x - self.min_x) * self.lambdas[aid].shape[0])
            j = int((y - self.min_y) / (self.max_y - self.min_y) * self.lambdas[aid].shape[1])
            i = max(0, min(i, self.lambdas[aid].shape[0] - 1))  # Add boundary check for i
            j = max(0, min(j, self.lambdas[aid].shape[1] - 1))  # Add boundary check for j
            prob[aid] = self.lambdas[aid][i, j]
        return prob
    

class ContStateOccupancyMeasure4d(OccupancyMeasure):
    def __init__(self, vs, s, bins, agent_num):
        self.min_u = self.min_v = -vs
        self.max_u = self.max_v = vs
        self.min_x = self.min_y = -s
        self.max_x = self.max_y = s
        super().__init__((bins, bins, bins, bins), agent_num)

    def count_cur_state(self, state):
        for aid in range(self.agent_num):
            u, v = state[aid][0], state[aid][1]
            p = int((u - self.min_u) / (self.max_u - self.min_u) * self.lambdas[aid].shape[0])
            q = int((v - self.min_v) / (self.max_v - self.min_v) * self.lambdas[aid].shape[1])
            p = max(0, min(p, self.lambdas[aid].shape[0] - 1))  # Add boundary check for p
            q = max(0, min(q, self.lambdas[aid].shape[1] - 1))  # Add boundary check for q
            x, y = state[aid][2], state[aid][3]
            i = int((x - self.min_x) / (self.max_x - self.min_x) * self.lambdas[aid].shape[2])
            j = int((y - self.min_y) / (self.max_y - self.min_y) * self.lambdas[aid].shape[3])
            i = max(0, min(i, self.lambdas[aid].shape[2] - 1))  # Add boundary check for i
            j = max(0, min(j, self.lambdas[aid].shape[3] - 1))  # Add boundary check for j
            # print(self.agent_num, self.local_lambdas[aid].shape, x, y, i, j)
            self.local_lambdas[aid][p, q, i, j] += 1

    def get_prob(self, state):
        prob = [0] * self.agent_num
        for aid in range(self.agent_num):
            u, v = state[aid][0], state[aid][1]
            p = int((u - self.min_u) / (self.max_u - self.min_u) * self.lambdas[aid].shape[0])
            q = int((v - self.min_v) / (self.max_v - self.min_v) * self.lambdas[aid].shape[1])
            p = max(0, min(p, self.lambdas[aid].shape[0] - 1))  # Add boundary check for p
            q = max(0, min(q, self.lambdas[aid].shape[1] - 1))  # Add boundary check for q
            x, y = state[aid][2], state[aid][3]
            i = int((x - self.min_x) / (self.max_x - self.min_x) * self.lambdas[aid].shape[2])
            j = int((y - self.min_y) / (self.max_y - self.min_y) * self.lambdas[aid].shape[3])
            i = max(0, min(i, self.lambdas[aid].shape[2] - 1))  # Add boundary check for i
            j = max(0, min(j, self.lambdas[aid].shape[3] - 1))  # Add boundary check for j
            prob[aid] = self.lambdas[aid][p, q, i, j]
        return prob


class StateActionOccupancyMeasure(OccupancyMeasure):
    def __init__(self, state_shape, action_dim, agent_num):
        shape = [action_dim] + state_shape
        super().__init__(shape, agent_num)

    def count_cur_state(self, state, action):
        for aid in range(self.agent_num):
            i, j = state[aid]
            self.local_lambdas[aid][action[aid], i, j] += 1

    def get_prob(self, state, action):
        prob = [0] * self.agent_num
        for aid in range(self.agent_num):
            i, j = state[aid]
            prob[aid] = self.lambdas[aid][action[aid], i, j]
        return prob