import torch as th
import torch.nn.functional as F
from torch.distributions import Categorical
from .epsilon_schedules import DecayThenFlatSchedule
import numpy as np
REGISTRY = {}


class MultinomialActionSelector():

    def __init__(self, args):
        self.args = args

        self.schedule = DecayThenFlatSchedule(args.epsilon_start, args.epsilon_finish, args.epsilon_anneal_time,
                                              decay="linear")
        self.epsilon = self.schedule.eval(0)
        self.test_greedy = getattr(args, "test_greedy", True)

    def select_action(self, agent_inputs, avail_actions, t_env, test_mode=False):
        masked_policies = agent_inputs.clone()
        masked_policies[avail_actions == 0.0] = 0.0

        self.epsilon = self.schedule.eval(t_env)

        if test_mode and self.test_greedy:
            picked_actions = masked_policies.max(dim=2)[1]
        else:
            picked_actions = Categorical(masked_policies).sample().long()

        return picked_actions


REGISTRY["multinomial"] = MultinomialActionSelector


class EpsilonGreedyActionSelector():

    def __init__(self, args):
        self.args = args

        self.schedule = DecayThenFlatSchedule(args.epsilon_start, args.epsilon_finish, args.epsilon_anneal_time,
                                              decay="linear")
        self.epsilon = self.schedule.eval(0)

    def select_action(self, agent_inputs, avail_actions, t_env, test_mode=False):

        # Assuming agent_inputs is a batch of Q-Values for each agent bav
        self.epsilon = self.schedule.eval(t_env)

        if test_mode:
            # Greedy action selection only
            self.epsilon = self.args.evaluation_epsilon

        # mask actions that are excluded from selection
        masked_q_values = agent_inputs.clone()
        masked_q_values[avail_actions == 0.0] = -float("inf")  # should never be selected!

        random_numbers = th.rand_like(agent_inputs[:, :, 0])
        pick_random = (random_numbers < self.epsilon).long()
        random_actions = Categorical(avail_actions.float()).sample().long()

        picked_actions = pick_random * random_actions + (1 - pick_random) * masked_q_values.max(dim=2)[1]
        return picked_actions


REGISTRY["epsilon_greedy"] = EpsilonGreedyActionSelector


class SoftPoliciesSelector():

    def __init__(self, args):
        self.args = args

    def select_action(self, agent_inputs, avail_actions, t_env, test_mode=False):
        m = Categorical(agent_inputs)
        picked_actions = m.sample().long()
        return picked_actions


REGISTRY["soft_policies"] = SoftPoliciesSelector


class SoftOptimisticSelector():

    def __init__(self, args):
        self.args = args
        self.test_greedy = getattr(args, "test_greedy", True)
        self.alpha=self.args.alpha_start
        self.delta_alpha=np.exp(np.log(self.args.alpha/self.alpha)/self.args.alpha_anealing_steps)
        self.epsilon = self.args.sample_epsilon

    def select_action(self, agent_inputs, avail_actions, t_env=0, test_mode=False):
        masked_policies = agent_inputs.clone()
        random_policies = avail_actions.clone().float()
        random_policies += (random_policies.sum(dim=-1,keepdim=True)==0.0).float()
        masked_policies[avail_actions == 0.0] = -1e9
        masked_policies = F.softmax((self.args.alpha if test_mode else self.alpha)*masked_policies,dim=-1)
        masked_policies += (masked_policies.sum(dim=-1,keepdim=True)==0.0).float()

        #if (np.random.rand()<0.0002):
            #print(masked_policies/masked_policies.sum(dim=-1, keepdim=True))

        if test_mode:
            if self.test_greedy:
                picked_actions = masked_policies.max(dim=-1)[1]
            else:
                picked_actions = Categorical(masked_policies).sample().long()
        else:
            random_numbers = th.rand(agent_inputs.shape[:-1]).unsqueeze(-1).expand(agent_inputs.shape).to(agent_inputs.device)
            # print(random_numbers.device, avail_actions.device, masked_policies.device)
            mixed_policies = th.where(random_numbers < self.epsilon, random_policies, masked_policies)
            # print(agent_inputs, avail_actions, masked_policies, mixed_policies)
            picked_actions = Categorical(mixed_policies).sample().long()
            if self.alpha < self.args.alpha:
                self.alpha *= self.delta_alpha
             
        #if np.random.rand()<0.001:
        #    print("select action: ",agent_inputs[0],masked_policies[0],self.alpha,picked_actions[0],test_mode)

        return picked_actions


REGISTRY["soft_optimistic"] = SoftOptimisticSelector

class EpsilonSoftOptimisticSelector():

    def __init__(self, args):
        self.args = args
        self.test_greedy = getattr(args, "test_greedy", True)
        self.schedule = DecayThenFlatSchedule(args.epsilon_start, args.epsilon_finish, args.epsilon_anneal_time,
                                              decay="linear")
        self.epsilon = self.schedule.eval(0)


    def select_action(self, agent_inputs, avail_actions, alpha=None, t_env=0, test_mode=False):
        if alpha is None:
            alpha = self.args.alpha
        masked_policies = agent_inputs.clone()
        random_policies = avail_actions.clone().float()
        random_policies += (random_policies.sum(dim=-1,keepdim=True)==0.0).float()
        masked_policies[avail_actions == 0.0] = -1e9
        masked_policies = F.softmax(alpha*masked_policies,dim=-1)
        masked_policies += (masked_policies.sum(dim=-1,keepdim=True)==0.0).float()

        #if (np.random.rand()<0.0002):
            #print(masked_policies/masked_policies.sum(dim=-1, keepdim=True))

        if test_mode:
            if self.test_greedy:
                picked_actions = masked_policies.max(dim=-1)[1]
            else:
                picked_actions = Categorical(masked_policies).sample().long()
        else:
            if t_env >= 0:
                self.epsilon = self.schedule.eval(t_env)
            else:
                self.epsilon = self.args.sample_epsilon
            random_numbers = th.rand(agent_inputs.shape[:-2]).unsqueeze(-1).unsqueeze(-1).expand(agent_inputs.shape).to(agent_inputs.device)
            # print(random_numbers.device, avail_actions.device, masked_policies.device)
            mixed_policies = th.where(random_numbers < self.epsilon, random_policies, masked_policies)
            # print(agent_inputs, avail_actions, masked_policies, mixed_policies)
            picked_actions = Categorical(mixed_policies).sample().long()
             
        #if np.random.rand()<0.001:
        if not (th.gather(random_policies,-1,index=picked_actions.unsqueeze(-1))>0).all():
            print("select action: ",agent_inputs,avail_actions,masked_policies,random_policies,picked_actions,alpha)
            raise AssertionError
        return picked_actions


REGISTRY["epsilon_soft_optimistic"] = EpsilonSoftOptimisticSelector