import os
import sys
import gc

if not __package__:
    # Make CLI runnable from source tree with
    #    python src/package
    package_source_path = os.path.dirname(os.path.dirname(__file__))
    sys.path.insert(0, package_source_path)

import wandb
import optuna
import hydra
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import scienceplots

import jax
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
from omegaconf import DictConfig, OmegaConf
from pi_lr import Data
from jax import config
from tqdm import tqdm
from pathlib import Path
import pandas as pd
from pi_lr.data.physics import LinearDE
config.update("jax_enable_x64", True)
OmegaConf.register_new_resolver("mul_pi", lambda x: jnp.pi * x)

plt.style.use(['science', 'no-latex'])

def sweep(cfg, use_pi=False):
    out_dir = Path(hydra.utils.get_original_cwd()) / cfg.out_dir
    os.makedirs(out_dir, exist_ok=True)
    
    key = jax.random.PRNGKey(cfg.random_seed) if cfg.random_seed is not None else jax.random.PRNGKey(0)
    data = hydra.utils.instantiate(cfg.data.dataset)
    equation = data.equation
    model = hydra.utils.instantiate(
        cfg.model,
        dim_in=data.dim, 
        dim_out_base=[data.nt] + [data.nx] * (data.dim-1), 
        dim_out_test=[data.nt] + [data.nx] * (data.dim-1),
        domain=equation.domain
    )
    
    for n_train in tqdm(list(cfg.n_train_list)):    
        train_losses = []
        test_losses = []
        pi_losses = []
    
        for i in range(cfg.n_init):  # 20 initializations
            X_train, y_train, y_clean = data[i]
            X_test, y_test = data.resample_data(X_train, y_train, key)
            X_train, y_train = shuffle(X_train, y_train, random_state=cfg.random_seed)
            X_test, X_val, y_test, y_val = train_test_split(X_test, y_test, test_size=0.2, random_state=cfg.random_seed)

            X_train = X_train[:n_train]
            y_train = y_train[:n_train]
            
            # Optuna optimization
            def objective(trial):
                lambda_L2 = trial.suggest_float('lambda_L2', cfg.lambda_L2_range[0], cfg.lambda_L2_range[1])
                if use_pi:
                    lambda_eq = trial.suggest_float('lambda_eq', cfg.lambda_eq_range[0], cfg.lambda_eq_range[1])
                else:
                    lambda_eq = 0.
                
                if isinstance(equation, LinearDE):
                    model.fit_exact(X_train, y_train, equation=equation, lambda_L2=lambda_L2, lambda_eq=lambda_eq)
                    return model.score(X_val, y_val)
                else:
                    model.fit_gradient_descent(
                        X_train, 
                        y_train,
                        X_val=X_val, 
                        y_val=y_val, 
                        lambda_L2=lambda_L2, 
                        lambda_eq=lambda_eq, 
                        equation=equation,
                        weight_init=cfg.optimizer_trial.weight_init,
                        optimizer=cfg.optimizer_trial.name,
                        learning_rate=cfg.optimizer_trial.learning_rate,
                        scheduler=cfg.optimizer_trial.scheduler,
                        epochs=cfg.optimizer_trial.epochs,
                    )
                    return model.score(X_val, y_val)
                
            study = optuna.create_study(direction='minimize')
            study.optimize(objective, n_trials=cfg.trial)
            
            best_trial = study.best_trial
            best_lambda_L2 = best_trial.params['lambda_L2']
            if use_pi:
                best_lambda_eq = best_trial.params['lambda_eq']
            else:
                best_lambda_eq = 0.
            
            if isinstance(equation, LinearDE):
                model.fit_exact(X_train, y_train, equation=equation, lambda_L2=best_lambda_L2, lambda_eq=best_lambda_eq)
            else:
                model.fit_gradient_descent(
                    X_train, 
                    y_train,
                    X_val=X_val, 
                    y_val=y_val, 
                    lambda_L2=best_lambda_L2, 
                    lambda_eq=best_lambda_eq, 
                    equation=equation,
                    weight_init=cfg.optimizer.weight_init,
                    optimizer=cfg.optimizer.name,
                    learning_rate=cfg.optimizer.learning_rate,
                    scheduler=cfg.optimizer.scheduler,
                )
            
            test_loss = model.score(X_test, y_test)
            train_loss = model.score(X_train, y_train)
            pi_loss = equation.pi_loss(model.basis_function, model.test_function, model.weights)
            
            train_losses.append(train_loss)
            test_losses.append(test_loss)
            pi_losses.append(pi_loss)
        
        dim_out = model.basis_function.dim
        entry = dict(
            n_train=n_train,
            dim_out=dim_out, 
            test_loss_mean=np.mean(test_losses), 
            test_loss_std=np.std(test_losses), 
            train_loss_mean=np.mean(train_losses), 
            train_loss_std=np.std(train_losses), 
            pi_loss_mean=np.mean(pi_losses), 
            pi_loss_std=np.std(pi_losses),
            use_pi=use_pi
        )
        
        use_pi_str = 'pi' if use_pi else 'lr'
        pd.DataFrame([entry]).to_csv(out_dir / f'results_{n_train}_{dim_out}_{use_pi_str}.csv', index=False)
        
    del model
    del data
    gc.collect()

def load_and_concatenate_csvs(directory):
    all_files = [os.path.join(directory, f) for f in os.listdir(directory) if f.endswith('.csv')]
    df_list = [pd.read_csv(file) for file in all_files]
    concatenated_df = pd.concat(df_list, ignore_index=True)
    return concatenated_df

def plot_graphs(df, out_dir):
    # n_trainの値ごとにループ
    df_use_pi_true = df[df['use_pi'] == True]
    df_non_use_pi = df[df['use_pi'] == False]
    
    n_train_values = sorted(df_use_pi_true['n_train'].unique().tolist())
    
    for n_train_value in n_train_values:
        df_n_train_use_pi = df_use_pi_true[df_use_pi_true['n_train'] == n_train_value]
        df_n_train_nonuse_pi = df_non_use_pi[df_non_use_pi['n_train'] == n_train_value]
        df_n_train_use_pi = df_n_train_use_pi.sort_values('dim_out')
        df_n_train_nonuse_pi = df_n_train_nonuse_pi.sort_values('dim_out')
        plt.errorbar(df_n_train_use_pi['dim_out'], df_n_train_use_pi['test_loss_mean'], yerr=df_n_train_use_pi['test_loss_std'], label=f'PILR (n={n_train_value})')
        plt.errorbar(df_n_train_nonuse_pi['dim_out'], df_n_train_nonuse_pi['test_loss_mean'], yerr=df_n_train_nonuse_pi['test_loss_std'], label=f'RR (n={n_train_value})')

    plt.xlabel('number of parameters')
    plt.ylabel('MSE on the test data')
    plt.yscale('log')
    plt.legend(bbox_to_anchor=(1.1, 1.25), ncol=3)
    plt.savefig(f'{out_dir}/results.png', dpi=300)
    
    
# wandb setup
@hydra.main(version_base=None, config_path="../config", config_name="sweep_numel.yaml")
def main(cfg: DictConfig):
    if cfg.sweep:
        sweep(cfg, use_pi=cfg.use_pi)
    
    if cfg.visualize:
        out_dir = Path(hydra.utils.get_original_cwd()) / cfg.out_dir
        df = load_and_concatenate_csvs(out_dir)
        plot_graphs(df, out_dir)

if __name__ == "__main__":
    main()
