import os
import pickle
import sys
import time
import yaml
import itertools
import numpy as np
import subprocess

import scipy.sparse as sp
import scipy.sparse.linalg as splinalg
# # If yoy want to see Jax complie log:
# os.environ["JAX_LOG_COMPILES"] = "1"

def select_free_gpu():
    result = subprocess.run(
        ["nvidia-smi", "--query-gpu=index,memory.used", "--format=csv,noheader,nounits"],
        stdout=subprocess.PIPE,
        text=True
    )
    gpu_info = result.stdout.strip().split("\n")
    
    gpu_mem = [int(x.split(",")[1]) for x in gpu_info]
    
    free_gpu = str(gpu_mem.index(min(gpu_mem)))
    print(f"Selecting GPU {free_gpu} (used memory: {gpu_mem[int(free_gpu)]} MiB)")
    
    os.environ["CUDA_VISIBLE_DEVICES"] = free_gpu

select_free_gpu()


import jax
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)
import optax
import pennylane as qml

import logging
logging.basicConfig( format = "%(asctime)s - %(filename)s  : %(message)s "
                , level = logging.INFO
                , datefmt = "%I:%M:%S"
                )

    
def setting_path(args):
    args.PDE_INFO = f"{args.equation}_{args.pde_parameter}_{args.basistype}_{args.boundary_condition}"
    args.TRAIN_INFO = (args.exp_name
                       +'_'+ f'{args.cur_time}'
                       +'_'+ f'n_{args.n_qubits}'
                       +'_'+ args.ansatz_name
                       +'_'+ f'n_layers_{args.n_layers}'
                       +'_'+ args.precondition
                       )
    args.RESULT_PATH = os.path.join(f'training_general_{args.optimizer_type}'
                                    , args.PDE_INFO
                                    , args.TRAIN_INFO
                                    )
    args.CKPT_PATH = os.path.join(args.RESULT_PATH , 'loss_and_ckpt')
    args.SOLUTION_PATH = os.path.join(args.RESULT_PATH , 'solution')
    

    if args.DB_forcingtype == 'uniform':
        args.DB_NAME = ( args.PDE_INFO
                        + f'_n_qubits_{args.n_qubits}_N_{args.N}'
                        +'_train_'+ str(args.train_size)
                        +'_'+ args.DB_forcingtype
                        +'_'+ f'{args.DB_sin_mean:.2f}'
                        +'_'+ f'{args.DB_sin_sd:.2f}'
                        +'_'+ f'{args.DB_cos_mean:.2f}'
                        +'_'+ f'{args.DB_cos_sd:.2f}'
                        )
    elif args.DB_forcingtype == 'ones':
        args.DB_NAME = (args.PDE_INFO 
                                # +'_'+ SPECTRAL_INFO
                                +'_'+ args.train_forcingtype
                                )
    else:
        raise NotImplementedError
    args.DB_FOLDER = os.path.join(args.RESULT_PATH, 'data')
    args.DB_ROOT = os.path.join(args.DB_FOLDER, args.DB_NAME+'.pickle')
    return args

if __name__ == "__main__":
    BASE_DIR = os.path.dirname(os.path.abspath(__file__))
    from src.utils import setting
    args, config = setting(BASE_DIR)

    args = setting_path(args)

    assert args.optimizer_type == 'LBFGS', f'Optimizer is {args.optimizer_type}, not LBFGS.'
    assert args.batch_train == args.train_size , f'batch_train: {args.batch_train} is not equal to train_size: {args.train_size}'
    assert args.batch_test == args.test_size , f'batch_test: {args.batch_test} is not equal to test_size: {args.test_size}'

    logging.info('Experiment Results are save in ' + args.RESULT_PATH)
    os.makedirs(args.RESULT_PATH, exist_ok=True)
    os.makedirs(args.SOLUTION_PATH, exist_ok=True)
    os.makedirs(args.CKPT_PATH, exist_ok=True)

    args.lr = None
    config['lr'] = None

    # Save stdout & stderr
    sys.stdout = open(os.path.join(args.RESULT_PATH, 'stdout.txt'), 'w')
    sys.stderr = open(os.path.join(args.RESULT_PATH, 'stderr.txt'), 'w')
    from src.utils import save_config
    save_config(config, args.RESULT_PATH)



    batch_train = args.batch_train
    batch_test = args.batch_test

    n_qubits = args.n_qubits
    n_layers = args.n_layers    
    output_dim = args.output_dim 
    NN = args.N
    nn_input_dim = args.nn_input_dim
    sol_dim = args.sol_dim
    
    num_train_batches = args.train_size // batch_train
    num_test_batches = args.test_size // batch_test

    
    logging_epoch = args.EPOCH // 20

    key = jax.random.PRNGKey(args.jax_seed)

    args.key = key


    from src.utils import PlotSystem
    PlotSystem = PlotSystem()

    from src.orthogonal import LegendreSystemJax
    OrthogonalSystem = LegendreSystemJax(args)
    xx = OrthogonalSystem.xx
    S_1D = OrthogonalSystem.S_1D
    R_1D = OrthogonalSystem.R_1D
    M_1D = OrthogonalSystem.M_1D

    '''
        A = S + k^2 * M
    '''
    B_matrix, C_matrix = S_1D, M_1D


    from src.database import DatabaseJax
    Database = DatabaseJax(args, OrthogonalSystem, B_matrix, C_matrix)
    DB, DB_INFO, TRAIN_DB, TEST_DB = Database.database_load(args.DB_ROOT)

    x_1d = DB_INFO['x_1d']
    B_matrix = DB_INFO['B_matrix']
    C_matrix = DB_INFO['C_matrix']

    train_A_matrix = TRAIN_DB['A_matrix']
    train_pde_param = TRAIN_DB['pde_param']
    train_forcing = TRAIN_DB['forcing']
    train_RHS = TRAIN_DB['RHS']
    train_coeff = TRAIN_DB['coeff']
    train_u_true = TRAIN_DB['u_true']

    test_A_matrix = TEST_DB['A_matrix']
    test_pde_param = TEST_DB['pde_param']
    test_forcing = TEST_DB['forcing']
    test_RHS = TEST_DB['RHS']
    test_coeff = TEST_DB['coeff']
    test_u_true = TEST_DB['u_true']


    
    print(f'train_A_matrix = {train_A_matrix}')
    print(f'train_pde_param = {train_pde_param}')
    print(f'train_forcing = {train_forcing.shape}')
    print(f'train_RHS = {train_RHS.shape}')
    print(f'train_coeff = {train_coeff.shape}')
    print(f'train_u_true = {train_u_true.shape}')

    print(f'test_A_matrix = {test_A_matrix}')
    print(f'test_pde_param = {test_pde_param}')
    print(f'test_forcing = {test_forcing.shape}')
    print(f'test_RHS = {test_RHS.shape}')
    print(f'test_coeff = {test_coeff.shape}')
    print(f'test_u_true = {test_u_true.shape}')


    from src.network import JAXMLP
    angle_net = JAXMLP(output_dim = output_dim
                       , N=NN)
    
    # optimizer = optax.adam(args.lr)
    optimizer = optax.lbfgs() # L-BFGS

    from src.network import init_params_opt
    nn_params = init_params_opt(args, angle_net)
    opt_state = optimizer.init(nn_params)


    from src.vqls_double import VQLSJax
    VQLS = VQLSJax(args, angle_net, B_matrix, C_matrix)


    # jit functions
    loss_fn = VQLS.make_cost_loc()
    predict_jit = VQLS.make_predict()

    reconstruct_1D = OrthogonalSystem.reconstruct_1D
    Relative_L2_1D = OrthogonalSystem.Relative_L2_1D
    Relative_Linf_1D = OrthogonalSystem.Relative_Linf_1D

    from functools import partial
    @ jax.jit
    def train_step(nn_params, opt_state, batch_pde_param, batch_forcing, batch_RHS):

        loss_fn_partial = partial(loss_fn, batch_pde_param=batch_pde_param, batch_forcing=batch_forcing, batch_RHS=batch_RHS)
        loss, grad = jax.value_and_grad(loss_fn_partial)(nn_params)
        updates, opt_state = optimizer.update(grad, opt_state, nn_params
                                              , value=loss, grad=grad, value_fn=loss_fn_partial
                                              ) # LBFGS
        nn_params = optax.apply_updates(nn_params, updates)
        return loss, nn_params, opt_state

    @ jax.jit
    def eval_step(nn_params, batch_pde_param, batch_A_matrix, batch_forcing, batch_RHS, batch_sol_exact):
        alpha_predict = predict_jit(nn_params, batch_pde_param, batch_A_matrix, batch_forcing, batch_RHS)
        # coeff_pred = coeff_func(alpha_predict)
        # sol_pred  = reconstruct_1D(coeff_pred)
        sol_pred  = reconstruct_1D(alpha_predict)
        rel_l2 = Relative_L2_1D(sol_pred, batch_sol_exact)
        rel_linf = Relative_Linf_1D(sol_pred, batch_sol_exact)
        return sol_pred, rel_l2, rel_linf

    @ jax.jit
    def train_eval_step(nn_params, opt_state, batch_pde_param, batch_A_matrix, batch_forcing, batch_RHS, batch_sol_exact):
        loss, nn_params, opt_state = train_step(nn_params, opt_state, batch_pde_param, batch_forcing, batch_RHS)
        sol_pred, rel_l2, rel_linf = eval_step(nn_params, batch_pde_param, batch_A_matrix, batch_forcing, batch_RHS, batch_sol_exact)
        return nn_params, opt_state, loss, sol_pred, rel_l2, rel_linf

    @ jax.jit
    def test_eval_step(nn_params, batch_pde_param, batch_A_matrix, batch_forcing, batch_RHS, batch_sol_exact):
        test_loss = loss_fn(nn_params, batch_pde_param, batch_forcing, batch_RHS)
        sol_pred, rel_l2, rel_linf = eval_step(nn_params, batch_pde_param, batch_A_matrix, batch_forcing, batch_RHS, batch_sol_exact)
        return test_loss, sol_pred, rel_l2, rel_linf


    
    best_cost = None
    best_ckpt = None
    train_ind = 0
    test_ind = 0
    TRAIN_COST_HIST = []
    TRAIN_L2_HIST = []
    TRAIN_Linf_HIST = []
    
    TEST_COST_HIST = []
    TEST_L2_HIST = []
    TEST_Linf_HIST = []

    total_start = time.time()
    for ep in range(args.EPOCH):
        batch_train_loss_hist = []
        batch_test_loss_hist = []
        batch_train_l2_hist = []
        batch_train_linf_hist = []
        batch_test_l2_hist = []
        batch_test_linf_hist = []

        start = time.time()

        ##### Train #####
        for batch_ind in range(num_train_batches):
            train_batch_pde_param = train_pde_param[batch_ind * batch_train : (batch_ind + 1) * batch_train]
            train_batch_A_matrix = train_A_matrix[batch_ind * batch_train : (batch_ind + 1) * batch_train]
            train_batch_forcing = train_forcing[batch_ind * batch_train : (batch_ind + 1) * batch_train]
            train_batch_RHS = train_RHS[batch_ind * batch_train : (batch_ind + 1) * batch_train]
            train_batch_sol_exact = train_u_true[batch_ind * batch_train : (batch_ind + 1) * batch_train]
            # Training & Eval
            # loss, nn_params, opt_state = train_step(nn_params, opt_state, x_batch)
            (nn_params, opt_state
             , train_loss, train_batch_sol_pred
             , train_batch_rel_l2, train_batch_rel_linf
             ) = train_eval_step(nn_params, opt_state,
                                  train_batch_pde_param, train_batch_A_matrix, train_batch_forcing, train_batch_RHS, train_batch_sol_exact)

            batch_train_loss_hist.append(jnp.mean(train_loss).item())
            batch_train_l2_hist.append(jnp.mean(train_batch_rel_l2).item())
            batch_train_linf_hist.append(jnp.mean(train_batch_rel_linf).item())


            if (ep+1)%logging_epoch ==0 and batch_ind == 0:
                PlotSystem.Solution_1D(xx
                , pred_u_dict={'Predicted solution': train_batch_sol_pred[train_ind]}
                , true_u= train_batch_sol_exact[train_ind]
                , NORM={'RelL2':train_batch_rel_l2[train_ind].item()
                        , 'RelLinf':train_batch_rel_linf[train_ind].item()
                        }
                , SAVE_PATH= os.path.join(args.SOLUTION_PATH
                                          , f'train_ind_{train_ind}_solution_epoch_{ep+1}.png')
                , SHOW=False, COLOR='b')
                
                
        ##### Test #####
        for batch_ind in range(num_test_batches):
            test_batch_pde_param = test_pde_param[batch_ind * batch_train : (batch_ind + 1) * batch_train]
            test_batch_A_matrix = test_A_matrix[batch_ind * batch_train : (batch_ind + 1) * batch_train]
            test_batch_forcing = test_forcing[batch_ind * batch_train : (batch_ind + 1) * batch_train]
            test_batch_RHS = test_RHS[batch_ind * batch_train : (batch_ind + 1) * batch_train]
            test_batch_sol_exact = test_u_true[batch_ind * batch_test : (batch_ind + 1) * batch_test]

            (test_loss, test_batch_sol_pred, test_batch_rel_l2, test_batch_rel_linf
             ) = test_eval_step(nn_params
                                ,test_batch_pde_param , test_batch_A_matrix, test_batch_forcing, test_batch_RHS, test_batch_sol_exact)

            batch_test_loss_hist.append(jnp.mean(test_loss).item())
            batch_test_l2_hist.append(jnp.mean(test_batch_rel_l2).item())
            batch_test_linf_hist.append(jnp.mean(test_batch_rel_linf).item())

            if (ep+1)%logging_epoch ==0 and batch_ind == 0:
                PlotSystem.Solution_1D(xx
                , pred_u_dict={'Predicted solution': test_batch_sol_pred[test_ind] }
                , true_u= test_batch_sol_exact[test_ind]
                , NORM={'RelL2':test_batch_rel_l2[test_ind].item()
                        , 'RelLinf':test_batch_rel_linf[test_ind].item()
                        }
                , SAVE_PATH= os.path.join(args.SOLUTION_PATH, f'test_ind_{train_ind}_solution_epoch_{ep+1}.png')
                , SHOW=False
                , COLOR='g'
                )

        elapsed_time = time.time() - start
        total_elapsed_time = time.time() - total_start

        epoch_train_loss_mean = jnp.mean(jnp.array(batch_train_loss_hist)).item()
        epoch_train_rel_l2_mean = jnp.mean(jnp.array(batch_train_l2_hist)).item()
        epoch_train_rel_linf_mean = jnp.mean(jnp.array(batch_train_linf_hist)).item()

        epoch_test_loss_mean = jnp.mean(jnp.array(batch_test_loss_hist)).item()
        epoch_test_rel_l2_mean = jnp.mean(jnp.array(batch_test_l2_hist)).item()
        epoch_test_rel_linf_mean = jnp.mean(jnp.array(batch_test_linf_hist)).item()

        TRAIN_COST_HIST.append(epoch_train_loss_mean)
        TEST_COST_HIST.append(epoch_test_loss_mean)
        TRAIN_L2_HIST.append(epoch_train_rel_l2_mean)
        TRAIN_Linf_HIST.append(epoch_train_rel_linf_mean)
        TEST_L2_HIST.append(epoch_test_rel_l2_mean)
        TEST_Linf_HIST.append(epoch_test_rel_linf_mean)

        if (best_cost is None) or (epoch_train_loss_mean < best_cost):
            best_cost = epoch_train_loss_mean
            best_ckpt = {
                        'params': nn_params,
                        'opt_state': opt_state,
                        'epoch': ep
                        }
        if ep%logging_epoch ==0 or ep == args.EPOCH-1:
            print(f"[Jax, GPU] Epoch {ep+1:02d}, time = {elapsed_time:.6f} sec, elapsed time: {total_elapsed_time:.6f} loss = {epoch_train_loss_mean}")
    
            PlotSystem.Losses({'Cost (Train)': [TRAIN_COST_HIST, 'b']
                               , 'Cost (Test)': [TEST_COST_HIST, 'g']
                               }
                            , SAVEPATH=os.path.join(args.RESULT_PATH, 'cost_mean_history.png')
                            , log_scale=True, SHOW=False)
            PlotSystem.Losses({r'Rel $L_2$ Error (Train)':[TRAIN_L2_HIST, 'b--']
                            ,r'Rel $L_\text{inf}$ Error (Train)':[TRAIN_Linf_HIST, 'b']
                            ,r'Rel $L_2$ Error (Test)':[TEST_L2_HIST, 'g--']
                            ,r'Rel $L_\text{inf}$ Error (Test)':[TEST_Linf_HIST, 'g']
                            }, SAVEPATH=os.path.join(args.RESULT_PATH, 'total_error_mean.png')
                            , log_scale=True
                            , SHOW=False
                            )
            PlotSystem.Losses({r'Relative $L_2$ Error':[TRAIN_L2_HIST, 'b--']
                            ,r'Relative $L_\text{inf}$ Error':[TRAIN_Linf_HIST, 'b']
                            }, SAVEPATH=os.path.join(args.RESULT_PATH, 'train_error_mean.png')
                            , log_scale=True
                            , SHOW=False)
            PlotSystem.Losses({r'Relative $L_2$ Error':[TEST_L2_HIST, 'g--']
                            ,r'Relative $L_\text{inf}$ Error':[TEST_Linf_HIST, 'g']
                            }, SAVEPATH=os.path.join(args.RESULT_PATH, 'test_error_mean.png')
                            , log_scale=True
                            , SHOW=False
                            # , COLOR='g'
                            )
        
    total_elapsed_time = time.time() - total_start
    print(f"[Jax, GPU] total time = {total_elapsed_time:.6f} sec")
    
    # Save Best CKPT
    with open(os.path.join(args.CKPT_PATH, 'best_cost.txt'), 'w') as f:
        f.write(f"Best Cost       : {best_cost}\n")
        f.write(f"Best Cost (Test): {min(TEST_COST_HIST)}\n")
        f.write(f"Best Train L2   : {min(TRAIN_L2_HIST)}\n")
        f.write(f"Best Train Linf : {min(TRAIN_Linf_HIST)}\n")
        f.write(f"Best Test L2    : {min(TEST_L2_HIST)}\n")
        f.write(f"Best Test Linf  : {min(TEST_Linf_HIST)}\n")
    checkpoint = {
                'params': nn_params,
                'opt_state': opt_state,
                'epoch': ep
                }
    with open( os.path.join(args.CKPT_PATH, f'nn_ckpt_epoch_{ep+1}.pkl'), 'wb') as f:
        pickle.dump(checkpoint, f)
    with open( os.path.join(args.CKPT_PATH, 'nn_best_ckpt.pkl'), 'wb') as f:
        pickle.dump(best_ckpt, f)
    with open( os.path.join(args.CKPT_PATH, 'TRAIN_COST_HIST.pkl'), 'wb') as f:
        pickle.dump(TRAIN_COST_HIST, f)
    with open( os.path.join(args.CKPT_PATH, 'TRAIN_L2_HIST.pkl'), 'wb') as f:
        pickle.dump(TRAIN_L2_HIST, f)
    with open( os.path.join(args.CKPT_PATH, 'TRAIN_Linf_HIST.pkl'), 'wb') as f:
        pickle.dump(TRAIN_Linf_HIST, f)
    with open( os.path.join(args.CKPT_PATH, 'TEST_COST_HIST.pkl'), 'wb') as f:
        pickle.dump(TEST_COST_HIST, f)
    with open( os.path.join(args.CKPT_PATH, 'TEST_L2_HIST.pkl'), 'wb') as f:
        pickle.dump(TEST_L2_HIST, f)
    with open( os.path.join(args.CKPT_PATH, 'TEST_Linf_HIST.pkl'), 'wb') as f:
        pickle.dump(TEST_Linf_HIST, f)

    # Solution: Exact & Prediction
    with open( os.path.join(args.SOLUTION_PATH, 'x_1d.pkl'), 'wb') as f:
        pickle.dump(x_1d, f)

    with open( os.path.join(args.SOLUTION_PATH, 'train_batch_pde_param.pkl'), 'wb') as f:
        pickle.dump(train_batch_pde_param, f)
    with open( os.path.join(args.SOLUTION_PATH, 'train_batch_RHS.pkl'), 'wb') as f:
        pickle.dump(train_batch_RHS, f)
    with open( os.path.join(args.SOLUTION_PATH, 'train_batch_sol_exact.pkl'), 'wb') as f:
        pickle.dump(train_batch_sol_exact, f)
    with open( os.path.join(args.SOLUTION_PATH, 'train_batch_sol_pred.pkl'), 'wb') as f:
        pickle.dump(train_batch_sol_pred, f)

    with open( os.path.join(args.SOLUTION_PATH, 'test_batch_pde_param.pkl'), 'wb') as f:
        pickle.dump(test_batch_pde_param, f)
    with open( os.path.join(args.SOLUTION_PATH, 'test_batch_RHS.pkl'), 'wb') as f:
        pickle.dump(test_batch_RHS, f)
    with open( os.path.join(args.SOLUTION_PATH, 'test_batch_sol_exact.pkl'), 'wb') as f:
        pickle.dump(test_batch_sol_exact, f)
    with open( os.path.join(args.SOLUTION_PATH, 'test_batch_sol_pred.pkl'), 'wb') as f:
        pickle.dump(test_batch_sol_pred, f)