import jax
import hydra
from omegaconf import OmegaConf
import numpy as np
from copy import deepcopy
from pathlib import Path

import os, sys
if not __package__:
    # Make CLI runnable from source tree with
    #    python src/package
    package_source_path = os.path.dirname(os.path.dirname(__file__))
    sys.path.insert(0, package_source_path)
    
from pi_lr.data.physics import *

class Data:
    def __init__(self, data_dir: str, std: float=1.0):
        print("Loading data... : ", data_dir)
        data_dir = Path(data_dir)
        self.X = np.load(data_dir / "X.npy")
        self.y = np.load(data_dir / "y.npy")
        self.y0 = np.load(data_dir / "y0.npy")
        
        equation_config = OmegaConf.load(data_dir / "config.yaml")
        self.equation = hydra.utils.instantiate(equation_config.data.equation)
        
        if len(self.X.shape) == 1:
            self.X = self.X[:, np.newaxis]
            
        if len(self.y.shape) == 2:
            self.y = self.y[:, :, np.newaxis]
            
        key = jax.random.PRNGKey(0) # random key
        noise = std * jax.random.normal(key, shape=self.y.shape)
        self.y_clean = deepcopy(self.y)
        self.y += noise
        
    def __getitem__(self, index):
        return self.X.reshape(-1, self.dim), self.y[index].reshape(-1), self.y_clean[index].reshape(-1)
    
    def __len__(self):
        return self.y.shape[0]
    
    def __repr__(self):
        return f"Data(X.shape={self.X.shape}, y.shape={self.y.shape})"
    
    @property
    def n(self):
        return self.y.shape[0]
    
    @property
    def nt(self):
        return self.X.shape[0]
    
    @property
    def nx(self):
        if len(self.X.shape) == 2:
            return 1
        else:
            return self.X.shape[1]
    
    @property
    def dim(self):
        return self.X.shape[-1]
    
    def resample_data(self, X, y, rng: jax.random.PRNGKey):
        h = X[1, :] - X[0, :]
        u = jax.random.uniform(rng, X.shape, maxval=0.99)
        X_test = X + (h / 2)
        y_test = y
        return X_test, y_test