from abc import ABC, abstractmethod
import jax.numpy as jnp
import flax.linen as nn
import orthax
import optax
import jax
from jax import vmap, jit, grad
import jax.random as random
from functools import partial
from pprint import pformat

from .gating import GatingNetwork
from .poly import PNets
from .custom_types import Array, Any

class Level(ABC):
    def __init__(self,
                 key: jax.random.PRNGKey,
                 net_setup_params : dict,
                 problem_params : dict,
                 device : Any,
                 logger : Any
                 ):
        self.net_setup_params = net_setup_params
        self.problem_params = problem_params
        self.x = self.problem_params["x"]
        self.dim = self.x.shape[-1]
        self.key = self.net_setup_params["key"]
        self.device = device
        self.logger = logger
        self.setup_level_specifics()


    @abstractmethod
    def setup_level_specifics(self):
        pass

    def gate_net(self, params, x):
        return self.gating_model.apply(params, x)

    def basis_net_c(self, params, x):
        if self.basis_choice == "classical":
            basis_out = [self.vander_fn(x).ravel() for _ in range(len(self.poly_model))]
        elif self.basis_choice == "mlp":
            basis_out = [self.poly_model[i].apply(params[i], x) for i in range(len(self.poly_model))]
        return jnp.asarray(basis_out)

    def print_cfg(self, poly_cfg, gating_cfg):
        if jax.process_index() == 0:
            self.logger.info("===================================================")
            self.logger.info(f"level={self.level}:")
            self.logger.info("\n" + pformat({"poly_cfg" : poly_cfg, "gating_cfg" : gating_cfg}, indent=1))

class LevelCoarse(Level):
    def setup_level_specifics(self):
        self.level = "coarse"
        # setup/config dictionaries
        poly_cfg = self.net_setup_params["poly"][self.level]
        self.basis_choice = poly_cfg["basis_choice"]
        gating_cfg = self.net_setup_params["gating"][self.level]
        self.print_cfg(poly_cfg, gating_cfg)


        key_gate, key_poly = random.split(self.key)
        dput = lambda params : jax.device_put(params, self.device)
        if self.basis_choice == "classical":
            assert gating_cfg["num_partitions"] == poly_cfg["num_partitions"]

            self.basis_size = poly_cfg["basis_size"]
            self.poly_type = poly_cfg["poly_type"]
            if self.poly_type == "monomial":
                self.vander_fn = orthax.polynomial.polyvander
            elif self.poly_type == "legendre":
                self.vander_fn = orthax.legendre.legvander
            elif self.poly_type == "chebyshev":
                self.vander_fn = orthax.chebyshev.chebvander
            else:
                raise ValueError(f"Unsupported poly_type: {self.poly_type}")

            self.vander_fn = partial(self.vander_fn, deg=self.basis_size - 1)

            # Initialize coarse poly network
            self.poly_model = [PNets(**poly_cfg) for _ in range(poly_cfg["num_partitions"])]
            self.poly_params = 0.
            self.coeffs = [jnp.ones(self.basis_size) for _ in range(poly_cfg["num_partitions"])]
        elif self.basis_choice == "mlp":
            self.basis_size = poly_cfg["basis_size"]
            # Initialize coarse poly network
            poly_keys = random.split(key_poly, poly_cfg["num_partitions"])
            self.poly_model = [PNets(**poly_cfg) for _ in range(poly_cfg["num_partitions"])]
            self.poly_params = [dput(i.init(poly_keys[ii], jnp.ones(self.dim,))) for ii, i in enumerate(self.poly_model)]
            self.coeffs = [jnp.ones(self.basis_size) for _ in range(poly_cfg["num_partitions"])]
        else:
            raise ValueError(f"Unsupported basis_choice: {self.basis_choice}")

        self.coeffs = dput(jnp.asarray(self.coeffs))
        # Initialize coarse gating network
        self.gating_model = GatingNetwork(**gating_cfg)
        self.gating_params = dput(self.gating_model.init(key_gate, self.x))

class LevelFine(Level):
    def setup_level_specifics(self):
        self.level = "fine"
        poly_cfg = self.net_setup_params["poly"][self.level]
        self.basis_choice = poly_cfg['basis_choice']
        self.basis_size = poly_cfg["basis_size"]
        gating_cfg = self.net_setup_params["gating"][self.level]
        self.print_cfg(poly_cfg, gating_cfg)
        c_num_partitions = self.net_setup_params["gating"]["coarse"]["num_partitions"]

        key_gate_main, key_poly_main = random.split(self.key)
        dput = lambda params : jax.device_put(params, self.device)
        if self.basis_choice == "classical":
            assert gating_cfg["num_partitions"] == poly_cfg["num_partitions"]

            self.poly_type = poly_cfg["poly_type"]
            if self.poly_type == "monomial":
                self.vander_fn = orthax.polynomial.polyvander
            elif self.poly_type == "legendre":
                self.vander_fn = orthax.legendre.legvander
            elif self.poly_type == "chebyshev":
                self.vander_fn = orthax.chebyshev.chebvander
            else:
                raise ValueError(f"Unsupported poly_type: {self.poly_type}")

            self.vander_fn = partial(self.vander_fn, deg=self.basis_size - 1)
            # Initialize fine poly network
            self.poly_model = [[PNets(**poly_cfg)
                                    for _ in range(poly_cfg["num_partitions"])]
                                    for _ in range(c_num_partitions)]
            self.poly_params = 0.
            self.coeffs = [[jnp.ones(self.basis_size)
                                for _ in range(poly_cfg["num_partitions"])]
                                for _ in range(c_num_partitions)]
        elif self.basis_choice == "mlp":

            poly_keys_flat = random.split(key_poly_main, c_num_partitions * poly_cfg["num_partitions"])
            poly_keys_nested = jnp.reshape(poly_keys_flat,
                                          (c_num_partitions, poly_cfg["num_partitions"], -1))

            self.poly_model = [[PNets(**poly_cfg)
                                    for _ in range(poly_cfg["num_partitions"])]
                                    for _ in range(c_num_partitions)]
            self.poly_params = [[dput(j.init(poly_keys_nested[ii, jj], jnp.ones(self.dim,)))
                                 for jj, j in enumerate(i)]
                                 for ii, i in enumerate(self.poly_model)]
            self.basis_size = self.basis_size + self.net_setup_params["poly"]["coarse"]["basis_size"]
            self.coeffs = [[jnp.ones(self.basis_size)
                                for _ in range(poly_cfg["num_partitions"])]
                                for _ in range(c_num_partitions)]
        else:
            raise ValueError(f"Unsupported basis_choice: {self.basis_choice}")

        self.coeffs = dput(jnp.asarray(self.coeffs))
        # Initialize fine gating network
        # Initilize gating identically for each partition
        self.gating_model = [GatingNetwork(**gating_cfg) for _ in range(c_num_partitions)]
        # different init
        #gate_keys = random.split(key_gate_main, c_num_partitions)
        #self.gating_params = [dput(m.init(gate_keys[i], self.x)) for i, m in enumerate(self.gating_model)]

        # identical
        self.gating_params = [dput(m.init(key_gate_main, self.x)) for i, m in enumerate(self.gating_model)]

    def gate_net(self, params, x):
        gate_f = [self.gating_model[i].apply(params[i], x) for i in range(len(self.gating_model))]
        return jnp.asarray(gate_f)

    def basis_net_f(self, params_c, params_f, basis_net_c, x):
        if self.basis_choice == "classical":
            basis_out = [[self.vander_fn(x).ravel()
                          for _ in range(len(self.poly_model[i]))]
                          for i in range(len(self.poly_model))]
            return jnp.asarray(basis_out)
        elif self.basis_choice == "mlp":
            basis_f = [[self.poly_model[i][j].apply(params_f[i][j], x)
                        for j in range(len(self.poly_model[i]))]
                        for i in range(len(self.poly_model))]
            basis_f = jnp.asarray(basis_f)
            basis_c = basis_net_c(params_c, x)[:,jnp.newaxis,:]
            basis_c = jnp.repeat(basis_c, repeats=self.net_setup_params["gating"]["fine"]["num_partitions"],
                                 axis=1)
            return jnp.concatenate([basis_f, basis_c], axis=-1)
