# implementation of UCRL2

import math
import copy
import random

import model

import numpy as np
from scipy.stats import chi2

import policy

class ParameterEstimator:
    def __init__(self, model_bounds, n_states, n_actions, U, state_rewards):
        self.model_bounds = model_bounds
        self.U = U
        self.state_rewards = state_rewards

        self.n_states = n_states
        self.n_actions = n_actions

        self.change_counts = [[[0 for k in range(self.n_states)] for j in range(self.n_actions)] for i in range(self.n_states)]
        self.rewards = [[[] for j in range(self.n_actions)] for i in range(self.n_states)]

        self.max_reward = 2
        self.min_reward = -2

    def observe(self, state, next_state, action, time_elapsed, reward):
        # uniformize
        n_steps = max(round(time_elapsed / self.U),1)
        self.change_counts[state][action][state] += n_steps-1
        self.change_counts[state][action][next_state] += 1

        # observe holding cost for each self-transition
        holding_list = [self.U * self.state_rewards.holding_rewards[state] for i in range(n_steps-1)]
        self.rewards[state][action] += holding_list

        self.rewards[state][action].append(reward - sum(holding_list))

    def change_prob_estimate(self, state, action):
        ct = sum(self.change_counts[state][action])

        if ct == 0:
            return [(1/self.n_states) for x in range(self.n_states)]

        return [x/ct for x in self.change_counts[state][action]]

    def change_prob_epsilon(self, state, action, confidence_param):
        ct = sum(self.change_counts[state][action])
        if ct== 0:
            return 2

        inner_term = ((14*self.n_states)/ct)*math.log(2*(self.n_actions)/confidence_param)
        return math.sqrt(inner_term)

    def change_prob_epsilon_kl(self, state, action, confidence_param, n_steps):
        ct = sum(self.change_counts[state][action])
        if ct== 0:
            return float("inf")
        
        B = math.log(2*math.e*(self.n_states**2)*self.n_actions*math.log(n_steps)/confidence_param)
        invt_log = 1/math.log(n_steps)
        coef = self.n_states * (B + math.log(B + invt_log) * (1+(1/(B+invt_log))))

        return coef/n_steps


    def reward_estimate(self, state, action):
        ct = len(self.rewards[state][action])
        if ct == 0:
            return 1
        total = sum(self.rewards[state][action])

        return total/ct

    def reward_epsilon(self, state, action, confidence_param):
        ct = len(self.rewards[state][action])

        inner_term = (7/(2*max(ct,1)))*math.log((2*self.n_actions*self.n_states)/confidence_param)
        return 2*math.sqrt(inner_term) # added in the 2 for reward scaling

    def reward_epsilon_kl(self, state, action, confidence_param):
        ct = len(self.rewards[state][action])

        inner_term = (7/(2*max(ct,1)))*math.log((2*self.n_actions*self.n_states)/confidence_param)
        return 2*math.sqrt(inner_term) # added in the 2 for reward scaling

    def reward_ub(self, state, action, confidence_param):
        return self.reward_estimate(state, action) + self.reward_epsilon(state, action, confidence_param)

    def print(self, confidence_param):
        print("state/action rewards (baseline)")
        print([[self.reward_estimate(state, action) for action in range(self.n_actions)] for state in range(self.n_states)])
        print("state/action reward epsilon")
        print([[self.reward_epsilon(state, action, confidence_param) for action in range(self.n_actions)] for state in range(self.n_states)])
        print("state/action transitions (baseline)")
        print([[self.change_prob_estimate(state, action) for action in range(self.n_actions)] for state in range(self.n_states)])
        print("state/action transition epsilon")
        print([[self.change_prob_epsilon(state, action, confidence_param) for action in range(self.n_actions)] for state in range(self.n_states)])

class Exploration:
    def __init__(self, model_bounds, n_states, n_actions):
        self.model_bounds = model_bounds
        self.n_states = n_states
        self.n_actions = n_actions
        self.sa_visit_counts = [[0 for j in range(self.n_actions)] for i in range(self.n_states)]
        self.sa_visit_counts_in_episode = [[0 for j in range(self.n_actions)] for i in range(self.n_states)]
        self.steps_before_episode = 1
        self.n_episodes = 0

    def observe(self, state, action):
        self.sa_visit_counts[state][action] += 1
        self.sa_visit_counts_in_episode[state][action] += 1

        return (2*self.sa_visit_counts_in_episode[state][action]) >= self.sa_visit_counts[state][action]
    
    def new_episode(self):
        self.sa_visit_counts_in_episode = [[0 for j in range(self.n_actions)] for i in range(self.n_states)]

        self.steps_before_episode = sum([sum(x) for x in self.sa_visit_counts])
        self.n_episodes += 1


def get_eva_next_u(probs, epsilon, u):
    sorted_states = sorted([(x, i) for i, x in enumerate(u)])
    sorted_states = [x[1] for x in sorted_states]

    prob_estimate = copy.deepcopy(probs)
    prob_estimate[sorted_states[-1]] += (epsilon/2)
    total_prob = sum(prob_estimate)

    for state in sorted_states:
        if total_prob <= 1:
            break
        outside_prob = total_prob-prob_estimate[state]
        new_prob = max(0, 1-outside_prob)
        total_prob -= (prob_estimate[state]-new_prob)
        prob_estimate[state] = new_prob

    return sum([x*y for x,y in zip(prob_estimate, u)])

def get_eva_policy(parameter_estimator, model_bounds, confidence_param, n_steps):
    n_states = parameter_estimator.n_states
    n_actions = parameter_estimator.n_actions
    rewards = [[parameter_estimator.reward_ub(state, action, confidence_param) for action in range(n_actions)] for state in range(n_states)]
    prob_estimates = [[parameter_estimator.change_prob_estimate(state, action) for action in range(n_actions)] for state in range(n_states)]
    prob_epsilon = [[parameter_estimator.change_prob_epsilon(state, action, confidence_param) for action in range(n_actions)] for state in range(n_states)]
    values = [0 for x in range(n_states)]
    state_action_mapping = [0 for x in range(n_states)]

    while True:
        new_values = [float("-inf") for x in range(n_states)]

        for state in range(n_states):
            for action in range(n_actions):
                adjacent_probs = copy.deepcopy(prob_estimates[state][action])
                next_u = get_eva_next_u(adjacent_probs, prob_epsilon[state][action], values)
                u_candidate = rewards[state][action] + next_u
                if u_candidate > new_values[state]:
                    state_action_mapping[state] = action
                    new_values[state] = u_candidate

        # check for convergence and update values
        max_change = max([x-y for x,y in zip(new_values, values)])
        min_change = min([x-y for x,y in zip(new_values, values)])

        values = new_values

        if (max_change-min_change) < math.pow(n_steps,-0.5):
            break

    return state_action_mapping


def get_klucr_policy(parameter_estimator, model_bounds, confidence_param, n_steps):
    n_states = parameter_estimator.n_states
    n_actions = parameter_estimator.n_actions
    rewards = [[parameter_estimator.reward_ub_kl(state, action, confidence_param) for action in range(n_actions)] for state in range(n_states)]
    prob_estimates = [[parameter_estimator.change_prob_estimate(state, action) for action in range(n_actions)] for state in range(n_states)]
    prob_epsilon = [[parameter_estimator.change_prob_epsilon_kl(state, action, confidence_param, n_steps) for action in range(n_actions)] for state in range(n_states)]
    values = [0 for x in range(n_states)]
    state_action_mapping = [0 for x in range(n_states)]

    while True:
        new_values = [float("-inf") for x in range(n_states)]

        for state in range(n_states):
            for action in range(n_actions):
                adjacent_probs = copy.deepcopy(prob_estimates[state][action])
                next_u = get_eva_next_u(adjacent_probs, prob_epsilon[state][action], values)
                u_candidate = rewards[state][action] + next_u
                if u_candidate > new_values[state]:
                    state_action_mapping[state] = action
                    new_values[state] = u_candidate

        # check for convergence and update values
        max_change = max([x-y for x,y in zip(new_values, values)])
        min_change = min([x-y for x,y in zip(new_values, values)])

        values = new_values

        if (max_change-min_change) < math.pow(n_steps,-0.5):
            break

    return state_action_mapping
