from functools import partial
import os
import pickle as pkl
from collections.abc import MutableMapping
from datetime import datetime
from itertools import product
from functools import partial
import tqdm

import jax
import jax.numpy as jnp
import optax
import jaxopt

from torch.utils import data

from .model_loader import construct_net
from .. import deepxde as dde
from ..icbc_patch import generate_residue
from .deeponet import DeepONet


class DeepONetModified:
    
    def __init__(self, inv_input_dim, exp_input_dim, x_input_dim, hidden_layers=2, hidden_dim=128, arch=None):
        self.deeponet = DeepONet(
            branch_input_dim=inv_input_dim+exp_input_dim,
            trunk_input_dim=x_input_dim,
            hidden_layers=hidden_layers,
            hidden_dim=hidden_dim,
            arch=arch,
        )

    def init(self, rng_key=jax.random.PRNGKey(42)):
        return self.deeponet.init(rng_key=rng_key)

    # Define DeepONet architecture
    def apply(self, params, inv, exp, xs):
        us = jnp.concatenate([inv, exp], axis=-1)
        return self.deeponet.apply(params=params, branch_in=us, trunk_in=xs)

    def apply_single_branch(self, params, inv, exp, xs):
        assert len(xs.shape) == 2
        return jax.vmap(lambda x_: self.apply(params, inv, exp, x_))(xs)
    
    
# Data generator
class DataGenerator(data.Dataset):
    
    def __init__(self, inv, exp, x, s, batch_size=64, rng_key=jax.random.PRNGKey(1234)):
        'Initialization'
        self.inv = inv
        self.exp = exp
        self.x = x 
        self.s = s
        
        self.N = x.shape[0]
        self.batch_size = batch_size
        self.key = rng_key

    def __getitem__(self, index):
        'Generate one batch of data'
        self.key, subkey = jax.random.split(self.key)
        inputs, outputs = self.__data_generation(subkey)
        return inputs, outputs

    @partial(jax.jit, static_argnums=(0,))
    def __data_generation(self, key):
        'Generates data containing batch_size samples'
        idx = jax.random.choice(key, self.N, (self.batch_size,), replace=False)
        s = self.s[idx,:]
        x = self.x[idx,:]
        inv = self.inv[idx,:]
        exp = self.exp[idx,:]
        # Construct batch
        inputs = (inv, exp, x)
        outputs = s
        return inputs, outputs
    
    
class CollocationPointGenerator(data.Dataset):
    
    def __init__(self, inv, exp, x, batch_size=64, rng_key=jax.random.PRNGKey(1234)):
        'Initialization'
        self.inv_candidate = inv
        self.exp_candidate = exp
        self.x_candidate = x 
        
        self.batch_size = batch_size
        self.key = rng_key
        
    def __getitem__(self, index):
        self.key, subkey = jax.random.split(self.key)
        return self.__data_generation(subkey)

    @partial(jax.jit, static_argnums=(0,))
    def __data_generation(self, key):
        
        key, k_ = jax.random.split(key)
        idx_x = jax.random.choice(k_, self.x_candidate.shape[0], (self.batch_size,), replace=True)
        x = self.x_candidate[idx_x, :]
        
        key, k_ = jax.random.split(key)
        idx_inv = jax.random.choice(k_, self.inv_candidate.shape[0], (self.batch_size,), replace=True)
        inv = self.inv_candidate[idx_inv, :]
        
        key, k_ = jax.random.split(key)
        idx_exp = jax.random.choice(k_, self.exp_candidate.shape[0], (self.batch_size,), replace=True)
        exp = self.exp_candidate[idx_exp, :]
        
        return (inv, exp, x), (idx_inv, idx_exp, idx_x)
    
    
def generate_train_set_for_deeponet(
    noisy_simulator_xs, exp_setup_list, inv_param_sampler, pde_domain, 
    n_inv=100, n_xs=10000, batch_size=64, rng=jax.random.PRNGKey(0)):
    
    """
    noisy_simulator_xs(exp_params, xs, inv_param, rng) -> observed ys
    inv_param_sampler(n, rng) -> n samples of the inverse param
    xs_sampler(n, rng) -> n samples of x
    """
    
    pbar = tqdm.tqdm(total=len(exp_setup_list) * n_inv)
    
    betas_list, exp_list, xs_list, ys_list = [], [], [], []
    
    for exp_params in exp_setup_list:
    
        rng, r_ = jax.random.split(rng)
        betas_samples = inv_param_sampler(n_inv, r_)
        
        xs = pde_domain.random_points(n_xs, random='Hammersley')
        
        for beta in betas_samples:
            rng, r_ = jax.random.split(rng)
            ys = noisy_simulator_xs(exp_params, beta, r_)(xs)
            exp_list.append(jnp.repeat(exp_params[None, :], repeats=n_xs, axis=0))
            betas_list.append(jnp.repeat(beta[None, :], repeats=n_xs, axis=0))
            xs_list.append(xs)
            ys_list.append(ys)
            pbar.update()

        # rng, r_ = jax.random.split(rng)
        # r_ = jax.random.split(r_, n_inv)

        # ys_all = jax.vmap(noisy_simulator_xs, in_axes=(None, None, 0, 0))(exp_params, xs, betas_samples, r_)
                
        # exp_list.append(jnp.repeat(exp_params[None, :], repeats=(n_inv * n_xs), axis=0))
        # xs_list.append(jnp.vstack(jnp.repeat(xs[None, :], repeats=n_inv, axis=0)))
        # betas_list.append(jnp.vstack(jnp.repeat(betas_samples[:, None, :], repeats=n_xs, axis=0)))
        # ys_list.append(jnp.vstack(ys_all))    
    
    return DataGenerator(
        inv=jnp.vstack(betas_list), 
        exp=jnp.vstack(exp_list), 
        x=jnp.vstack(xs_list), 
        s=jnp.vstack(ys_list), 
        batch_size=batch_size, 
        rng_key=rng,
    )
    
    

def generate_collocation_train_set_for_deeponet(
    exp_setup_list, inv_param_sampler, pde_domain, 
    n_inv=1000, n_xs=10000, batch_size=256, rng=jax.random.PRNGKey(0)):
        
    rng, r_ = jax.random.split(rng)
    betas = jnp.array(inv_param_sampler(n_inv, r_))
    
    xs = jnp.array(pde_domain.random_points(n_xs, random='Hammersley'))
    
    return CollocationPointGenerator(
        inv=betas,
        exp=exp_setup_list,
        x=xs,
        batch_size=batch_size,
        rng_key=rng,
    )
    
    
def generate_train_set_for_deeponet_from_known_xs(
    exp_setup_list, inv_param_sampler, xs,
    n_inv=1000, n_xs=10000, batch_size=256, rng=jax.random.PRNGKey(0)):
        
    rng, r_ = jax.random.split(rng)
    betas = jnp.array(inv_param_sampler(n_inv, r_))
    return CollocationPointGenerator(
        inv=betas,
        exp=exp_setup_list,
        x=xs[:n_xs],
        batch_size=batch_size,
        rng_key=rng,
    )
    
    
def generate_bc_train_set_for_deeponet(
    exp_setup_list, inv_param_sampler, pde_domain, bc,
    n_inv=1000, n_xs=10000, batch_size=256, rng=jax.random.PRNGKey(0)):
        
    rng, r_ = jax.random.split(rng)
    betas = jnp.array(inv_param_sampler(n_inv, r_))
    
    if isinstance(bc, dde.icbc.IC):
        xs = jnp.array(pde_domain.random_initial_points(n_xs))
    elif isinstance(bc, dde.icbc.PointSetBC):
        raise NotImplementedError
    else:
        xs = jnp.array(pde_domain.random_boundary_points(n_xs))
    
    return CollocationPointGenerator(
        inv=betas,
        exp=exp_setup_list,
        x=xs,
        batch_size=batch_size,
        rng_key=rng,
    )
    
    
def train_deeponet(
    deeponet, dset, batch_size=256, steps=200000, optim_type='adam', optim_args=None, rng=jax.random.PRNGKey(0)):
    
    @jax.jit
    def new_loss(p, in_obs, out_obs):
        inv, exp, x = in_obs
        y = deeponet.apply(p, inv, exp, x)
        return jnp.mean((y - out_obs)**2)

    if optim_args is None:
        optim_args = {'learning_rate': optax.exponential_decay(1e-3, transition_steps=2000, decay_rate=0.9)}
    if optim_type == 'adam':
        opt = optax.adam()
    else:
        raise ValueError(f'Invalid optim_type "{optim_type}"')

    params = deeponet.init(rng_key=rng)
    solver = jaxopt.OptaxSolver(opt=opt, fun=jax.jit(jax.value_and_grad(new_loss)), value_and_grad=True)
    opt_state = solver.init_state(params)
    data_iterator = iter(dset)
    update_fn = jax.jit(solver.update)

    losses_steps = []
    losses_train = []

    for s in tqdm.trange(steps):
        xs, ys = next(data_iterator)
        params, opt_state = update_fn(params, opt_state, xs, ys)
        if (s+1) % 1000 == 0:
            losses_train.append(new_loss(params, xs, ys))
            losses_steps.append(s+1)
            
    aux = {
        'losses_train': losses_train,
        'losses_steps': losses_steps,
        'opt_state': opt_state,
        'dset': dset,
    }
    return params, aux


# def generate_residue_deeponet(bc, net_apply, exp_fixed):
    
#     if isinstance(bc, dde.icbc.boundary_conditions.PeriodicBC):
        
#         def residue(params, inv, xs):
#             Xl = xs.at[:,0].set(bc.geom.geometry.l)
#             Xr = xs.at[:,0].set(bc.geom.geometry.r)
#             f = lambda x_: net_apply(params, inv, exp_fixed, x_)
#             if bc.derivative_order == 0:
#                 yleft = f(Xl)[:, bc.component : bc.component + 1]
#                 yright = f(Xr)[:, bc.component : bc.component + 1]
#             else:
#                 yleft = dde.grad.jacobian((f(Xl), f), Xl, i=bc.component, j=bc.component_x)
#                 yright = dde.grad.jacobian((f(Xr), f), Xr, i=bc.component, j=bc.component_x)
#             return (yleft - yright).reshape(-1)
        
#     elif isinstance(bc, dde.icbc.boundary_conditions.NeumannBC):
        
#         def b_fn(xs):
#             return bc.boundary_normal(xs, 0, xs.shape[0], None)
        
#         def residue(params, inv, xs):
#             values = bc.func(xs, 0, xs.shape[0], None)
#             f_ = lambda x_: net_apply(params, inv, exp_fixed, x_)
#             ys = f_(xs)
#             dydx = dde.grad.jacobian((ys, f_), xs, i=bc.component, j=None)[0]
#             # n = jax.lax.stop_gradient(boundary_normal_jax(xs))
#             # n = jax.lax.stop_gradient(hcb.call(b_fn, xs, result_shape=xs))
#             n = jax.lax.stop_gradient(jax.pure_callback(b_fn, jax.ShapeDtypeStruct(xs.shape, xs.dtype), xs))
#             y = dde.backend.sum(dydx * n, 1, keepdims=True)
#             return (y - values).reshape(-1)
        
#     elif isinstance(bc, dde.icbc.boundary_conditions.PointSetBC):
#         raise NotImplementedError
        
#     else:
    
#         def residue(params, inv, xs):
#             return bc.error(
#                 X=xs,
#                 inputs=xs, 
#                 outputs=net_apply(params, inv, exp_fixed, xs), 
#                 beg=0, 
#                 end=xs.shape[0], 
#                 aux_var=None
#             ).reshape(-1)
            
#     return residue


def train_deeponet_with_pde(
    deeponet, dset, colloc_dset, pde, bcs_dset=None, bc_exp_fns=None,
    steps=200000, optim_type='adam', optim_args=None, rng=jax.random.PRNGKey(0)):
        

    if (bcs_dset is None) or (bc_exp_fns is None):
        bcs_training = []
        bcs_dset = []
        
    else:
        
        bcs_training = []
        
        for (bc_fn, _) in bc_exp_fns:
            
            def loss_fn_single(p, inv, exp, x):
                apply_fn = lambda p_, x_: deeponet.apply(p_, inv, exp, x_)
                return bc_fn(p, apply_fn, exp, x)
            
            def loss_fn_batch(p, invs, exps, xs):
                return jax.vmap(loss_fn_single, in_axes=(None, 0, 0, 0))(p, invs, exps, xs)
            
            bcs_training.append(loss_fn_batch)
        
    @jax.jit
    def compute_residue_single(p, inv, exp, x):
        x = x.reshape(1, -1)
        # exp = exp.reshape(1, -1)
        # inv = inv.reshape(1, -1)
        apply_fn = lambda x_: deeponet.apply(p, inv, exp, x_)
        y = apply_fn(x)
        # print(y.shape, x.shape, exp.shape)
        r_ = pde(x, (y, apply_fn), inv)
        return r_[0][0,0]
    
    @jax.jit
    def compute_residue(p, inv, exp, x):
        return jax.vmap(compute_residue_single, in_axes=(None, 0, 0, 0))(p, inv, exp, x)
    
    @jax.jit
    def data_loss(p, inv, exp, x, s):
        y = deeponet.apply(p, inv, exp, x)
        return y - s
    
    @jax.jit
    def overall_loss(p, obs, colloc, bcs):
                
        l = 0.
        
        (inv_obs, exp_obs, x_obs), out_obs = obs
        l += jnp.mean(data_loss(p, inv_obs, exp_obs, x_obs, out_obs) ** 2)
        
        (inv_colloc, exp_colloc, x_colloc), _ = colloc
        l += jnp.mean(compute_residue(p, inv_colloc, exp_colloc, x_colloc) ** 2)
        
        for (fn, in_bc) in zip(bcs_training, bcs):
            (inv_bc, exp_bc, x_bc), _ = in_bc
            l += jnp.mean(fn(p, inv_bc, exp_bc, x_bc) ** 2)
        
        return l
    
    
    if optim_args is None:
        optim_args = {'learning_rate': optax.exponential_decay(1e-3, transition_steps=2000, decay_rate=0.9)}
    if optim_type == 'adam':
        opt = optax.adam(**optim_args)
    else:
        raise ValueError(f'Invalid optim_type "{optim_type}"')

    params = deeponet.init(rng_key=rng)
    solver = jaxopt.OptaxSolver(opt=opt, fun=jax.jit(jax.value_and_grad(overall_loss)), value_and_grad=True)
    opt_state = solver.init_state(params)
    data_iterator = iter(dset)
    colloc_iterator = iter(colloc_dset)
    bcs_iterators = [iter(bc_dset) for bc_dset in bcs_dset]
    update_fn = jax.jit(solver.update)

    losses_steps = []
    losses_train = []

    for s in tqdm.trange(steps):
        anc = next(data_iterator)
        colloc_xs = next(colloc_iterator)
        bc_xs = [next(bc_iter) for bc_iter in bcs_iterators]
        params, opt_state = update_fn(params, opt_state, anc, colloc_xs, bc_xs)
        if (s+1) % 1000 == 0:
            losses_train.append(overall_loss(params, anc, colloc_xs, bc_xs))
            losses_steps.append(s+1)
            
    aux = {
        'losses_train': losses_train,
        'losses_steps': losses_steps,
        'opt_state': opt_state,
        'dset': dset,
        'colloc_dset': colloc_dset,
        'compute_residue': compute_residue,
        'data_loss': data_loss,
        'bcs_training': bcs_training,
        'overall_loss': overall_loss,
    }
    return params, aux


# def train_deeponet_with_pde(
#     deeponet, dset, colloc_dset, pde, bcs_dset=None, bcs_gen=None, 
#     steps=200000, optim_type='adam', optim_args=None, rng=jax.random.PRNGKey(0)):
        

#     if bcs_gen is None or bcs_dset is None:
#         bcs_training = []
#         bcs_dset = []
        
#     else:
        
#         bcs_training = []
        
#         def _gen_loss(bc_fn, bc_dset):
                    
#             def _gen_loss_single(bc_fn, exp_):
                
#                 bc = bc_fn(exp_)
#                 net_apply = lambda p, inv, exp, x_: deeponet.apply(p, inv.reshape(1, -1), exp.reshape(1, -1), x_.reshape(1, -1))
#                 bc_loss = generate_residue_deeponet(bc=bc, net_apply=net_apply, exp_fixed=exp_)
                  
#                 @jax.jit  
#                 def _loss_single(p, inv, x):
#                     return bc_loss(p, inv, x.reshape(1, -1))[0]
                
#                 return _loss_single
            
#             branches = []
#             for exp in bc_dset.exp_candidate:
#                 branches.append(_gen_loss_single(bc_fn, exp))
            
#             # exp_idxs = jnp.linspace(0, bc_dset.exp_candidate.shape[0], bc_dset.batch_size, endpoint=False, dtype=int)
#             # exps_fixed = bc_dset.exp_candidate[exp_idxs]
            
#             @jax.jit
#             def _loss(p, inv, exp, x, inv_idx, exp_idx, x_idx):
#                 _fn = lambda p, inv, x: jax.lax.switch(exp_idx[0], branches, p, inv, x)
#                 return jax.vmap(jax.jit(_fn), in_axes=(None, 0, 0))(p, inv, x)
            
#             return _loss
        
#         for i, bc_dset in enumerate(bcs_dset):
#             bc_fn = lambda exp_: bcs_gen(exp_)[i]
#             bcs_training.append(_gen_loss(bc_fn, bc_dset))
        
#     @jax.jit
#     def compute_residue_single(p, inv, exp, x):
#         x = x.reshape(1, -1)
#         # exp = exp.reshape(1, -1)
#         # inv = inv.reshape(1, -1)
#         apply_fn = lambda x_: deeponet.apply(p, inv, exp, x_)
#         y = apply_fn(x)
#         # print(y.shape, x.shape, exp.shape)
#         r_ = pde(x, (y, apply_fn), inv)
#         return r_[0][0,0]
    
#     @jax.jit
#     def compute_residue(p, inv, exp, x):
#         return jax.vmap(compute_residue_single, in_axes=(None, 0, 0, 0))(p, inv, exp, x)
    
#     @jax.jit
#     def data_loss(p, inv, exp, x, s):
#         y = deeponet.apply(p, inv, exp, x)
#         return y - s
    
#     @jax.jit
#     def overall_loss(p, obs, colloc, bcs):
                
#         l = 0.
        
#         (inv_obs, exp_obs, x_obs), out_obs = obs
#         l += jnp.mean(data_loss(p, inv_obs, exp_obs, x_obs, out_obs) ** 2)
        
#         (inv_colloc, exp_colloc, x_colloc), _ = colloc
#         l += jnp.mean(compute_residue(p, inv_colloc, exp_colloc, x_colloc) ** 2)
        
#         for (fn, in_bc) in zip(bcs_training, bcs):
#             (inv_bc, exp_bc, x_bc), (inv_idx, exp_idx, x_idx) = in_bc
#             l += jnp.mean(fn(p, inv_bc, exp_bc, x_bc, inv_idx, exp_idx, x_idx) ** 2)
        
#         return l
    
    
#     if optim_args is None:
#         optim_args = {'learning_rate': optax.exponential_decay(1e-3, transition_steps=2000, decay_rate=0.9)}
#     if optim_type == 'adam':
#         opt = optax.adam(**optim_args)
#     else:
#         raise ValueError(f'Invalid optim_type "{optim_type}"')

#     params = deeponet.init(rng_key=rng)
#     solver = jaxopt.OptaxSolver(opt=opt, fun=jax.jit(jax.value_and_grad(overall_loss)), value_and_grad=True)
#     opt_state = solver.init_state(params)
#     data_iterator = iter(dset)
#     colloc_iterator = iter(colloc_dset)
#     bcs_iterators = [iter(bc_dset) for bc_dset in bcs_dset]
#     update_fn = jax.jit(solver.update)

#     losses_steps = []
#     losses_train = []

#     for s in tqdm.trange(steps):
#         anc = next(data_iterator)
#         colloc_xs = next(colloc_iterator)
#         bc_xs = [next(bc_iter) for bc_iter in bcs_iterators]
#         params, opt_state = update_fn(params, opt_state, anc, colloc_xs, bc_xs)
#         if (s+1) % 1000 == 0:
#             losses_train.append(overall_loss(params, anc, colloc_xs, bc_xs))
#             losses_steps.append(s+1)
            
#     aux = {
#         'losses_train': losses_train,
#         'losses_steps': losses_steps,
#         'opt_state': opt_state,
#         'dset': dset,
#         'colloc_dset': colloc_dset,
#         'compute_residue': compute_residue,
#         'data_loss': data_loss,
#         'bcs_training': bcs_training,
#         'overall_loss': overall_loss,
#     }
#     return params, aux
