from functools import partial

import jax.numpy as jnp
from jax import lax, jit, grad, random, vmap

import optax
from dataclasses import dataclass

from phijax.models import *
from phijax.data import *
from phijax.equations.base import IVP, residual_stats_logging
from phijax.equations.registry import register_pde


@dataclass
class WindowState:
    idx: int
    v0: jnp.ndarray
    u0: jnp.ndarray
    p0: jnp.ndarray
    T0: jnp.ndarray
    num_steps: int
    max_steps: int


def get_dataset(ref_path):
    import scipy.io
    data = jnp.load(ref_path, allow_pickle=True).item()
    start_idx = 5

    velocity = jnp.array(data["velocity"])[start_idx:]
    pressure = jnp.array(data["pressure"])[start_idx:]
    temperature = jnp.array(data["temperature"])[start_idx:]

    t = jnp.array(data["t"])[start_idx:]
    t = t - t[0]
    coords = jnp.array(data["coords"])

    # parameters
    alpha1 = data["alpha1"]
    alpha2 = data["alpha2"]
    alpha3 = data["alpha3"]
    alpha4 = data["alpha4"]

    Ra = data["Ra"]
    Pr = data["Pr"]
    Ge = data["Ge"]

    return velocity, pressure, temperature, t, coords, alpha1, alpha2, alpha3, alpha4, Ra, Pr, Ge

@register_pde(
    "rayleightaylor", aliases=["rt"], 
)
class RayleighTaylor(IVP):
    loss_keys = ("uics", "vics", "pics", "Tics", "ubcs", "vbcs", "Tbcs", "resu", "resv", "resc", "rese")
    def __init__(self, config):
        super().__init__(config)

        pcfg = self.config.pde_config
        num_pts = pcfg.num_points_per_dim or 256

        velocity, pressure, temperature, t_star, coords, alpha1, alpha2, alpha3, alpha4, Ra, Pr, Ge = get_dataset(pcfg.ref_path)
        self.velocity = velocity
        self.pressure = pressure
        self.temperature = temperature
        self.t_star = t_star
        self.coords = coords

        self.u0 = velocity[0, :, 0]
        self.v0 = velocity[0, :, 1]
        self.T0 = temperature[0]
        self.p0 = pressure[0]

        self.alpha1, self.alpha2, self.alpha3, self.alpha4  = alpha1, alpha2, alpha3, alpha4
        self.Ra, self.Pr, self.Ge = Ra, Pr, Ge

        self.v_scale = jnp.max(jnp.sqrt(self.u0**2 + self.v0**2)) + 1e-3

        self.u_ic_pred_fn = vmap(self.u_net, (None, None, 0, 0))
        self.v_ic_pred_fn = vmap(self.v_net, (None, None, 0, 0))
        self.p_ic_pred_fn = vmap(self.p_net, (None, None, 0, 0))
        self.temp_ic_pred_fn = vmap(self.temp_net, (None, None, 0, 0))

        

        self.u_bc_pred_fn = vmap(self.u_net, (None, 0, 0, 0))
        self.v_bc_pred_fn = vmap(self.v_net, (None, 0, 0, 0))
        self.temp_bc_pred_fn = vmap(self.temp_net, (None, 0, 0, 0))

        self.r_pred_fn = vmap(self.r_net, (None, 0, 0, 0))
        self.u_pred_fn = vmap(vmap(self.u_net, (None, None, 0, 0)), (None, 0, None, None))
        self.v_pred_fn = vmap(vmap(self.v_net, (None, None, 0, 0)), (None, 0, None, None))
        self.p_pred_fn = vmap(vmap(self.p_net, (None, None, 0, 0)), (None, 0, None, None))
        self.temp_pred_fn = vmap(vmap(self.temp_net, (None, None, 0, 0)), (None, 0, None, None))


    def get_ref(self):
        return self.velocity, self.pressure, self.temperature, self.t_star, self.coords, self.alpha1, self.alpha2, self.alpha3, self.alpha4


    def net(self, state, t, x, y):
        t = t / self.t_star[-1]
        z = jnp.stack([t, x, y])
        _, out = self.state.apply_fn(state.variables(), z)
        return out[0], out[1], out[2], out[3]
    
    def u_net(self, state, t, x, y):
        u, _, _, _ = self.net(state, t, x, y)
        return u
    
    def v_net(self, state, t, x, y):
        _, v, _, _ = self.net(state, t, x, y)
        return v
    def p_net(self, state, t, x, y):
        _, _, p, _ = self.net(state, t, x, y)
        return p
    def temp_net(self, state, t, x, y):
        _, _, _, temp = self.net(state, t, x, y)
        return temp
       
    def r_net(self, state, t, x, y):
        u, v, p, temp = self.net(state, t, x, y)
        (
            (u_t, u_x, u_y),
            (v_t, v_x, v_y),
            (_, p_x, p_y),
            (temp_t, temp_x, temp_y)
            ) =  jax.jacrev(self.net, argnums=(1,2,3))(state, t, x, y)
        
        u_hess = jax.hessian(self.u_net, argnums=(2,3))(state, t, x, y)
        v_hess = jax.hessian(self.v_net, argnums=(2,3))(state, t, x, y)
        temp_hess = jax.hessian(self.temp_net, argnums=(2,3))(state, t, x, y)

        u_xx = u_hess[0][0]
        u_yy = u_hess[1][1]

        v_xx = v_hess[0][0]
        v_yy = v_hess[1][1]
        temp_xx = temp_hess[0][0]
        temp_yy = temp_hess[1][1]

        ru = u_t + u * u_x + v * u_y + p_x - self.alpha1 * (u_xx + u_yy)
        rv = v_t + u * v_x + v * v_y + p_y - self.alpha1 * (v_xx + v_yy) - self.alpha2 * temp
        rc = u_x + v_y
        re = temp_t + u * temp_x + v * temp_y - self.alpha4 * (temp_xx + temp_yy)

        return ru, rv, rc, re

    @partial(jit, static_argnums=(0,))
    def residuals(self, state, batch):
        
        ics_batch = batch['ics']
        bcs_batch = batch['bcs']
        res_batch = batch['res']

         # Initial condition loss
        coords_batch, u_batch, v_batch, p_batch, temp_batch = ics_batch

        u_ic_pred = self.u_ic_pred_fn(state, 0.0, coords_batch[:, 0], coords_batch[:, 1])
        v_ic_pred = self.v_ic_pred_fn(state, 0.0, coords_batch[:, 0], coords_batch[:, 1])
        p_ic_pred = self.p_ic_pred_fn(state, 0.0, coords_batch[:, 0], coords_batch[:, 1])
        temp_ic_pred = self.temp_ic_pred_fn(state, 0.0, coords_batch[:, 0], coords_batch[:, 1])

        #BCs
        u_bc_pred = self.u_bc_pred_fn(state, bcs_batch[:, 0], bcs_batch[:, 1], bcs_batch[:, 2])
        v_bc_pred = self.v_bc_pred_fn(state, bcs_batch[:, 0], bcs_batch[:, 1], bcs_batch[:, 2])
        temp_bc_pred = self.temp_bc_pred_fn(state, bcs_batch[:, 0], bcs_batch[:, 1], bcs_batch[:, 2])

        ru_pred, rv_pred, rc_pred, re_pred = self.r_pred_fn(state, res_batch[:, 0], res_batch[:, 1], res_batch[:, 2])

        return {
            'uics': u_ic_pred - u_batch,
            'vics': v_ic_pred - v_batch,
            'pics': p_ic_pred - p_batch,
            'Tics': temp_ic_pred - temp_batch,
            'ubcs': u_bc_pred,
            'vbcs': v_bc_pred,
            'Tbcs': temp_bc_pred,
            'resu': ru_pred,
            'resv': rv_pred,
            'resc': rc_pred,
            'rese': re_pred

        }
    
    @partial(jit, static_argnums=(0,))
    def compute_metrics(self, state, t, coords, u_ref, v_ref, temp_ref):
        u_pred = self.u_pred_fn(state, t, coords[:, 0], coords[:, 1])
        v_pred = self.v_pred_fn(state, t, coords[:, 0], coords[:, 1])
        temp_pred = self.temp_pred_fn(state, t, coords[:, 0], coords[:, 1])

        u_error = jnp.linalg.norm(u_ref - u_pred) / jnp.linalg.norm(u_ref)
        v_error = jnp.linalg.norm(v_ref - v_pred) / jnp.linalg.norm(v_ref)
        temp_error = jnp.linalg.norm(temp_ref - temp_pred) / jnp.linalg.norm(temp_ref)

        return u_error, v_error, temp_error
    
    def log_errors(self, state, t, coords, u_ref, v_ref, temp_ref):
        u_error, v_error, temp_error = self.compute_metrics(state, t, coords, u_ref, v_ref, temp_ref)
        self.log_dict["u_error"] = u_error
        self.log_dict["v_error"] = v_error
        self.log_dict["temp_error"] = temp_error

    def set_initial_condition(self, u0, v0, p0, temp0, u_star, v_star, temp_star, window_t_star, *args):
        self.u0 = u0
        self.v0 = v0
        self.p0 = p0
        self.T0 = temp0
        self.u_ref_window = u_star
        self.v_ref_window = v_star
        self.temp_ref_window = temp_star
        self.wt_star = window_t_star


    def log(self, state, batch, t, coords, u_ref, v_ref, temp_ref):
        self.log_dict = super().log(state, batch, t, coords, u_ref, v_ref, temp_ref)
        #self.log_errors(state, t, coords, u_ref, v_ref, temp_ref)
        return self.log_dict
    
    def _log_stats(self, state, batch, *args):
        residuals = self.residuals(state, batch, *args)
        #print("Logging residual stats...")
        keys = ["resu", "resv", "resc"]
        for key in keys:
            stats = residual_stats_logging(residuals[key])
            for stat_name, stat_value in stats.items():
                self.log_dict[f"stats/{key}_{stat_name}"] = stat_value


     
   
    



