import utils

import copy

class Policy:
    def __init__(self, model):
        self.model = model
        self.n_states = model.n_states
        self.n_server_types = model.n_server_types
        self.n_customer_types = model.n_customer_types

        self.limiting_types = []

    def get_limiting_types(self, state):
        return self.limiting_types[state]

    def get_improved_policy(self):
        bias, gain = self.model.get_gain_bias(self)
        #print(f"Current gain: {gain}")
        #print(f"Bias: {bias}")
        #print(f"Generator: {self.model.get_generator_matrix(self)}")
        #print(f"Reward: {self.model.get_reward_vector(self)}")
        #print(f"Probs: {self.model.get_steady_state_probs(self)}")

        new_policy = Policy(self.model)
        new_policy.limiting_types = copy.deepcopy(self.limiting_types)

        for state in range(self.n_states):
            new_limiting_types = self.model.get_maximal_action(state, bias,gain, self.limiting_types[state])
            new_policy.limiting_types[state] = new_limiting_types
            #if not all([x == y for x,y in zip(new_limiting_types, self.limiting_types[state])]):
            #    new_policy.limiting_types[state] = new_limiting_types
            #    break

        return new_policy

    def clean(self):
        new_policy = Policy(self.model)
        new_policy.limiting_types = copy.deepcopy(self.limiting_types)

        # go upwards from 0 and clean up customer arrivals
        is_recurrent = True
        for state in range(self.model.capacities[1], self.model.n_states):
            if not is_recurrent:
                new_policy.limiting_types[state][0] = -1
            elif self.limiting_types[state][0] == -1:
                is_recurrent = False
                new_policy.limiting_types[state][0] = -1

        is_recurrent = True
        for state in range(self.model.capacities[1], -1, -1):
            if not is_recurrent:
                new_policy.limiting_types[state][1] = -1
            elif self.limiting_types[state][1] == -1:
                is_recurrent = False
                new_policy.limiting_types[state][1] = -1
        return new_policy


    def act(self, state, transition_type):
        if transition_type == 0:
            return True
        if transition_type < 0:
            server_type = (-transition_type)-1
            limiting_type = self.limiting_types[state][1]
            if limiting_type == -1:
                return False
            return self.model.state_rewards.server_rewards[state][server_type] >= self.model.state_rewards.server_rewards[state][limiting_type]
        if transition_type > 0:
            customer_type = transition_type-1
            limiting_type = self.limiting_types[state][0]
            if limiting_type == -1:
                return False
            return self.model.state_rewards.customer_rewards[state][customer_type] >= self.model.state_rewards.customer_rewards[state][limiting_type]

    def __eq__(self, other):
        for x, y in zip(self.limiting_types, other.limiting_types):
            if x[0] != y[0] or x[1] != y[1]:
                return False
        return True

    @staticmethod
    def full_acceptance_policy(model):
        out_policy = Policy(model)

        out_policy.limiting_types = []
        for state in range(model.n_states):
            limiting_customer_type = utils.argmin(model.state_rewards.customer_rewards[state])
            limiting_server_type = utils.argmin(model.state_rewards.server_rewards[state])
            out_policy.limiting_types.append([limiting_customer_type, limiting_server_type])

        return out_policy

    @staticmethod
    def full_rejection_policy(model):
        out_policy = Policy(model)

        out_policy.limiting_types = [[-1,-1] for state in range(model.n_states)]

        return out_policy
