from typing import Hashable, Sequence, Dict, Any

from coba.learners import VowpalMediator

from memory import MemoryModel

logn = 500

class MemoryToCBLearner:

    def __init__(self, epsilon: float, mem: MemoryModel) -> None:

        assert 0 <= epsilon and epsilon <= 1

        self._epsilon = epsilon
        self._mem     = mem

    @property
    def params(self) -> Dict[str,Any]:
        return { 'family': 'MemToCB','e':self._epsilon, **self._mem.params }

    def predict(self, context: Hashable, actions: Sequence[Hashable]) -> Sequence[float]:
        """Choose which action index to take."""

        rewards = [ self._mem.predict({'x': context, 'a': a}) for a in actions]

        greedy_r = -float('inf')
        greedy_A = []

        for action, mem_value in zip(actions, rewards):

            mem_value = mem_value or 0

            if mem_value == greedy_r:
                greedy_A.append(action)

            if mem_value > greedy_r:
                greedy_r = mem_value
                greedy_A = [action]

        min_p = self._epsilon / len(actions)
        grd_p = (1-self._epsilon)/len(greedy_A)

        return [ grd_p+min_p if a in greedy_A else min_p for a in actions ], len(actions)

    def learn(self, context: Hashable, action: Hashable, reward: float, probability: float, predict_info: Any) -> None:
        """Learn about the result of an action that was taken in a context."""
        n_actions = predict_info
        self._mem.learn({'x': context, 'a': action}, value=reward, weight=1/(n_actions*probability))

class StackedMemLearner:

    def __init__(self, epsilon: float, mem: MemoryModel, X:str, coin:bool, constant:bool) -> None:

        assert 0 <= epsilon and epsilon <= 1

        self._epsilon = epsilon
        self._mem     = mem
        self._args    = (X, coin, constant)

        if X == 'xa':
            args = f"--quiet --cb_explore_adf --epsilon {epsilon} --ignore_linear x --interactions xa --random_seed {1}"

        if X == 'xxa':
            args = f"--quiet --cb_explore_adf --epsilon {epsilon} --ignore_linear x --interactions xa --interactions xxa --random_seed {1}"

        if coin: 
            args += ' --coin'

        if not constant:
            args += " --noconstant"

        self._vw = VowpalMediator().init_learner(args,4)

    @property
    def params(self) -> Dict[str,Any]:
        return { 'family': 'Stacked', 'e': self._epsilon, **self._mem.params, "other": self._args }

    def predict(self, context: Hashable, actions: Sequence[Hashable]) -> Sequence[float]:
        """Choose which action index to take."""

        memories = [ self._mem.predict({'x':context, 'a':a}) for a in actions ]        
        adfs     = [ {'a':a, 'm':m } for a,m in zip(actions,memories)    ]
        probs    = self._vw.predict(self._vw.make_examples({'x': context}, adfs, None))

        return probs, (actions,adfs)

    def learn(self, context: Hashable, action: Hashable, reward: float, probability: float, predict_info: Any) -> None:
        """Learn about the result of an action that was taken in a context."""

        actions,adfs = predict_info
        n_actions    = len(actions)
        labels       = self._labels(actions, action, reward, probability)

        self._mem.learn({'x':context,'a':action}, value=reward, weight=1/(n_actions*probability))
        self._vw.learn(self._vw.make_examples({'x':context}, adfs, labels))

    def _labels(self,actions,action,reward:float,prob:float):
        return [ f"{i+1}:{round(-reward,5)}:{round(prob,5)}" if a == action else None for i,a in enumerate(actions)]

    def __reduce__(self):
        return (type(self), (self._epsilon, self._mem, *self._args))
