import os
import pickle
import sys
import time
import yaml
import itertools
import numpy as np


import jax
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)
import optax
import pennylane as qml

import logging
logging.basicConfig( format = "%(asctime)s - %(filename)s  : %(message)s "
                , level = logging.INFO
                , datefmt = "%I:%M:%S"
                )

    

if __name__ == "__main__":

    BASE_DIR = os.path.dirname(os.path.abspath(__file__))
    from src.utils import setting
    args, config = setting(BASE_DIR)

    assert args.optimizer_type == 'LBFGS', f'Optimizer is {args.optimizer_type}, not LBFGS.'
    assert args.batch_train == args.train_size , f'batch_train: {args.batch_train} is not equal to train_size: {args.train_size}'


    logging.info('Experiment Results are save in ' + args.RESULT_PATH)
    os.makedirs(args.RESULT_PATH, exist_ok=True)
    os.makedirs(args.SOLUTION_PATH, exist_ok=True)
    os.makedirs(args.CKPT_PATH, exist_ok=True)


    # Save stdout & stderr
    sys.stdout = open(os.path.join(args.RESULT_PATH, 'stdout.txt'), 'w')
    sys.stderr = open(os.path.join(args.RESULT_PATH, 'stderr.txt'), 'w')
    from src.utils import save_config
    save_config(config, args.RESULT_PATH)




    batch_train = args.batch_train
    batch_test = args.batch_test

    n_qubits = args.n_qubits
    n_layers = args.n_layers    
    output_dim = args.output_dim 
    NN = args.N
    nn_input_dim = args.nn_input_dim
    sol_dim = args.sol_dim
    
    num_train_batches = args.train_size // batch_train
    num_test_batches = args.test_size // batch_test
    logging_epoch = args.EPOCH // 20

    key = jax.random.PRNGKey(args.jax_seed)
    args.key = key



    from src.utils import PlotSystem
    PlotSystem = PlotSystem()

    from src.orthogonal import LegendreSystemJax
    OrthogonalSystem = LegendreSystemJax(args)

    xx = OrthogonalSystem.xx
    AA = OrthogonalSystem.AA


    from src.utils import precondition
    precondition = precondition(AA, args.precondition)
    A_tilde = precondition.A_tilde
    PL_inv = precondition.PL_inv
    coeff_func = precondition.make_coeff()


    from src.network import JAXMLP
    angle_net = JAXMLP(output_dim = output_dim
                       , N=NN)
    

    # optimizer = optax.adam(args.lr) # ADAM
    optimizer = optax.lbfgs() # L-BFGS


    from src.network import init_params_opt
    nn_params = init_params_opt(args, angle_net)
    opt_state = optimizer.init(nn_params)

    # if CKPT_PATH:
    #     with open(CKPT_PATH, 'rb') as f:
    #         checkpoint = pickle.load(f)
    #     nn_params = checkpoint['params']
    #     opt_state = checkpoint['opt_state']
    #     ckpt_epoch = checkpoint['epoch']

    from src.vqls_double import VQLSJax
    VQLS = VQLSJax(args
                   , A_tilde
                   , angle_net=angle_net
                   )



    # jit functions
    loss_fn = VQLS.make_cost_loc()
    predict_jit = VQLS.make_predict()
    reconstruct_1D = OrthogonalSystem.reconstruct_1D
    Relative_L2_1D = OrthogonalSystem.Relative_L2_1D
    Relative_Linf_1D = OrthogonalSystem.Relative_Linf_1D

    from functools import partial
    @ jax.jit
    def train_step(nn_params, opt_state, batch_forcing, batch_RHS):
        loss_fn_partial = partial(loss_fn, batch_forcing=batch_forcing, batch_RHS=batch_RHS)
        loss, grad = jax.value_and_grad(loss_fn_partial)(nn_params)
        updates, opt_state = optimizer.update(grad, opt_state, nn_params
                                              , value=loss, grad=grad, value_fn=loss_fn_partial
                                              ) # LBFGS
        nn_params = optax.apply_updates(nn_params, updates)
        return loss, nn_params, opt_state

    @ jax.jit
    def eval_step(nn_params, batch_forcing, batch_RHS, sol_exact_batch):
        alpha_predict = predict_jit(nn_params, batch_forcing, batch_RHS)
        coeff_pred = coeff_func(alpha_predict)
        sol_pred  = reconstruct_1D(coeff_pred)
        rel_l2 = Relative_L2_1D(sol_pred, sol_exact_batch)
        rel_linf = Relative_Linf_1D(sol_pred, sol_exact_batch)
        return sol_pred, rel_l2, rel_linf

    @ jax.jit
    def train_eval_step(nn_params, opt_state, batch_forcing, batch_RHS, batch_sol_exact):
        loss, nn_params, opt_state = train_step(nn_params, opt_state, batch_forcing, batch_RHS)
        sol_pred, rel_l2, rel_linf = eval_step(nn_params, batch_forcing, batch_RHS, batch_sol_exact)
        return nn_params, opt_state, loss, sol_pred, rel_l2, rel_linf

    @ jax.jit
    def test_eval_step(nn_params, batch_forcing, batch_RHS, batch_sol_exact):
        test_loss = loss_fn(nn_params, batch_forcing, batch_RHS)
        sol_pred, rel_l2, rel_linf = eval_step(nn_params, batch_forcing, batch_RHS, batch_sol_exact)
        return test_loss, sol_pred, rel_l2, rel_linf

    from src.database import DatabaseJax
    Database = DatabaseJax(args, AA, OrthogonalSystem)
    DB, DB_INFO, TRAIN_DB, TEST_DB = Database.database_load(args.DB_ROOT)

    train_forcing = TRAIN_DB['forcing']
    # train_bar_f = TRAIN_DB['bar_f']
    train_RHS = TRAIN_DB['RHS']
    train_coeff = TRAIN_DB['coeff']
    train_u_true = TRAIN_DB['u_true']

    test_forcing = TEST_DB['forcing']
    # test_bar_f = TEST_DB['bar_f']
    test_RHS = TEST_DB['RHS']
    test_coeff = TEST_DB['coeff']
    test_u_true = TEST_DB['u_true']
    
    
    best_cost = None
    best_ckpt = None
    train_ind = 0
    test_ind = 0
    TRAIN_COST_HIST = []
    TEST_COST_HIST = []
    TRAIN_L2_HIST = []
    TRAIN_Linf_HIST = []
    TEST_L2_HIST = []
    TEST_Linf_HIST = []

    total_start = time.time()
    for ep in range(args.EPOCH):
        batch_train_loss_hist = []
        batch_test_loss_hist = []
        batch_train_l2_hist = []
        batch_train_linf_hist = []
        batch_test_l2_hist = []
        batch_test_linf_hist = []

        start = time.time()

        ##### Train #####
        for batch_ind in range(num_train_batches):
            train_batch_forcing = train_forcing[batch_ind * batch_train : (batch_ind + 1) * batch_train]
            train_batch_RHS = train_RHS[batch_ind * batch_train : (batch_ind + 1) * batch_train]
            train_batch_sol_exact = train_u_true[batch_ind * batch_train : (batch_ind + 1) * batch_train]

            # Training & Eval
            (nn_params, opt_state, train_loss
             , train_batch_sol_pred, rel_l2_train_batch, rel_linf_train_batch
             ) = train_eval_step(nn_params, opt_state
                                 , train_batch_forcing, train_batch_RHS, train_batch_sol_exact)

            batch_train_loss_hist.append(jnp.mean(train_loss).item())
            batch_train_l2_hist.append(jnp.mean(rel_l2_train_batch).item())
            batch_train_linf_hist.append(jnp.mean(rel_linf_train_batch).item())

            if (ep+1)%logging_epoch ==0 and batch_ind == 0:
                PlotSystem.Solution_1D(xx
                , pred_u_dict={'Predicted solution': train_batch_sol_pred[train_ind]
                                # , 'solution': exact_sol
                                }
                , true_u= train_batch_sol_exact[train_ind]
                , NORM={'RelL2':rel_l2_train_batch[train_ind].item()
                        , 'RelLinf':rel_linf_train_batch[train_ind].item()
                        }
                , SAVE_PATH= os.path.join(args.SOLUTION_PATH, f'train_ind_{train_ind}_solution_epoch_{ep+1}.png')
                , SHOW=False, COLOR='b')
                
                
        ##### Test #####
        for batch_ind in range(num_test_batches):

            test_batch_forcing = test_forcing[batch_ind * batch_train : (batch_ind + 1) * batch_train]
            test_batch_RHS = test_RHS[batch_ind * batch_train : (batch_ind + 1) * batch_train]
            test_batch_sol_exact = test_u_true[batch_ind * batch_test : (batch_ind + 1) * batch_test]

            (test_loss, test_batch_sol_pred, rel_l2_test_batch, rel_linf_test_batch
             ) = test_eval_step(nn_params, test_batch_forcing, test_batch_RHS, test_batch_sol_exact)

            batch_test_loss_hist.append(jnp.mean(test_loss).item())
            batch_test_l2_hist.append(jnp.mean(rel_l2_test_batch).item())
            batch_test_linf_hist.append(jnp.mean(rel_linf_test_batch).item())

            if (ep+1)%logging_epoch ==0 and batch_ind == 0:
                PlotSystem.Solution_1D(xx
                , pred_u_dict={'Predicted solution': test_batch_sol_pred[test_ind] }
                , true_u= test_batch_sol_exact[test_ind]
                , NORM={'RelL2':rel_l2_test_batch[test_ind].item()
                        , 'RelLinf':rel_linf_test_batch[test_ind].item()
                        }
                , SAVE_PATH= os.path.join(args.SOLUTION_PATH, f'test_ind_{train_ind}_solution_epoch_{ep+1}.png')
                , SHOW=False
                , COLOR='g')


        elapsed_time = time.time() - start
        total_elapsed_time = time.time() - total_start

        epoch_train_loss_mean = jnp.mean(jnp.array(batch_train_loss_hist)).item()
        epoch_train_rel_l2_mean = jnp.mean(jnp.array(batch_train_l2_hist)).item()
        epoch_train_rel_linf_mean = jnp.mean(jnp.array(batch_train_linf_hist)).item()

        epoch_test_loss_mean = jnp.mean(jnp.array(batch_test_loss_hist)).item()
        epoch_test_rel_l2_mean = jnp.mean(jnp.array(batch_test_l2_hist)).item()
        epoch_test_rel_linf_mean = jnp.mean(jnp.array(batch_test_linf_hist)).item()

        TRAIN_COST_HIST.append(epoch_train_loss_mean)
        TEST_COST_HIST.append(epoch_test_loss_mean)
        TRAIN_L2_HIST.append(epoch_train_rel_l2_mean)
        TRAIN_Linf_HIST.append(epoch_train_rel_linf_mean)
        TEST_L2_HIST.append(epoch_test_rel_l2_mean)
        TEST_Linf_HIST.append(epoch_test_rel_linf_mean)

        if (best_cost is None) or (epoch_train_loss_mean < best_cost):
            best_cost = epoch_train_loss_mean
            best_ckpt = {
                        'params': nn_params,
                        'opt_state': opt_state,
                        'epoch': ep
                        }

        if (ep+1)%logging_epoch ==0 or (ep+1) == args.EPOCH:
            print(f"[Jax, GPU] Epoch {ep:02d}, time = {elapsed_time:.6f} sec, elapsed time: {total_elapsed_time:.6f} loss = {epoch_train_loss_mean}")        

            PlotSystem.Losses({'Cost (Train)': [TRAIN_COST_HIST, 'b']
                               , 'Cost (Test)': [TEST_COST_HIST, 'g']
                               }
                            , SAVEPATH=os.path.join(args.RESULT_PATH, 'cost_mean_history.png')
                            , log_scale=True, SHOW=False)
            PlotSystem.Losses({r'Train Rel $L_2$ Error':[TRAIN_L2_HIST, 'b']
                            ,r'Train Rel $L_\text{inf}$ Error':[TRAIN_Linf_HIST, 'b']
                            ,r'Test Rel $L_2$ Error':[TEST_L2_HIST, 'g']
                            ,r'Test Rel $L_\text{inf}$ Error':[TEST_Linf_HIST, 'g']
                            }, SAVEPATH=os.path.join(args.RESULT_PATH, 'total_error_mean.png')
                            , log_scale=True
                            , SHOW=False
                            )
            PlotSystem.Losses({r'Relative $L_2$ Error':TRAIN_L2_HIST
                            ,r'Relative $L_\text{inf}$ Error':TRAIN_Linf_HIST
                            }, SAVEPATH=os.path.join(args.RESULT_PATH, 'train_error_mean.png')
                            , log_scale=True
                            , SHOW=False)
            PlotSystem.Losses({r'Relative $L_2$ Error':TEST_L2_HIST
                            ,r'Relative $L_\text{inf}$ Error':TEST_Linf_HIST
                            }, SAVEPATH=os.path.join(args.RESULT_PATH, 'test_error_mean.png')
                            , log_scale=True
                            , SHOW=False
                            , COLOR='g'
                            )
        
    total_elapsed_time = time.time() - total_start
    print(f"[Jax, GPU] total time = {total_elapsed_time:.6f} sec")

    # Save Best CKPT
    with open(os.path.join(args.CKPT_PATH, 'best_cost.txt'), 'w') as f:
        f.write(f"Best Cost       : {best_cost}\n")
        f.write(f"Best Cost (Test): {min(TEST_COST_HIST)}\n")
        f.write(f"Best Train L2   : {min(TRAIN_L2_HIST)}\n")
        f.write(f"Best Train Linf : {min(TRAIN_Linf_HIST)}\n")
        f.write(f"Best Test L2    : {min(TEST_L2_HIST)}\n")
        f.write(f"Best Test Linf  : {min(TEST_Linf_HIST)}\n")
    checkpoint = {
                'params': nn_params,
                'opt_state': opt_state,
                'epoch': ep
                }
    with open( os.path.join(args.CKPT_PATH, f'nn_ckpt_epoch_{ep+1}.pkl'), 'wb') as f:
        pickle.dump(checkpoint, f)
    with open( os.path.join(args.CKPT_PATH, 'nn_best_ckpt.pkl'), 'wb') as f:
        pickle.dump(best_ckpt, f)
    with open( os.path.join(args.CKPT_PATH, 'TRAIN_COST_HIST.pkl'), 'wb') as f:
        pickle.dump(TRAIN_COST_HIST, f)
    with open( os.path.join(args.CKPT_PATH, 'TRAIN_L2_HIST.pkl'), 'wb') as f:
        pickle.dump(TRAIN_L2_HIST, f)
    with open( os.path.join(args.CKPT_PATH, 'TRAIN_Linf_HIST.pkl'), 'wb') as f:
        pickle.dump(TRAIN_Linf_HIST, f)
    with open( os.path.join(args.CKPT_PATH, 'TEST_COST_HIST.pkl'), 'wb') as f:
        pickle.dump(TEST_COST_HIST, f)
    with open( os.path.join(args.CKPT_PATH, 'TEST_L2_HIST.pkl'), 'wb') as f:
        pickle.dump(TEST_L2_HIST, f)
    with open( os.path.join(args.CKPT_PATH, 'TEST_Linf_HIST.pkl'), 'wb') as f:
        pickle.dump(TEST_Linf_HIST, f)

    # Solution: Exact & Prediction
    with open( os.path.join(args.SOLUTION_PATH, 'train_batch_sol_exact.pkl'), 'wb') as f:
        pickle.dump(train_batch_sol_exact, f)
    with open( os.path.join(args.SOLUTION_PATH, 'train_batch_sol_pred.pkl'), 'wb') as f:
        pickle.dump(train_batch_sol_pred, f)
    with open( os.path.join(args.SOLUTION_PATH, 'test_batch_sol_exact.pkl'), 'wb') as f:
        pickle.dump(test_batch_sol_exact, f)
    with open( os.path.join(args.SOLUTION_PATH, 'test_batch_sol_pred.pkl'), 'wb') as f:
        pickle.dump(test_batch_sol_pred, f)