from functools import partial
import os
import pickle as pkl
from collections.abc import MutableMapping
from datetime import datetime
from itertools import product
from functools import partial
import tqdm

import jax
import jax.numpy as jnp
import optax
import jaxopt

from torch.utils import data

from .model_loader import construct_net


class DeepONet:
    
    def __init__(self, branch_input_dim, trunk_input_dim, hidden_layers=2, hidden_dim=128, arch=None):
        self.branch_input_dim = branch_input_dim
        self.trunk_input_dim = trunk_input_dim
        self.branch_net = construct_net(input_dim=self.branch_input_dim, hidden_layers=hidden_layers, hidden_dim=hidden_dim, 
                                        output_dim=hidden_dim, arch=arch)[0]
        self.trunk_net = construct_net(input_dim=self.trunk_input_dim, hidden_layers=hidden_layers, hidden_dim=hidden_dim, 
                                       output_dim=hidden_dim, arch=arch)[0]

    def init(self, rng_key=jax.random.PRNGKey(42)):
        keyb, keyt = jax.random.split(rng_key)
        branch_params = self.branch_net.init(keyb, jnp.empty((1, self.branch_input_dim)))
        trunk_params = self.trunk_net.init(keyt, jnp.empty((1, self.trunk_input_dim)))
        return dict(branch=branch_params, trunk=trunk_params)

    # Define DeepONet architecture
    def apply(self, params, branch_in, trunk_in):
        B = self.branch_net.apply(params['branch'], branch_in)
        T = self.trunk_net.apply(params['trunk'], trunk_in)
        return jnp.sum(B * T, axis=-1, keepdims=True)

    def apply_single_branch(self, params, branch_in, trunk_in):
        return jax.vmap(lambda x_: self.apply(params, branch_in, x_))(trunk_in)