import os
import sys
import gc

if not __package__:
    # Make CLI runnable from source tree with
    #    python src/package
    package_source_path = os.path.dirname(os.path.dirname(__file__))
    sys.path.insert(0, package_source_path)

import wandb
import optuna
import hydra
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import scienceplots
from sklearn.model_selection import train_test_split
from omegaconf import DictConfig, OmegaConf
from pi_lr import Data
from jax import config
from tqdm import tqdm
from pathlib import Path
import pandas as pd
config.update("jax_enable_x64", True)
OmegaConf.register_new_resolver("mul_pi", lambda x: jnp.pi * x)

plt.style.use(['science', 'no-latex'])
plt.rcParams.update({"font.size": 8}) 

def sweep(cfg, use_pi=False):
    out_dir = Path(hydra.utils.get_original_cwd()) / cfg.out_dir
    os.makedirs(out_dir, exist_ok=True)
    
    data = hydra.utils.instantiate(cfg.data.dataset)
    equation = data.equation
    
    for dim_out in cfg.dim_out_list:
        model = hydra.utils.instantiate(
            cfg.model,
            dim_in=data.dim,
            dim_out_base=[cfg.model.dim_out_base[0], dim_out],
            domain=equation.domain
        )
        equation.M = None
        
        for n_train in tqdm(list(cfg.n_train_list)):    
            train_losses = []
            test_losses = []
            pi_losses = []

            for i in range(cfg.n_init):  # 20 initializations
                X, y, y_clean = data[i]
                X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=cfg.random_seed)
                X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.25, random_state=cfg.random_seed)

                X_train = X_train[:n_train]
                y_train = y_train[:n_train]
                
                # Optuna optimization
                def objective(trial):
                    lambda_L2 = trial.suggest_float('lambda_L2', 1e-12, 1e-2)
                    if use_pi:
                        lambda_eq = trial.suggest_float('lambda_eq', 1e-12, 1e-2)
                    else:
                        lambda_eq = 0.
                    model.fit_exact(X_train, y_train, equation=equation, lambda_L2=lambda_L2, lambda_eq=lambda_eq)
                    return model.score(X_val, y_val)
                study = optuna.create_study(direction='minimize')
                study.optimize(objective, n_trials=cfg.trial)
                
                best_trial = study.best_trial
                best_lambda_L2 = best_trial.params['lambda_L2']
                if use_pi:
                    best_lambda_eq = best_trial.params['lambda_eq']
                else:
                    best_lambda_eq = 0.
                    
                model.fit_exact(X_train, y_train, equation=equation, lambda_L2=best_lambda_L2, lambda_eq=best_lambda_eq)
                
                test_loss = model.score(X_test, y_test)
                train_loss = model.score(X_train, y_train)
                pi_loss = equation.pi_loss(model.basis_function, model.test_function, model.weights)
                
                train_losses.append(train_loss)
                test_losses.append(test_loss)
                pi_losses.append(pi_loss)
            
            dim_out = model.basis_function.dim
            entry = dict(
                n_train=n_train,
                dim_out=dim_out, 
                test_loss_mean=np.mean(test_losses), 
                test_loss_std=np.std(test_losses), 
                train_loss_mean=np.mean(train_losses), 
                train_loss_std=np.std(train_losses), 
                pi_loss_mean=np.mean(pi_losses), 
                pi_loss_std=np.std(pi_losses),
                use_pi=use_pi
            )
            
            use_pi_str = 'pi' if use_pi else 'lr'
            pd.DataFrame([entry]).to_csv(out_dir / f'results_{n_train}_{dim_out}_{use_pi_str}.csv', index=False)
                
        del model
        gc.collect()
        
    del data
    gc.collect()
    

def load_and_concatenate_csvs(directory):
    all_files = [os.path.join(directory, f) for f in os.listdir(directory) if f.endswith('.csv')]
    df_list = [pd.read_csv(file) for file in all_files]
    concatenated_df = pd.concat(df_list, ignore_index=True)
    return concatenated_df

def plot_graphs(df, out_dir):
    # n_trainの値ごとにループ
    df_use_pi_true = df[df['use_pi'] == True]
    df_non_use_pi = df[df['use_pi'] == False]
    
    n_train_values = sorted(df_use_pi_true['n_train'].unique().tolist())
    
    for n_train_value in n_train_values:
        df_n_train_use_pi = df_use_pi_true[df_use_pi_true['n_train'] == n_train_value]
        df_n_train_nonuse_pi = df_non_use_pi[df_non_use_pi['n_train'] == n_train_value]
        df_n_train_use_pi = df_n_train_use_pi.sort_values('dim_out')
        df_n_train_nonuse_pi = df_n_train_nonuse_pi.sort_values('dim_out')
        plt.errorbar(df_n_train_use_pi['dim_out'], df_n_train_use_pi['test_loss_mean'], yerr=df_n_train_use_pi['test_loss_std'], label=f'PILR (n={n_train_value})')
        plt.errorbar(df_n_train_nonuse_pi['dim_out'], df_n_train_nonuse_pi['test_loss_mean'], yerr=df_n_train_nonuse_pi['test_loss_std'], label=f'RR (n={n_train_value})')

    plt.xlabel('number of parameters')
    plt.ylabel('MSE on the test data')
    plt.yscale('log')
    plt.legend(bbox_to_anchor=(1.1, 1.25), ncol=3)
    plt.savefig(f'{out_dir}/results.png', dpi=300)
    

# wandb setup
@hydra.main(version_base=None, config_path="../config", config_name="sweep.yaml")
def main(cfg: DictConfig):
    if cfg.sweep:
        sweep(cfg, use_pi=cfg.use_pi)
    
    if cfg.visualize:
        out_dir = Path(hydra.utils.get_original_cwd()) / cfg.out_dir
        df = load_and_concatenate_csvs(out_dir)
        plot_graphs(df, out_dir)

if __name__ == "__main__":
    main()
