from dataclasses import dataclass
from typing import List
import jax
import jax.numpy as jnp
import jax.random
from jax.tree_util import tree_map

@dataclass
class LinearTabularApproximation:

    feature_mat: jnp.array

    def apply_fn(self, params):
        return self.feature_mat @ params

@dataclass
class Identity:

    def apply_fn(self, params):
        return params


@dataclass
class LinearGameApproximation:

    feature_mats: List[jnp.array]

    def apply_fn(self, params: list[jnp.array]):
        
        outputs = tree_map(lambda ft,p: ft @ p, self.feature_mats, params)
        return tuple(outputs)
    
@dataclass
class TinyMLP:

    player_alphas: List[tuple]

    def mlp(x, alphas):
        x = jax.nn.celu(alphas[0]*x)
        x = jax.nn.sigmoid(alphas[1]*x)
        return x

    def apply_fn(self, params: list[jnp.array]):
        
        outputs = tree_map(TinyMLP.mlp, params, self.player_alphas)
        return tuple(outputs)
    
@dataclass
class SoftmaxMLP:

    player_alphas: List[tuple]

    def mlp(x, alphas):
        x = jax.nn.celu(alphas[0] @ x)
        x = jax.nn.softmax(alphas[1] @ x)
        return x

    def apply_fn(self, params):
        outputs = [SoftmaxMLP.mlp(p, alphas) for p, alphas in zip(params, self.player_alphas)]
        return tuple(outputs)
    
@dataclass
class SoftmaxGameApproximation:

    def apply_fn(self, params):

        outputs = [jax.nn.softmax(p) for p in params]
        return tuple(outputs)

def random_split_like_tree(rng_key, target=None, treedef=None):
    if treedef is None:
        treedef = jax.tree_structure(target)
    keys = jax.random.split(rng_key, treedef.num_leaves)
    return jax.tree_unflatten(treedef, keys)


def tree_random_alphas_like(rng_key, shapes, minval=-1, maxval=1):
    keys_tree = random_split_like_tree(rng_key, shapes)
    return jax.tree_map(
        lambda l, k: jax.random.uniform(k, l, minval=minval, maxval=maxval),
        shapes,
        keys_tree,
    )

