import os
import sys

if not __package__:
    # Make CLI runnable from source tree with
    #    python src/package
    package_source_path = os.path.dirname(os.path.dirname(__file__))
    sys.path.insert(0, package_source_path)

import hydra
import jax.numpy as jnp
from jax import config

from omegaconf import DictConfig, OmegaConf
from sklearn.model_selection import train_test_split
config.update("jax_enable_x64", True)

from pi_lr.data.physics import LinearDE

OmegaConf.register_new_resolver("mul_pi", lambda x: jnp.pi * x)
@hydra.main(version_base=None, config_path="../config", config_name="main.yaml")
def main(cfg: DictConfig):
    # instantiate dataset
    data = hydra.utils.instantiate(cfg.data.dataset)
    equation = data.equation
    
    # instantiate model
    # model = hydra.utils.instantiate(cfg.model, dim_in=data.dim, domain=equation.domain)
    model = hydra.utils.instantiate(
        cfg.model,
        dim_in=data.dim, 
        dim_out_base=[data.nt, data.nx], 
        dim_out_test=[data.nt, data.nx],
        domain=equation.domain
    )
    if cfg.calc_dim:
        # calc affine variety dimension
        optimal_weight = model.basis_function.projection(data.y_clean, model.test_function)
        print(f"d_V = {equation.affine_variety_dimension(model.basis_function, model.test_function, optimal_weight, iterate=True)}")
        # print(f"d_V^eff = {equation.effective_affine_variety_dimension(model.basis_function, model.test_function, optimal_weight)}")
        
    if cfg.train:
        for i in range(cfg.n_init):
            X, y, y_clean = data[i]
            X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=cfg.random_seed)
            X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.25, random_state=cfg.random_seed)
            
            X_train = X_train[:cfg.n_train]
            y_train = y_train[:cfg.n_train]
            
            # using exact solution
            if isinstance(equation, LinearDE):
                model.fit_exact(X_train, y_train, equation=equation, lambda_L2=cfg.lambda_L2, lambda_eq=cfg.lambda_eq)
                
            # using gradient descent
            else:
                model.fit_gradient_descent(
                    X_train,
                    y_train,
                    X_val=X_val,
                    y_val=y_val, 
                    lambda_L2=cfg.lambda_L2,
                    lambda_eq=cfg.lambda_eq, 
                    equation=equation,
                    weight_init=cfg.optimizer.weight_init,
                    optimizer=cfg.optimizer.name,
                    learning_rate=cfg.optimizer.learning_rate,
                    scheduler=cfg.optimizer.shceduler,
                )
            
            test_loss = model.score(X_test, y_test)
            train_loss = model.score(X_train, y_train)
            pinn_loss = equation.pi_loss(model.basis_function, model.test_function, model.weights)
            print(f"test loss: {test_loss}, train loss: {train_loss}, pinn loss: {pinn_loss}")
            equation.visualize(X, y_clean, X_train, f"{cfg.out_dir}/{repr(equation)}.png")

# 使用例
if __name__ == "__main__":
    main()
