import torch
import numpy as np
import time
import utils
from algos.model import ACModel
from algos.ppo import PPO
from algos.occupancy_measure import StateOccupancyMeasure, ContStateOccupancyMeasure, ContStateOccupancyMeasure4d


class PPOwPrior(PPO):
    def __init__(self, env, args, target_steps=2048, prior=None):
        super().__init__(env, args, target_steps)

        self.N = args.N
        self.prior = prior
        self.pweight = args.pweight
        self.pdecay = args.pdecay
        self.pweight_update_count = 0
        self.pdecay_interval = args.pdecay_interval
        self.use_shadow_reward = True

        if self.args.env.env_name == "mpe":
            self.s = 1.5
            self.bins = 20
            if "4d" in self.args.mpe_demonstration:
                self.vs = 3
                self.occupancy_measures = ContStateOccupancyMeasure4d(self.vs, self.s, self.bins, env.agent_num)
            else:
                self.occupancy_measures = ContStateOccupancyMeasure(self.s, 20, env.agent_num)
        else:
            self.occupancy_measures = StateOccupancyMeasure(env.grid.shape, env.agent_num)
    
    def collect_experiences(self, buffer, tb_writer=None):
        self.compute_lambda()

        buffer.empty_buffer_before_explore()
        steps = 0
        ep_returns = np.zeros(self.agent_num * 2)

        stime = time.time()
        while steps < self.target_steps:
            state = self.env.reset()
            done = False
            ep_steps = 0
            ep_returns *= 0
            while not done:
                # self.env.render()
                action = self.select_action(state["vec"], state.get("mask"))
                next_state, reward, done, info = self.env.step(action)
                if self.use_shadow_reward:
                    shadow_reward = self.compute_shadow_r(state["vec"], action)
                    reward = reward + shadow_reward

                buffer.append(state["vec"], action, next_state["vec"], reward, done, state.get("mask"))

                ep_returns += reward
                state = next_state
                steps += 1
                ep_steps += 1

            if tb_writer:
                tb_writer.add_info(ep_steps, ep_returns, self.pweight)

        etime = time.time()
        fps = steps / (etime - stime)
        print("FPS: ", fps)
        if self.use_state_norm:
            buffer.update_rms()
        return steps

    def get_prior_prob(self, state, action):
        prob = [0] * self.agent_num
        for aid in range(self.agent_num):
            if self.args.env.env_name == "mpe":
                if "4d" in self.args.mpe_demonstration:
                    u, v = state[aid][0], state[aid][1]
                    p = int((u + self.vs) / (self.vs * 2) * self.prior[aid].shape[0])
                    q = int((v + self.vs) / (self.vs * 2) * self.prior[aid].shape[1])
                    p = max(0, min(p, self.prior[aid].shape[0] - 1))  # Add boundary check for p
                    q = max(0, min(q, self.prior[aid].shape[1] - 1))  # Add boundary check for q
                    x, y = state[aid][2], state[aid][3]
                    i = int((x + self.s) / (self.s * 2) * self.prior[aid].shape[2])
                    j = int((y + self.s) / (self.s * 2) * self.prior[aid].shape[3])
                    i = max(0, min(i, self.prior[aid].shape[2] - 1))  # Add boundary check for i
                    j = max(0, min(j, self.prior[aid].shape[3] - 1))  # Add boundary check for j
                    prob[aid] = self.prior[aid][action[aid], p, q, i, j]
                else:
                    x, y = state[aid][2], state[aid][3]
                    i = int((x + self.s) / (self.s * 2) * self.prior[aid].shape[0])
                    j = int((y + self.s) / (self.s * 2) * self.prior[aid].shape[1])
                    i = max(0, min(i, self.prior[aid].shape[1] - 1))  # Add boundary check for i
                    j = max(0, min(j, self.prior[aid].shape[2] - 1))  # Add boundary check for j
                    # print(self.prior[aid].shape, x, y, i, j)
                    prob[aid] = self.prior[aid][action[aid], i, j]
            else:
                prob[aid] = self.prior[aid][action[aid], state[aid][0], state[aid][1]]
        return prob

    def compute_shadow_r(self, state, action):
        shadow_r = [0] * self.agent_num
        cur_probs = self.occupancy_measures.get_prob(state)
        prior_probs = self.get_prior_prob(state, action)
        for aid in range(self.agent_num):
            cur_prob = cur_probs[aid] + 1e-12
            prior_prob = prior_probs[aid] + 1e-12
            if cur_prob != prior_prob:
                shadow_r[aid] = - (np.log(2 * cur_prob) - np.log(cur_prob + prior_prob))
        return shadow_r

    def compute_lambda(self):
        episode = 0
        while episode < self.N:
            state = self.env.reset()
            self.occupancy_measures.count_cur_state(state["vec"])
            done = False
            while not done:
                action = self.select_action(state["vec"], state.get("mask"))
                state, reward, done, _ = self.env.step(action)
                self.occupancy_measures.count_cur_state(state["vec"])
            self.occupancy_measures.update_lambdas()
            episode += 1
        self.occupancy_measures.normalize()

