
import pprint
import numpy as np
from joblib import Parallel, delayed
import time 
import random

from SONv4 import Fl
from MCTS import  gen_deci_dicts, single_test, MCTS, SearchNode
from utils import timing, timing_with_return, redirect_log_file, time_to_str

def test_mse_1():

    deci_dict = {
        # --- turbulent flow ---
        "is_turbulent": 0,
        "type_non_turbulent":0,
      
        # --- constitutive equation---
        "is_newtonian": 0,
        "type_non_newtonian": 2,
        "poly_order":1,
        "Fourier_order":1,

        # --- non_isothermal flow ---
        "is_isothermal": 1, 
        # --- compressible ---
        "is_compressible":1,
        'is_stokes_hypothesis':1,

        # --- other independent decisions ---
        "type_body_force":1,
        }
    assert Fl.is_valid_deci(deci_dict)
    sol = Fl.test(deci_dict)
    print(f"{sol['train_loss']=}, {sol['deci_dict']=}, {sol['params']=}", flush=1)
    print(f"{sol['losses']=}")

def test_non_isothermal():
    dataname_tuple = ('2d_non_isothermal_ns',"COMSOL")
    deci_dict = {
        # --- turbulent flow ---
        "is_turbulent": 0,
        "type_non_turbulent":0,
      
        # --- constitutive equation---
        "is_newtonian": 1,
        "type_newtonian": 1,

        # --- non_isothermal flow ---
        "is_isothermal": 0,
        "is_thermal_conductive":1,
        "is_pressure_work":1,
        "is_viscosity_diffusion":1,

        # --- compressible ---
        "is_compressible":1,
        'is_stokes_hypothesis':1,

        # --- other independent decisions ---
        "type_body_force":1,
        }
    sol = Fl.test(deci_dict, dataname_tuple)
    print(f"{sol['train_loss']=}, {sol['deci_dict']=}, {sol['params']=}", flush=1)
    print(f"{sol['losses']=}")

def test_mse_comp():
    dataname_tuple = ('2d_comp_viscose_newton_ns',"COMSOL")
    deci_dict = {
        # --- turbulent flow ---
        "is_turbulent": 0,
        "type_non_turbulent":0,
      
        # --- constitutive equation---
        "is_newtonian": 1,
        "type_newtonian": 1,

        # --- non_isothermal flow ---
        "is_isothermal": 1, 
        # --- compressible ---
        "is_compressible":1,
        'is_stokes_hypothesis':1,

        # --- other independent decisions ---
        "type_body_force":0,
        }
    sol = Fl.test(deci_dict, dataname_tuple)
    print(f"{sol['train_loss']=}, {sol['deci_dict']=}, {sol['params']=}", flush=1)
    print(f"{sol['losses']=}")

def test_mse_incomp():
    dataname_tuple = ('2d_comp_viscose_newton_ns',"COMSOL")
    deci_dict = {
        # --- turbulent flow ---
        "is_turbulent": 0,
        "type_non_turbulent":0,
      

        # --- constitutive equation---
        "is_newtonian": 1,
        "type_newtonian": 1,

        # --- non_isothermal flow ---
        "is_isothermal": 1, 

        # --- compressible ---
        "is_compressible":0,

        # --- other independent decisions ---
        "type_body_force":0,
        }
    sol = Fl.test(deci_dict, dataname_tuple)
    print(f"{sol['train_loss']=}, {sol['deci_dict']=}, {sol['params']=}", flush=1)
    print(f"{sol['losses']=}")

def para_test(deci_dicts, dataname_tuple, datafold_tuple, out_file=None, n_jobs=1, time_limit=None):
    n_parts = 10
    n_dicts_per_part = len(deci_dicts)//n_parts
    sols = []
    start_time = time.time()
    for i in range(n_parts):
        if i < n_parts-1:
            part_i_deci_dicts = deci_dicts[i*n_dicts_per_part:(i+1)*n_dicts_per_part]
        else:
            part_i_deci_dicts = deci_dicts[i*n_dicts_per_part:]
        sols += Parallel(n_jobs=n_jobs)(delayed(single_test)(deci_dict, dataname_tuple, datafold_tuple, out_file) for deci_dict in part_i_deci_dicts)
        
        tot_time_elapsed = time.time() - start_time
        print(f"Part {i} finished, {len(part_i_deci_dicts)=}, {tot_time_elapsed=}", flush=1)
        if time_limit and tot_time_elapsed > time_limit:
            print("timeout.")
            break
    return sols

@timing_with_return
def brute_force_search(dataname_tuple, datafold_tuple=(0,1), out_file=None, n_jobs=10, verbose=False, time_limit=None):
    deci_dicts = gen_deci_dicts(dataset_name=dataname_tuple[0])
    random.shuffle(deci_dicts)
    print(f"{len(deci_dicts)=},{n_jobs=}")

    sols = para_test(deci_dicts, dataname_tuple, datafold_tuple, out_file, n_jobs, time_limit=time_limit)
    top_3_sols = sorted(sols, key=lambda sol:sol['valid_loss'])[:3]
    opt_sol = top_3_sols[0]
    total_time = time_to_str(sum([sol['time'] for sol in sols]))
    print("------result------")
    print("Search terminated successfully. Time used: ", total_time)
    print(f"Final result:")
    pprint.pprint(opt_sol)

    if verbose:
        print("------solution details------")
        for sol in sols:
            print(sol)
    return top_3_sols

def k_fold_cv_bfs(dataname_tuple, k=5, out_file=None, n_jobs=10, verbose=False, time_limit=None):
    print(dataname_tuple)
    search_time_list = []
    top_3_nodes_list = []
    for i in range(k):
        print(f"=== {i+1}-th fold ===", flush=1)
        datafold_tuple = (i,k)
        top_3_sols, search_time = brute_force_search(dataname_tuple, datafold_tuple, out_file, n_jobs, verbose, time_limit=time_limit)
        for sol in top_3_sols:
            pprint.pprint(sol)
        search_time_list.append(search_time)
        top_3_nodes = {SearchNode(sol["deci_dict"]):sol for sol in top_3_sols} # construct nodes for MCTS.summary
        top_3_nodes_list.append(top_3_nodes)
    
    results_list, time_mean_std = MCTS().summary(top_3_nodes_list, search_time_list)
    for r in results_list:
        pprint.pprint(r)
    print("Search time mean={}, std={}".format(*time_mean_std))

if __name__ == "__main__":
    log_root = "./log/FNOv4/"
    #out_file=None
    dataname_tuple_1 =("2d_comp_viscose_newton_ns", "COMSOL")
    dataname_tuple_2 =('2d_comp_viscose_new_non_newton', 'COMSOL')
    dataname_tuple_3 =('2d_heat_comp_v2', 'COMSOL') # max_rollout=60, epsilon=0.05
    dataname_tuple=dataname_tuple_3
    
    exp_name = "BruteForce_"+dataname_tuple[0]
    out_file =  redirect_log_file(log_root, exp_name)

    k_fold_cv_bfs(dataname_tuple=dataname_tuple, out_file=out_file)