from functools import partial

import jax
import jax.numpy as jnp
from jax import lax, jit, grad, vmap, jacrev, hessian
from jax.flatten_util import ravel_pytree

from jaxpi.models import ForwardIVP
from jaxpi.evaluator import BaseEvaluator
from jaxpi.utils import ntk_fn, flatten_pytree, compute_cosine_similarity, compute_intra_alignment_score, compute_inter_alignment_score

from matplotlib import pyplot as plt


class Wave(ForwardIVP):
    def __init__(self, config, u0, t_star, x_star, c):
        super().__init__(config)

        state = jax.tree.map(lambda x: x[0], self.state)
        self.old_state = state

        self.u0 = u0
        self.t_star = t_star
        self.x_star = x_star
        self.c = c

        self.t0 = t_star[0]
        self.t1 = t_star[-1]

        # Predictions over a grid
        self.u_pred_fn = vmap(vmap(self.u_net, (None, None, 0)), (None, 0, None))
        self.r_pred_fn = vmap(vmap(self.r_net, (None, None, 0)), (None, 0, None))

    def u_net(self, params, t, x):
        z = jnp.stack([t, x])
        _, u = self.state.apply_fn(params, z)
        return u[0]

    def u_t_net(self, params, t, x):
        u_t = grad(self.u_net, argnums=1)(params, t, x)
        return u_t

    def r_net(self, params, t, x):

        u_tt = grad(grad(self.u_net, argnums=1), argnums=1)(params, t, x)
        u_xx = grad(grad(self.u_net, argnums=2), argnums=2)(params, t, x)

        return u_tt - self.c**2 * u_xx

    @partial(jit, static_argnums=(0,))
    def res_and_w(self, params, batch):
        # Sort temporal coordinates for computing  temporal weights
        t_sorted = batch[:, 0].sort()
        # Compute residuals over the full domain
        r_pred = vmap(self.r_net, (None, 0, 0))(params, t_sorted, batch[:, 1])
        # Split residuals into chunks
        r_pred = r_pred.reshape(self.num_chunks, -1)
        l = jnp.mean(r_pred**2, axis=1)
        # Compute temporal weights
        w = lax.stop_gradient(jnp.exp(-self.tol * (self.M @ l)))
        return l, w

    @partial(jit, static_argnums=(0,))
    def losses(self, params, batch):
        # Initial condition loss
        u0_pred = vmap(self.u_net, (None, None, 0))(params, self.t0, self.x_star)
        u0_loss = jnp.mean((self.u0 - u0_pred) ** 2)

        u_t0_pred = vmap(self.u_t_net, (None, None, 0))(params, self.t0, self.x_star)
        u_t0_loss = jnp.mean((0 - u_t0_pred) ** 2)

        # Boundary condition loss
        u_bc1_pred = vmap(self.u_net, (None, 0, None))(params, self.t_star, self.x_star[0])
        u_bc2_pred = vmap(self.u_net, (None, 0, None))(params, self.t_star, self.x_star[-1])
        bcs_loss = jnp.mean((u_bc1_pred) ** 2) + jnp.mean((u_bc2_pred) ** 2)

        # Residual loss
        if self.config.weighting.use_causal == True:
            l, w = self.res_and_w(params, batch)
            res_loss = jnp.mean(l * w)
        else:
            r_pred = vmap(self.r_net, (None, 0, 0))(params, batch[:, 0], batch[:, 1])
            res_loss = jnp.mean((r_pred) ** 2)

        loss_dict = {"u0": u0_loss, "u_t0": u_t0_loss, "res": res_loss, "bcs": bcs_loss}
        return loss_dict

    @partial(jit, static_argnums=(0,))
    def compute_diag_ntk(self, params, batch):
        ics_ntk = vmap(ntk_fn, (None, None, None, 0))(
            self.u_net, params, self.t0, self.x_star
        )

        # Consider the effect of causal weights
        if self.config.weighting.use_causal:
            # sort the time step for causal loss
            batch = jnp.array([batch[:, 0].sort(), batch[:, 1]]).T
            res_ntk = vmap(ntk_fn, (None, None, 0, 0))(
                self.r_net, params, batch[:, 0], batch[:, 1]
            )

            res_ntk = res_ntk.reshape(self.num_chunks, -1)  # shape: (num_chunks, -1)
            res_ntk = jnp.mean(
                res_ntk, axis=1
            )  # average convergence rate over each chunk
            _, casual_weights = self.res_and_w(params, batch)
            res_ntk = res_ntk * casual_weights  # multiply by causal weights
        else:
            res_ntk = vmap(ntk_fn, (None, None, 0, 0))(
                self.r_net, params, batch[:, 0], batch[:, 1]
            )

        ntk_dict = {"ics": ics_ntk, "res": res_ntk}

        return ntk_dict

    @partial(jit, static_argnums=(0,))
    def compute_l2_error(self, params, u_test):
        u_pred = self.u_pred_fn(params, self.t_star, self.x_star)
        error = jnp.linalg.norm(u_pred - u_test) / jnp.linalg.norm(u_test)
        return error


class WaveEvaluator(BaseEvaluator):
    def __init__(self, config, model):
        super().__init__(config, model)

    def log_errors(self, params, u_ref):
        l2_error = self.model.compute_l2_error(params, u_ref)
        self.log_dict["l2_error"] = l2_error

    def log_preds(self, params):
        u_pred = self.model.u_pred_fn(params, self.model.t_star, self.model.x_star)
        fig = plt.figure(figsize=(6, 5))
        plt.imshow(u_pred.T, cmap="jet")
        self.log_dict["u_pred"] = fig
        plt.close()

    def __call__(self, state, batch, u_ref):
        self.log_dict = super().__call__(state, batch)

        if self.config.weighting.use_causal:
            _, causal_weight = self.model.res_and_w(state.params, batch)
            self.log_dict["cas_weight"] = causal_weight.min()

        if self.config.logging.log_errors:
            self.log_errors(state.params, u_ref)

        if self.config.logging.log_preds:
            self.log_preds(state.params)

        if self.config.logging.log_nonlinearities:
            layer_keys = [key for key in state.params['params'].keys() if
                          key.endswith(tuple([f"Bottleneck_{i}" for i in range(self.config.arch.num_layers)]))]
            for i, key in enumerate(layer_keys):
                self.log_dict[f"alpha_{i}"] = state.params['params'][key]['alpha']

        if self.config.logging.log_cossim:
            tt, xx = jnp.meshgrid(self.model.t_star, self.model.x_star, indexing='ij')
            batch = jnp.hstack([tt.flatten()[:, None], xx.flatten()[:, None]])

            grads = grad(self.model.loss)(state.params, state.weights, batch)
            grads_dict = jacrev(self.model.losses)(state.params, batch)

            # optimizer update
            scaled_grads, _ = state.tx.update(grads, state.opt_state, state.params)

            old_grads = grad(self.model.loss)(self.model.old_state.params, self.model.old_state.weights, batch)
            scaled_old_grads, _ = state.tx.update(old_grads, self.model.old_state.opt_state,
                                                  self.model.old_state.params)

            inter_align, inter_scaled_align = compute_inter_alignment_score(state, grads_dict)

            intra_align = compute_intra_alignment_score(grads, old_grads)
            intra_scaled_align = compute_intra_alignment_score(scaled_grads, scaled_old_grads)
            self.log_dict["inter_align"] = inter_align
            self.log_dict["inter_scaled_align"] = inter_scaled_align
            self.log_dict["intra_align"] = intra_align
            self.log_dict["intra_scaled_align"] = intra_scaled_align

            # update grads and scaled_grads
            self.grads = grads
            self.scaled_grads = scaled_grads

        return self.log_dict
