from abc import ABC, abstractmethod
from typing import Dict, Any, Sequence

from coba.learners import VowpalMediator

MemVal = Any

class MemoryModel(ABC):

    def params(self) -> Dict[str,Any]:
        return {}

    @abstractmethod
    def predict(self, features) -> MemVal:
        ...
    
    @abstractmethod
    def learn(self, features, value, weight) -> None:
        ...

class EMT(MemoryModel):

    def __init__(self, split:int = 100, scorer:int=3, router:int=2, bound:int=-1, interactions: Sequence[str]=[], rng : int = 1337) -> None:

        self._args = (split, scorer, router, bound, interactions, rng)

        vw_args = [
            "--eigen_memory_tree",
            f"--tree {bound}",
            f"--leaf {split}",
            f"--scorer {scorer}",
            f"--router {router}",
            "--min_prediction 0",
            "--max_prediction 3",
            "--coin",
            "--noconstant",
            f"--power_t {0}",
            "--loss_function squared",
            f"-b {26}",
            "--initial_weight 0",
            *[ f"--interactions {i}" for i in interactions ]
        ]

        init_args = f"{' '.join(vw_args)} --quiet --random_seed {rng}"
        label_type = 2

        self._vw = VowpalMediator().init_learner(init_args, label_type)

    def __reduce__(self) -> str | tuple[Any, ...]:
        return (EMT, self._args)

    @property
    def params(self) -> Dict[str,Any]:
        keys = ['split', 'scorer', 'router', 'bound', 'X']
        return { 'type':'EMT', **dict(zip(keys,self._args))}

    def predict(self, features) -> MemVal:
        ex = self._vw.make_example(features, None)
        pr = self._vw.predict(ex)
        return pr

    def learn(self, features, value: MemVal, weight: float):
        self._vw.learn(self._vw.make_example(features, f"{value} {weight}"))

class CMT(MemoryModel):

    def __init__(self, n_nodes:int=100, leaf_multiplier:int=15, dream_repeats:int=5, alpha:float=0.5, coin:bool = True, interactions: Sequence[str]=[], rng : int = 1337) -> None:

        self._args = (n_nodes, leaf_multiplier, dream_repeats, alpha, coin, interactions, rng)

        vw_args = [
            f"--memory_tree {n_nodes}",
            "--learn_at_leaf",
            "--online 1",
            f"--leaf_example_multiplier {leaf_multiplier}",
            f"--dream_repeats {dream_repeats}",
            f"--alpha {alpha}",
            f"--power_t {0}",
            f"-b {25}",
            *[ f"--interactions {i}" for i in interactions ]
        ]

        if coin: vw_args.append("--coin")

        init_args = f"{' '.join(vw_args)} --quiet --random_seed {rng}"
        label_type = 2

        self._vw = VowpalMediator().init_learner(init_args, label_type)

    def __reduce__(self) -> str | tuple[Any, ...]:
        return (CMT, self._args)

    @property
    def params(self) -> Dict[str,Any]:
        keys = ['nodes','multiplier','dreams','alpha','coin','X']
        return { 'type':'CMT', **dict(zip(keys,self._args)) }

    def predict(self, features) -> MemVal:
        ex = self._vw.make_example(features, None)
        pr = self._vw.predict(ex)
        return pr

    def learn(self, features, value: MemVal, weight: float):
        self._vw.learn(self._vw.make_example(features, f"{value} {weight}"))
