from functools import partial

import jax.numpy as jnp
from jax import lax, jit, grad, random, vmap

import optax
from dataclasses import dataclass


from phijax.models import *
from phijax.data import *
from phijax.equations.base import IVP, residual_stats_logging
from phijax.equations.registry import register_pde


@dataclass
class WindowState:
    idx: int
    v0: jnp.ndarray
    u0: jnp.ndarray
    p0: jnp.ndarray
    T0: jnp.ndarray
    num_steps: int
    max_steps: int


def get_dataset(ref_path, Re=1e4, tf=1.0):
    data = jnp.load(ref_path, allow_pickle=True).item()
    w_ref = jnp.array(data["vorticity"])
    velocity = jnp.array(data["velocity"])

    u_ref = velocity[..., 0]
    v_ref = velocity[..., 1]

    t = jnp.array(data["t"]).flatten()
    t = t - t[0]

    # Truncate data
    num_steps = int(tf * t.shape[0])
    u_ref = u_ref[:num_steps]
    v_ref = v_ref[:num_steps]
    w_ref = w_ref[:num_steps]
    t = t[:num_steps]

    print("t.shape", t.shape, "u_ref.shape", u_ref.shape, "v_ref.shape", v_ref.shape, "w_ref.shape", w_ref.shape)

    coords = jnp.array(data["coords"])
    nu = data["nu"]
    nu = 1 / Re

    return u_ref, v_ref, w_ref, t, coords, nu

@jit 
def periodic_l2(field1, field2):
    norm_ref = jnp.sqrt(jnp.sum(field1 ** 2))

    shape = field1.shape

    def compute_error_for_shift(shift_x, shift_y):
        shifted = jnp.roll(jnp.roll(field2, shift_x, axis=0), shift_y, axis=1)
        diff = field1 - shifted
        return jnp.sqrt(jnp.sum(diff ** 2)) / norm_ref
    shifts_x, shifts_y = jnp.meshgrid(jnp.arange(shape[0]), jnp.arange(shape[1]))
    errors = vmap(lambda x, y: compute_error_for_shift(x, y))(
        shifts_x.ravel(), shifts_y.ravel()
    )

    return jnp.min(errors)

    
@register_pde(
    "kolmogorov", aliases=["kf"], 
)
class KolmogorovFlow(IVP):
    loss_keys = ("uics", "vics", "ru", "rv", "rc")
    def __init__(self, config):
        super().__init__(config)

        pcfg = self.config.pde_config
        num_pts = pcfg.num_points_per_dim or 256

        u_ref, v_ref, w_ref, t_star, coords, nu = get_dataset(pcfg.ref_path)
        self.u_ref = u_ref
        self.v_ref = v_ref
        self.w_ref = w_ref
        self.t_star = t_star
        self.coords = coords
        self.nu = nu


        self.u0 = u_ref[0, :]
        self.v0 = v_ref[0, :]
        self.w0 = w_ref[0, :]

        self.force_fn = lambda x, y: 2 * jnp.sin(4 * jnp.pi * y)

        self.u_ic_pred_fn = vmap(self.u_net, (None, None, 0, 0))
        self.v_ic_pred_fn = vmap(self.v_net, (None, None, 0, 0))
        self.w_ic_pred_fn = vmap(self.w_net, (None, None, 0, 0))

        self.u_pred_fn = vmap(vmap(self.u_net, (None, None, 0, 0)), (None, 0, None, None))
        self.v_pred_fn = vmap(vmap(self.v_net, (None, None, 0, 0)), (None, 0, None, None))
        self.w_pred_fn = vmap(vmap(self.w_net, (None, None, 0, 0)), (None, 0, None, None))
        self.r_pred_fn = vmap(self.r_net, (None, 0, 0, 0))




    def get_ref(self):
        return self.u_ref, self.v_ref, self.w_ref, self.t_star, self.coords


    def net(self, state, t, x, y):
        t = t / self.t_star[-1]
        z = jnp.stack([t, x, y])
        _, out = self.state.apply_fn(state.variables(), z)
        return out[0], out[1], out[2]
    
    def u_net(self, state, t, x, y):
        u, _, _ = self.net(state, t, x, y)
        return u
    
    def v_net(self, state, t, x, y):
        _, v, _ = self.net(state, t, x, y)
        return v

    def p_net(self, state, t, x, y):
        _, _, p = self.net(state, t, x, y)
        return p
    
    def w_net(self, state, t, x, y):
       u_y = grad(self.u_net, argnums=3)(state, t, x, y)
       v_x = grad(self.v_net, argnums=2)(state, t, x, y)
       return v_x - u_y
       
    def r_net(self, state, t, x, y):
        u, v, p = self.net(state, t, x, y)
        (
            (u_t, u_x, u_y),
            (v_t, v_x, v_y),
            (_, p_x, p_y),
            ) =  jax.jacrev(self.net, argnums=(1,2,3))(state, t, x, y)
        
        u_hess = jax.hessian(self.u_net, argnums=(2,3))(state, t, x, y)
        v_hess = jax.hessian(self.v_net, argnums=(2,3))(state, t, x, y)
       

        u_xx = u_hess[0][0]
        u_yy = u_hess[1][1]

        v_xx = v_hess[0][0]
        v_yy = v_hess[1][1]
        
        force = self.force_fn(x, y)

        ru = u_t + u * u_x + v * u_y + p_x - self.nu * (u_xx + u_yy) - force
        rv = v_t + u * v_x + v * v_y + p_y - self.nu * (v_xx + v_yy)
        rc = u_x + v_y

        return ru, rv, rc

    @partial(jit, static_argnums=(0,))
    def residuals(self, state, batch):
        
        ics_batch = batch['ics']
        res_batch = batch['res']

         # Initial condition loss
        coords_batch, u_batch, v_batch, w_batch = ics_batch

        u_ic_pred = self.u_ic_pred_fn(state, 0.0, coords_batch[:, 0], coords_batch[:, 1])
        v_ic_pred = self.v_ic_pred_fn(state, 0.0, coords_batch[:, 0], coords_batch[:, 1])
        
        ru_pred, rv_pred, rc_pred = self.r_pred_fn(state, res_batch[:, 0], res_batch[:, 1], res_batch[:, 2])

        return {
            "uics": u_ic_pred - u_batch,
            "vics": v_ic_pred - v_batch,
            "ru": ru_pred,
            "rv": rv_pred,
            "rc": rc_pred
        }
    
    @partial(jit, static_argnums=(0,))
    def compute_metrics(self, state, t, coords, u_ref, v_ref, w_ref):
        u_pred = self.u_pred_fn(state, t, coords[:, 0], coords[:, 1])
        v_pred = self.v_pred_fn(state, t, coords[:, 0], coords[:, 1])
        w_pred = self.w_pred_fn(state, t, coords[:, 0], coords[:, 1])

        u_error = jnp.linalg.norm(u_ref - u_pred) / jnp.linalg.norm(u_ref)
        v_error = jnp.linalg.norm(v_ref - v_pred) / jnp.linalg.norm(v_ref)
        w_error = jnp.linalg.norm(w_ref - w_pred) / jnp.linalg.norm(w_ref)
        return u_error, v_error, w_error
    
    def log_errors(self, state, t, coords, u_ref, v_ref, w_ref):
        u_error, v_error, w_error = self.compute_metrics(state, t, coords, u_ref, v_ref, w_ref)
        self.log_dict["u_error"] = u_error
        self.log_dict["v_error"] = v_error
        self.log_dict["w_error"] = w_error

    def set_initial_condition(self, u0, v0, p0, w0, u_star, v_star, w_star, window_t_star, *args):
        self.u0 = u0
        self.v0 = v0
        self.p0 = p0
        self.w0 = w0
        self.u_ref_window = u_star
        self.v_ref_window = v_star
        self.w_ref_window = w_star
        self.wt_star = window_t_star


    def log(self, state, batch, t, coords, u_ref, v_ref, w_ref):
        self.log_dict = super().log(state, batch, t, coords, u_ref, v_ref, w_ref)
        #self.log_errors(state, t, coords, u_ref, v_ref, temp_ref)
        return self.log_dict
    
    def _log_stats(self, state, batch, *args):
        residuals = self.residuals(state, batch, *args)
        #print("Logging residual stats...")
        keys = ["ru", "rv", "rc"]
        for key in keys:
            stats = residual_stats_logging(residuals[key])
            for stat_name, stat_value in stats.items():
                self.log_dict[f"stats/{key}_{stat_name}"] = stat_value

