import copy
import random
from stable_baselines3.common.evaluation import evaluate_policy
import numpy as np
import torch
from pathlib import Path
import sys
from a2c_ppo_acktr.model import Policy
from stable_baselines3.common.utils import obs_as_tensor

root_path = '..'
root_path2 = '.'
sys.path.append(root_path)
sys.path.append(root_path2)

from overcooked_ai_py.mdp.actions import Action, Direction
from overcooked_ai_py.agents.agent import AgentFromPolicy, AgentPair
from overcooked_ai_py.agents.benchmarking import AgentEvaluator

from human_aware_rl.rllib.utils import get_base_ae
from overcooked_ai_py.agents.agent import Agent
from stable_baselines3.common.policies import ActorCriticPolicy
from collections import defaultdict

@torch.no_grad()
def get_bc_agent_from_policy(policy, agent_index, featurize_fn, unstuck=True):
    AgentType = UnstuckBCAgent if unstuck else SimpleBCAgent
    # dummy_env = OvercookedMultiAgent.from_config(env_config)
    return AgentType(policy, agent_index, featurize_fn)


def get_bc_agent_fn_from_policy(policy, unstuck=True):
    AgentType = UnstuckBCAgent if unstuck else SimpleBCAgent
    def bc_agent_fn(agent_index, featurize_fn):
        return AgentType(policy, agent_index, featurize_fn)
    return bc_agent_fn


def get_bc_agent_fn_from_saved(layout_name, mode, **kwargs):
    from stable_baselines3.common.policies import ActorCriticPolicy
    save_path = f"./saved_models/bc/{layout_name}/{mode}.zip"
    policy = ActorCriticPolicy.load(save_path)
    def bc_agent_fc(agent_index, featurize_fn):
        return get_bc_agent_from_policy(policy, agent_index, featurize_fn, **kwargs)
    return bc_agent_fc


def mean_and_std_err(lst):
    "Mean and standard error"
    mu = np.mean(lst)
    return mu, std_err(lst)


def std_err(lst):
    sd = np.std(lst)
    n = len(lst)
    return sd / np.sqrt(n)


def evaluate_bc_agent_from_saved(layout_name):
    bc_train_fn = get_bc_agent_fn_from_saved(layout_name, "train")
    bc_test_fn = get_bc_agent_fn_from_saved(layout_name, "test")
    # TODO add agent evaluating


class SimpleBCAgent(Agent):
    def __init__(self, policy: ActorCriticPolicy, agent_index: int, featurize_fn):
        super(SimpleBCAgent, self).__init__()
        self.policy = policy
        self.agent_index = agent_index
        self.featurize = featurize_fn
        
    # def reset(self):
    #     super().reset()
        
    def action_probabilities(self, state):
        obs = self.featurize(state, debug=False)
        my_obs = obs[self.agent_index]
        if not isinstance(my_obs, torch.Tensor):
            my_obs = obs_as_tensor(my_obs, self.policy.device)
            
        return self.policy.get_distribution(my_obs).distribution.probs.detach().cpu().numpy()[0]
    
    def action(self, state):
        obs = self.featurize(state)
        my_obs = obs[self.agent_index]
        
        if not isinstance(my_obs, torch.Tensor):
            my_obs = obs_as_tensor(my_obs, self.policy.device)
        
        action, _, action_log_prob = self.policy.forward(my_obs)
        agent_action = Action.INDEX_TO_ACTION[action[0].item()]
        agent_action_info = {
            'action_probs': action_log_prob.exp()
        }
        
        return agent_action, agent_action_info


class UnstuckBCAgent(Agent):
    def __init__(self, policy, agent_index, featurize_fn, no_waits=False, stochastic=True, stuck_time=3):
        self.policy = policy
        self.agent_index = agent_index
        self.featurize = featurize_fn
        self.stuck_time = stuck_time
        self.history_length = stuck_time + 1
        self.stochastic = stochastic
        self.action_probs = False
        self.no_waits = no_waits
        self.will_unblock_if_stuck = False if stuck_time == 0 else True
        self.reset()
        
    
    def preprocess(self, state):
        my_obs = self.featurize(state)
        my_obs = np.stack(my_obs, axis=0)
        if not isinstance(my_obs, torch.Tensor):
            my_obs = obs_as_tensor(my_obs, self.policy.device)
        action_probs = self.policy.get_distribution(my_obs).distribution.probs.detach().cpu().numpy()[self.agent_index]
        
        if self.no_waits:
            action_probs = self.remove_indices_and_renormalize(action_probs, Action.ACTION_TO_INDEX[Direction.STAY])

        if self.will_unblock_if_stuck:
            action_probs = self.unblock_if_stuck(state, action_probs)
            
        if self.stochastic:
            action_idx = np.random.choice(len(action_probs), p=action_probs)
        else:
            action_idx = np.argmax(action_probs)
            
        action = Action.INDEX_TO_ACTION[action_idx]
        self.add_to_history(state, action)
        return action, action_probs
        
    def action(self, state):
        processed_action, processed_action_probs = self.preprocess(state)
        
        agent_action_info = {
            "action_probs": processed_action_probs,
        }
        
        return processed_action, agent_action_info
    
    def unblock_if_stuck(self, state, action_probs):
        """Get final action for a single state, given the action probabilities
        returned by the model and the current agent index.
        NOTE: works under the invariance assumption that self.agent_idx is already set
        correctly for the specific parallel agent we are computing unstuck for"""
        stuck, last_actions = self.is_stuck(state)
        if stuck:
            assert any([a not in last_actions for a in Direction.ALL_DIRECTIONS]), last_actions
            last_action_idxes = [Action.ACTION_TO_INDEX[a] for a in last_actions]
            action_probs = self.remove_indices_and_renormalize(action_probs, last_action_idxes)
        return action_probs
    
    def is_stuck(self, state):
        if None in self.history[self.agent_index]:
            return False, []
        
        last_states = [s_a[0] for s_a in self.history[self.agent_index][-self.stuck_time:]]
        last_actions = [s_a[1] for s_a in self.history[self.agent_index][-self.stuck_time:]]
        player_states = [s.players[self.agent_index] for s in last_states]
        pos_and_ors = [p.pos_and_or for p in player_states] + [state.players[self.agent_index].pos_and_or]
        if self.checkEqual(pos_and_ors):
            return True, last_actions
        return False, []
    
    @staticmethod
    def remove_indices_and_renormalize(probs, indices):
        if len(np.array(probs).shape) > 1:
            probs = np.array(probs)
            for row_idx, row in enumerate(indices):
                for idx in indices:
                    probs[row_idx][idx] = 0
            norm_probs =  probs.T / np.sum(probs, axis=1)
            return norm_probs.T
        else:
            for idx in indices:
                probs[idx] = 0
            return probs / sum(probs)

    def checkEqual(self, iterator):
        first_pos_and_or = iterator[0]
        for curr_pos_and_or in iterator:
            if curr_pos_and_or[0] != first_pos_and_or[0] or curr_pos_and_or[1] != first_pos_and_or[1]:
                return False
        return True

    def add_to_history(self, state, action):
        assert len(self.history[self.agent_index]) == self.history_length, "something wrong"
        self.history[self.agent_index].append((state, action))
        self.history[self.agent_index] = self.history[self.agent_index][1:]

    def reset(self):
        # Matrix of histories, where each index/row corresponds to a specific agent
        self.history = defaultdict(lambda: [None] * self.history_length)
        super().reset()


class ApagAgentNewVersion(Agent):
    def __init__(self, actor_critic: Policy, featurize_fn, agent_index: int = 0, deterministic=False, sample_prob=1.0, name=None, device='cpu'):
        super(ApagAgentNewVersion).__init__()
        self.actor_critic = actor_critic.to(device)
        self.actor_critic.eval()

        self.reset()
        self.agent_index = agent_index
        self.featurize = featurize_fn
        self.deterministic = deterministic
        self.sample_prob = sample_prob
        self.name = name
        self.device = device
        # print(f'actor_critic: {actor_critic}, featurize_fn: {featurize_fn}, agent_index: {agent_index}, deterministic={deterministic}')
        # print(f'self.agent_index = {self.agent_index}')
        
    def reset(self):
        # super().reset()
        if self.actor_critic.is_recurrent:
            # TODO add recurrent policy initial state
            pass
        else:
            self.rnn_hxs = torch.zeros(1, self.actor_critic.recurrent_hidden_state_size)
            
    def action_probabilities(self, state):
        # NOTE code for test
        # obs = self.featurize(state, debug=False)
        obs = self.featurize(state)
        my_obs = obs[self.agent_index].astype(np.float32)
        if not isinstance(my_obs, torch.Tensor):
            my_obs = obs_as_tensor(my_obs, next(self.actor_critic.parameters()).device)

        # print(my_obs, my_obs.shape)
        if my_obs.shape[-1] == 26:
            my_obs = torch.permute(my_obs, [2, 0, 1])

        my_obs = my_obs.unsqueeze(0)
        _, feats, self.rnn_hxs = self.actor_critic.base(my_obs, self.rnn_hxs, None)
        
        dist = self.actor_critic.dist(feats)
        
        return dist.probs.cpu().detach().numpy()
            
    def action(self, state):
        # NOTE code only for test
        obs = self.featurize(state)
        # print(obs[0].shape)
        my_obs = obs[self.agent_index].astype(np.float32)
        if not isinstance(my_obs, torch.Tensor):
            my_obs = obs_as_tensor(my_obs, next(self.actor_critic.parameters()).device)

        # print(my_obs, my_obs.shape)
        if my_obs.shape[-1] == 26:
            my_obs = torch.permute(my_obs, [2, 0, 1])

        my_obs = my_obs.unsqueeze(0)
        # _, action, action_log_prob, rnn_hxs = self.actor_critic.act(my_obs, self.rnn_hxs, None, deterministic=self.deterministic)

        if not self.deterministic and 0 < self.sample_prob < 1:
            det = not(random.random() <= self.sample_prob)  # do sample instead of deterministic
            _, action, action_log_prob, rnn_hxs = self.actor_critic.act(my_obs, self.rnn_hxs, None, deterministic=det)
        else:
            _, action, action_log_prob, rnn_hxs = self.actor_critic.act(my_obs, self.rnn_hxs, None,
                                                                        deterministic=self.deterministic)
        action = action.to('cpu').detach()
        agent_action_info = {
            "action_probs": action_log_prob.exp(),
        }
        agent_action = Action.INDEX_TO_ACTION[action[0]]
        
        self.rnn_hxs = rnn_hxs
        
        return agent_action, agent_action_info


def eval_self_play(agent, eval_params, num_games=1):
    ae = get_base_ae(eval_params["mdp_params"], eval_params["env_params"], None, None)
    
    rollouts = ae.evaluate_agent_pair(AgentPair(agent, agent, allow_duplicate_agents=True), num_games=num_games)
    
    mean, se = mean_and_std_err(rollouts["ep_returns"])
    return mean, se, rollouts


def eval_paired_play(agent0, agent1, eval_params, num_games=1):
    ae = get_base_ae(eval_params["mdp_params"], eval_params["env_params"], None, None)

    rollouts = ae.evaluate_agent_pair(AgentPair(agent0, agent1, allow_duplicate_agents=True), num_games=num_games)

    mean, se = mean_and_std_err(rollouts["ep_returns"])
    return mean, se, rollouts


def get_dummy_ae():
    return get_base_ae({'layout_name': 'cramped_room'}, {'horizon': 400})
