import math
import random

import policy
import model
import ac
import lp
import ac_ablation

import ucrl

PI_ITERATIONS = 1000
INITIAL_CONFIDENCE_PARAM = 10

class Agent:
    def __init__(self):
        pass

    def act(self, state, transition_type):
        if transition_type == 0:
            return True
        return random.choice([True, False])

    def observe(self,state, next_state, action, transition_type, reward, time_elapsed):
        pass

class DeterministicAgent(Agent):
    def __init__(self, model, policy):
        super().__init__()
        self.model = model
        self.policy = policy

    def act(self, state, transition_type):
        if transition_type == 0:
            return True
        
        return self.policy.act(state, transition_type)
    
    def get_estimated_gain(self):
        return self.model.get_gain_bias(self.policy)[1]

class KnownPOAgent(Agent):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.policy = policy.Policy.full_acceptance_policy(model)
        self.policy, self.gain, _, _ = lp.get_optimal_policy(self.model)
        found_policy = False
        #input ("Here we are just doing PI")
        #for i in range(PI_ITERATIONS):
        #    new_policy = self.policy.get_improved_policy()
        #    if self.policy == new_policy:
        #        found_policy = True
        #        break
        #    self.policy = new_policy
        _, self.gain = self.model.get_gain_bias(self.policy)

    def act(self, state, transition_type):
        if transition_type == 0:
            return True
        
        return self.policy.act(state, transition_type)
    
    def get_estimated_gain(self):
        return self.gain

class ACRLAgent(Agent):
    def __init__(self, model_bounds, state_rewards):
        # make sure there's +1 classes for handling abandonments
        super().__init__()
        self.parameter_estimator = ac.ParameterEstimator(model_bounds)
        self.state_rewards = state_rewards
        self.model_bounds = model_bounds
        self.exploration = ac.Exploration(model_bounds)

        self.verbose_pi = False


        self.initial_confidence_param = INITIAL_CONFIDENCE_PARAM
        self.model, failed = ac.generate_extended_model(model_bounds, self.parameter_estimator, self.state_rewards, self.initial_confidence_param)
        self.policy = policy.Policy.full_acceptance_policy(self.model)
        self.n_policies = 1
        self.gain = float("-inf")

        self.v_basis = None
        self.c_basis = None

        self.update_policy()

    def update_policy(self):
        #self.policy, self.gain, self.v_basis, self.c_basis = lp.get_optimal_policy(self.model, self.v_basis, self.c_basis)
        found_policy = False
        if self.verbose_pi:
            print("-------------------------------")
            print("Beginning Policy Improvement")
        for i in range(PI_ITERATIONS):
            new_policy = self.policy.get_improved_policy()
            if self.policy == new_policy:
                found_policy = True
                break
            self.policy = new_policy
            _, gain = self.model.get_gain_bias(new_policy)
            if self.verbose_pi:
                print(f"gain: {gain}")
                print(f"limiting_types: {new_policy.limiting_types}")
                print(f"transition_rates: {[self.model.get_transition_rates(x, new_policy.limiting_types[x]) for x in range(self.model.n_states)]}")
                print(f"rewards: {[self.model.get_mean_reward(x, new_policy.limiting_types[x]) for x in range(self.model.n_states)]}")
        _, self.gain = self.model.get_gain_bias(self.policy)
        if self.verbose_pi:
            print("Policy improvement finished.")
            print("-------------------------------")
        self.policy = self.policy.clean()

    def get_confidence_param(self):
        return self.initial_confidence_param/self.exploration.steps_before_episode

    def act(self, state, transition_type):
        if transition_type == 0:
            return True
        
        return self.policy.act(state, transition_type)

    def observe(self,state, next_state, action, transition_type, reward, time_elapsed):
        self.parameter_estimator.observe(state, transition_type, time_elapsed)
        if self.exploration.observe(state):
            self.exploration.new_episode()
            print(f"new episode. total episodes: {self.exploration.n_episodes}")
            model, failed = ac.generate_extended_model(self.model_bounds, self.parameter_estimator, self.state_rewards, self.initial_confidence_param/self.exploration.steps_before_episode)
            self.model = model
            self.policy.model = self.model
            self.update_policy()
            #if self.exploration.n_episodes == 20:
            #    raise Exception("stop")
    
    def get_estimated_gain(self):
        return self.gain

class AblationACRLAgent(Agent):
    def __init__(self, model_bounds, state_rewards):
        # make sure there's +1 classes for handling abandonments
        super().__init__()
        self.parameter_estimator = ac.ParameterEstimator(model_bounds)
        self.state_rewards = state_rewards
        self.model_bounds = model_bounds
        self.exploration = ac.Exploration(model_bounds)
        self.verbose_pi = False


        self.initial_confidence_param = INITIAL_CONFIDENCE_PARAM
        self.model, failed = ac_ablation.generate_extended_model_ablation(model_bounds, self.parameter_estimator, self.state_rewards, self.initial_confidence_param)
        self.policy = policy.Policy.full_acceptance_policy(self.model)
        self.n_policies = 1
        self.gain = float("-inf")
        self.v_basis = None
        self.c_basis = None

        self.update_policy()

    def update_policy(self):
        #self.policy, self.gain, self.v_basis, self.c_basis = lp.get_optimal_policy(self.model, self.v_basis, self.c_basis)
        found_policy = False
        if self.verbose_pi:
            print("-------------------------------")
            print("Beginning Policy Improvement")
        for i in range(PI_ITERATIONS):
            new_policy = self.policy.get_improved_policy()
            if self.policy == new_policy:
                found_policy = True
                break
            self.policy = new_policy
            bias, gain = self.model.get_gain_bias(new_policy)
            if self.verbose_pi:
                print(f"gain: {gain}")
                print(f"bias: {bias}")
                print(f"limiting_types: {new_policy.limiting_types}")
                print(f"transition_rates: {[self.model.get_transition_rates(x, new_policy.limiting_types[x]) for x in range(self.model.n_states)]}")
                print(f"rewards: {[self.model.get_mean_reward(x, new_policy.limiting_types[x]) for x in range(self.model.n_states)]}")
        _, self.gain = self.model.get_gain_bias(self.policy)
        if self.verbose_pi:
            print("Policy improvement finished.")
            print("-------------------------------")
        self.policy = self.policy.clean()

    def get_confidence_param(self):
        return self.initial_confidence_param/self.exploration.steps_before_episode

    def act(self, state, transition_type):
        if transition_type == 0:
            return True
        
        return self.policy.act(state, transition_type)

    def observe(self,state, next_state, action, transition_type, reward, time_elapsed):
        self.parameter_estimator.observe(state, transition_type, time_elapsed)
        if self.exploration.observe(state):
            self.exploration.new_episode()
            model, failed = ac_ablation.generate_extended_model_ablation(self.model_bounds, self.parameter_estimator, self.state_rewards, self.initial_confidence_param/self.exploration.steps_before_episode)
            self.model = model
            self.policy.model = self.model
            self.update_policy()
            #if self.exploration.n_episodes == 20:
            #    raise Exception("stop")
    
    def get_estimated_gain(self):
        return self.gain

class AlternateUCRLAgent(Agent):
    def __init__(self, model_bounds, state_rewards):
        super().__init__()
        self.parameter_estimator = ucrl.ParameterEstimator(model_bounds)
        self.state_rewards = state_rewards
        self.model_bounds = model_bounds
        self.exploration = ucrl.Exploration(model_bounds)

        self.initial_confidence_param = INITIAL_CONFIDENCE_PARAM
        
        self.policy = [1 for x in range(self.model_bounds.n_states * self.model_bounds.n_transitions)]
        self.update_policy()
        
        self.last_observation = []

    def update_policy(self):
        confidence_param = self.initial_confidence_param/self.exploration.steps_before_episode

        self.policy = ucrl.get_eva_policy(self.parameter_estimator, self.model_bounds, confidence_param, self.exploration.steps_before_episode)

    def state_conversion(self, state, transition_type):
        return (state*self.model_bounds.n_transitions)+(transition_type+self.model_bounds.n_classes[1])

    def act(self, state, transition_type):
        new_state = self.state_conversion(state, transition_type)
        if len(self.last_observation) > 0:
            self.analyze_transition(new_state)

        if transition_type == 0:
            return True

        return self.policy[new_state] == 1

    def analyze_transition(self, new_state):
        state, action, time_elapsed, reward = self.last_observation
        self.parameter_estimator.observe(state, new_state, action, time_elapsed, reward)
        if self.exploration.observe(state, action):
            self.exploration.new_episode()
            self.update_policy()

    def observe(self, state, next_state, action, transition_type, reward, time_elapsed):
        # we don't *really* know the next state here
        conv_state = self.state_conversion(state, transition_type)
        self.last_observation = [conv_state, action, time_elapsed, reward]

    def print(self):
        confidence_param = self.initial_confidence_param/self.exploration.steps_before_episode
        print("-------------------------------------------")
        print("Policy:")
        print(self.policy)
        print("-------------------------------------------")
        print("Beliefs:")
        print(self.parameter_estimator.print(confidence_param))

class UCRLAgent(Agent):
    def __init__(self, model_bounds, state_rewards):
        super().__init__()

        self.state_rewards = state_rewards
        self.model_bounds = model_bounds
        self.U = self.model_bounds.customer_ub + self.model_bounds.abandonment_ub + self.model_bounds.server_ub
        self.initial_confidence_param = INITIAL_CONFIDENCE_PARAM

        self.exploration = ucrl.Exploration(model_bounds, model_bounds.n_states, model_bounds.n_actions)
        self.parameter_estimator = ucrl.ParameterEstimator(model_bounds, model_bounds.n_states, model_bounds.n_actions, self.U, self.state_rewards)
        
        self.policy = [1 for x in range(self.model_bounds.n_states)]
        self.update_policy()
        
        self.last_observation = []

        # enumerate the action space
        self.expanded_as = []
        for customer_type in range(-1, self.model_bounds.n_classes[0]):
            for server_type in range(-1, self.model_bounds.n_classes[1]):
                self.expanded_as.append([customer_type, server_type])

    def update_policy(self):
        confidence_param = self.initial_confidence_param/self.exploration.steps_before_episode

        self.policy = ucrl.get_eva_policy(self.parameter_estimator, self.model_bounds, confidence_param, self.exploration.steps_before_episode)

    def action_conversion(self, action):
        return self.expanded_as[action]

    def accept_customer(self, state, action, customer_type):
        limiting_customer, limiting_server = self.action_conversion(action)

        if limiting_customer == -1:
            return False 
        return self.state_rewards.customer_rewards[state][customer_type] >= self.state_rewards.customer_rewards[state][limiting_customer]

    def accept_server(self, state, action, server_type):
        limiting_customer, limiting_server = self.action_conversion(action)

        if limiting_server == -1:
            return False 
        return self.state_rewards.server_rewards[state][server_type] >= self.state_rewards.server_rewards[state][limiting_server]

    def act(self, state, transition_type):
        action = self.policy[state]
        if transition_type == 0:
            return True
        elif transition_type < 0:
            return self.accept_server(state, action, (-transition_type)-1)
        return self.accept_customer(state, action, transition_type-1)

    def observe(self, state, next_state, accepted, transition_type, reward, time_elapsed):
        action = self.policy[state]
        self.parameter_estimator.observe(state, next_state, action, time_elapsed, reward)
        if self.exploration.observe(state, action):
            self.exploration.new_episode()
            self.update_policy()

    def print(self):
        confidence_param = self.initial_confidence_param/self.exploration.steps_before_episode
        print("-------------------------------------------")
        print("Policy:")
        print(self.policy)
        print("-------------------------------------------")
        print("Beliefs:")
        print(self.parameter_estimator.print(confidence_param))


class LearnerAgent:
    def __init__(self, model_bounds, state_rewards, learner):
        super().__init__()

        self.state_rewards = state_rewards
        self.model_bounds = model_bounds
        self.U = self.model_bounds.customer_ub + self.model_bounds.abandonment_ub + self.model_bounds.server_ub
        self.initial_confidence_param = INITIAL_CONFIDENCE_PARAM

        self.exploration = ucrl.Exploration(model_bounds, model_bounds.n_states, model_bounds.n_actions)
        self.parameter_estimator = ucrl.ParameterEstimator(model_bounds, model_bounds.n_states, model_bounds.n_actions, self.U, self.state_rewards)
        
        self.learner = learner
        self.reward_norm = 2
        
        self.last_action = 0

        # enumerate the action space
        self.expanded_as = []
        for customer_type in range(-1, self.model_bounds.n_classes[0]):
            for server_type in range(-1, self.model_bounds.n_classes[1]):
                self.expanded_as.append([customer_type, server_type])

    def update_policy(self):
        confidence_param = self.initial_confidence_param/self.exploration.steps_before_episode

        self.policy = ucrl.get_eva_policy(self.parameter_estimator, self.model_bounds, confidence_param, self.exploration.steps_before_episode)

    def action_conversion(self, action):
        return self.expanded_as[action]

    def accept_customer(self, state, action, customer_type):
        limiting_customer, limiting_server = self.action_conversion(action)

        if limiting_customer == -1:
            return False 
        return self.state_rewards.customer_rewards[state][customer_type] >= self.state_rewards.customer_rewards[state][limiting_customer]

    def accept_server(self, state, action, server_type):
        limiting_customer, limiting_server = self.action_conversion(action)

        if limiting_server == -1:
            return False 
        return self.state_rewards.server_rewards[state][server_type] >= self.state_rewards.server_rewards[state][limiting_server]

    def act(self, state, transition_type):
        action = self.learner.play(state)
        self.last_action = action
        if transition_type == 0:
            return True
        elif transition_type < 0:
            return self.accept_server(state, action, (-transition_type)-1)
        return self.accept_customer(state, action, transition_type-1)

    def observe(self, state, next_state, accepted, transition_type, reward, time_elapsed):
        action = self.last_action

        n_next_transitions = 1
        n_self_transitions = max(round(time_elapsed/self.U) - 1,0)

        reward_per_transition = (reward*time_elapsed)/(n_next_transitions + n_self_transitions)
        reward_per_transition = reward_per_transition/self.reward_norm
        new_episode = False
        for i in range(n_self_transitions):
            self.learner.update(state, self.last_action, reward_per_transition, state)
            new_episode = self.exploration.observe(state, action) or new_episode

        self.learner.update(state, self.last_action, reward_per_transition, next_state)
        if self.exploration.observe(state, action) or new_episode:
            self.exploration.new_episode()
            self.learner.new_episode()
