from collections import defaultdict
import typing
import numpy as np

from core.environment import Environment
if typing.TYPE_CHECKING:
    from typing import Any, List, Tuple, Dict

class AlgorithmUCB:
    """"
    Implementation of the UCB algorithm for the MDP.
    Runs in an unconstrained environment.
    """

    def __init__(
        self, environment: Environment, time_horizon: int, delta: float, c: int
    ):
        self.environment = environment
        self.horizon = time_horizon
        self.delta = delta
        self.c = c
        self.policy, self.empirical_rewards = self.initialize_policy_rewards()
        self._sa_counter = self.initialize_counters()
        self.first_ucb_timestep = 0
        self.second_ucb_timestep = 0

    def initialize_policy_rewards(self):
        """Initializes the policy and the empirical rewards to 0."""
        zeros = np.zeros((self.environment.n_states, self.environment.n_actions))
        self.policy = dict(zip(self.environment.sa_ind.keys(), zeros.flatten()))
        self.empirical_rewards = dict(
            zip(self.environment.sa_ind.keys(), zeros.flatten())
        )
        return self.policy, self.empirical_rewards

    def initialize_counters(self):
        """Initializes the counters of visits for every pair (x,a) (sa_counter) in the MDP to 0."""
        sa = self.environment.sa_ind.keys()
        self._sa_counter = dict()
        for pair in sa:
            self._sa_counter[pair] = 0

        return self._sa_counter

    def _update_counters(self, path):
        """Updates the counters of visits for every pair (x,a) (sa_counter) in the path."""
        for x, a, x_p in path:
            self._sa_counter[(x, a)] += path[(x, a, x_p)]
        return self._sa_counter

    def round(self, timestep: int):
        """Play a round of the UCB algorithm."""
        
        zeros = np.zeros((self.environment.n_states, self.environment.n_actions))
        ucb_values = dict(zip(self.environment.sa_ind.keys(), zeros.flatten()))

        for x in range(self.environment.layers[-1][-1]):
            if x == 0:
                adjusted_timestep = timestep
            elif x == 1:
                adjusted_timestep = self.first_ucb_timestep
            else:
                adjusted_timestep = self.second_ucb_timestep 

            # Calculate UCB values
            for a in self.environment.actions[x]:
                if self._sa_counter[(x, a)] == 0:
                    ucb_values[(x, a)] = np.inf
                else:
                    ucb_values[(x, a)] = self.empirical_rewards[
                        (x, a)
                    ] + self.c * np.sqrt(
                        np.log(adjusted_timestep + 1) / (self._sa_counter[(x, a)] + 1)
                    )

        # Construct policy setting to 1 for every state the action with highest UCB value
        for x, a in self.policy:
            self.policy[(x, a)] = 0

        for x in range(self.environment.layers[-1][-1]):
            for a in self.environment.actions[x]:
                if ucb_values[(x, a)] == max(
                    [
                        ucb_values[(x_s, a)]
                        for (x_s, a) in self.environment.sa_ind
                        if x_s == x
                    ]
                ):
                    self.policy[(x, a)] = 1
                    break  # if there is a tie, assign 1 just to the first action

        # Observe rewards, path, constraints from environment
        path, rewards, _ = self.environment.play_policy(self.policy, timestep)

        # Update counters
        path_dict = defaultdict(int)
        for triple in path:
            path_dict[triple] += 1
        self._update_counters(path_dict)

        # Update empirical rewards
        for x, a, x_p in path:
            self.empirical_rewards[(x, a)] += (
                rewards[self.environment.sa_ind[(x, a)]]
                - self.empirical_rewards[(x, a)]
            ) / (self._sa_counter[(x, a)] + 1)

        # Update first and second ucb timestep
        self.first_ucb_timestep += 1 if path[0][1] == 1 else 0
        self.second_ucb_timestep += 1 if path[0][1] == 2 else 0

        return self.policy, rewards, list(ucb_values.values())
