import os
import pickle
import sys
import time
import yaml
import itertools
import numpy as np

# If yoy want to see Jax complie log:
# os.environ["JAX_LOG_COMPILES"] = "1"

def select_free_gpu():
    import subprocess
    result = subprocess.run(
        ["nvidia-smi", "--query-gpu=index,memory.used", "--format=csv,noheader,nounits"],
        stdout=subprocess.PIPE,
        text=True
    )
    gpu_info = result.stdout.strip().split("\n")
    
    gpu_mem = [int(x.split(",")[1]) for x in gpu_info]
    
    free_gpu = str(gpu_mem.index(min(gpu_mem)))
    print(f"Selecting GPU {free_gpu} (used memory: {gpu_mem[int(free_gpu)]} MiB)")
    
    os.environ["CUDA_VISIBLE_DEVICES"] = free_gpu

select_free_gpu()


import jax
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)
import optax
import pennylane as qml

import logging
logging.basicConfig( format = "%(asctime)s - %(filename)s  : %(message)s "
                , level = logging.INFO
                , datefmt = "%I:%M:%S"
                )

if __name__ == "__main__":

    BASE_DIR = os.path.dirname(os.path.abspath(__file__))
    from src.utils import setting
    args, config = setting(BASE_DIR)

    assert args.optimizer_type == 'LBFGS', f'Optimizer is {args.optimizer_type}, not Adam.'
    assert args.train_size == args.batch_train, f'Not full-batch: train_size={args.train_size}, batch_train={args.batch_train},'

    logging.info('Experiment Results are save in ' + args.RESULT_PATH)
    os.makedirs(args.RESULT_PATH, exist_ok=True)
    os.makedirs(args.SOLUTION_PATH, exist_ok=True)
    os.makedirs(args.CKPT_PATH, exist_ok=True)

    # Save stdout & stderr
    sys.stdout = open(os.path.join(args.RESULT_PATH, 'stdout.txt'), 'w')
    sys.stderr = open(os.path.join(args.RESULT_PATH, 'stderr.txt'), 'w')
    from src.utils import save_config
    save_config(config, args.RESULT_PATH)



    batch_train = args.batch_train
    batch_test = args.batch_test

    n_qubits = args.n_qubits
    n_layers = args.n_layers    
    output_dim = args.output_dim 
    NN = args.N

    num_train_batches = args.train_size // batch_train
    num_test_batches = args.test_size // batch_test

    logging_epoch = args.EPOCH // 20
    input_dim = args.input_dim

    key = jax.random.PRNGKey(args.jax_seed)
    args.key = key



    from src.utils import PlotSystem
    PlotSystem = PlotSystem()

    from src.orthogonal import LegendreSystemJax
    OrthogonalSystem = LegendreSystemJax(args)

    X_2D = OrthogonalSystem.X_2D
    Y_2D = OrthogonalSystem.Y_2D
    AA = OrthogonalSystem.AA


    from src.utils import precondition
    precondition = precondition(AA, args.precondition)
    A_tilde = precondition.A_tilde
    PL_inv = precondition.PL_inv
    coeff_func = precondition.make_coeff()


    from src.network import JAXMLP
    angle_net = JAXMLP(input_dim = input_dim
                       , output_dim = output_dim
                       , N=NN)
    
    # optimizer = optax.adam(args.lr) # ADAM
    optimizer = optax.lbfgs() # L-BFGS

    from src.network import init_params_opt
    nn_params = init_params_opt(args, angle_net)
    opt_state = optimizer.init(nn_params)


    # if CKPT_PATH:
    #     with open(CKPT_PATH, 'rb') as f:
    #         checkpoint = pickle.load(f)
    #     nn_params = checkpoint['params']
    #     opt_state = checkpoint['opt_state']
    #     ckpt_epoch = checkpoint['epoch']

    from src.vqls_double import VQLSJax
    VQLS = VQLSJax(args
                   , A_tilde
                   , angle_net=angle_net)

    # jit functions
    loss_fn = VQLS.make_cost_loc()
    predict_jit = VQLS.make_predict()
    reconstruct = OrthogonalSystem.reconstruct
    relative_l2 = OrthogonalSystem.relative_l2
    relative_linf = OrthogonalSystem.relative_linf

    from functools import partial
    @ jax.jit
    def train_step(nn_params, opt_state, batch_forcing, batch_RHS):
        loss_fn_partial = partial(loss_fn, batch_forcing=batch_forcing, batch_RHS=batch_RHS)
        loss, grad = jax.value_and_grad(loss_fn_partial)(nn_params)
        updates, opt_state = optimizer.update(grad, opt_state, nn_params
                                              , value=loss, grad=grad, value_fn=loss_fn_partial
                                              ) # LBFGS
        nn_params = optax.apply_updates(nn_params, updates)
        return loss, nn_params, opt_state
    
    @ jax.jit
    def eval_step(nn_params, batch_forcing, batch_RHS, batch_sol_exact):
        alpha_predict = predict_jit(nn_params, batch_forcing, batch_RHS)
        coeff_pred = coeff_func(alpha_predict)
        sol_pred  = reconstruct(coeff_pred)
        rel_l2 = relative_l2(sol_pred, batch_sol_exact)
        rel_linf = relative_linf(sol_pred, batch_sol_exact)
        return sol_pred, rel_l2, rel_linf

    @ jax.jit
    def train_eval_step(nn_params, opt_state, batch_forcing, batch_RHS, batch_sol_exact):
        loss, nn_params, opt_state = train_step(nn_params, opt_state, batch_forcing, batch_RHS)
        sol_pred, rel_l2, rel_linf = eval_step(nn_params, batch_forcing, batch_RHS, batch_sol_exact)
        return nn_params, opt_state, loss, sol_pred, rel_l2, rel_linf

    @ jax.jit
    def test_eval_step(nn_params, batch_forcing, batch_RHS, batch_sol_exact):
        test_loss = loss_fn(nn_params, batch_forcing, batch_RHS)
        sol_pred, rel_l2, rel_linf = eval_step(nn_params, batch_forcing, batch_RHS, batch_sol_exact)
        return test_loss, sol_pred, rel_l2, rel_linf
    
    from src.database import DatabaseJax
    Database = DatabaseJax(args, AA, OrthogonalSystem)
    DB, DB_INFO, TRAIN_DB, TEST_DB = Database.database_load(args.DB_ROOT)


    train_forcing = TRAIN_DB['forcing']
    # train_bar_f= TRAIN_DB['bar_f']
    train_RHS = TRAIN_DB['RHS']
    train_coeff = TRAIN_DB['coeff']
    train_u_true = TRAIN_DB['u_true']

    test_forcing = TEST_DB['forcing']
    # test_bar_f = TEST_DB['bar_f']
    test_RHS = TEST_DB['RHS']
    test_coeff = TEST_DB['coeff']
    test_u_true = TEST_DB['u_true']
    
    best_cost = None
    best_ckpt = None
    train_ind = 0
    test_ind = 0
    TRAIN_COST_HIST = []
    TRAIN_L2_HIST = []
    TRAIN_Linf_HIST = []

    TEST_COST_HIST = []
    TEST_L2_HIST = []
    TEST_Linf_HIST = []

    total_start = time.time()
    for ep in range(args.EPOCH):

        batch_train_loss_hist = []
        batch_train_l2_hist = []
        batch_train_linf_hist = []

        batch_test_loss_hist = []
        batch_test_l2_hist = []
        batch_test_linf_hist = []

        start = time.time()

        ##### Train #####
        for batch_ind in range(num_train_batches):

            train_batch_forcing = train_forcing[batch_ind * batch_train : (batch_ind + 1) * batch_train]
            train_batch_RHS = train_RHS[batch_ind * batch_train : (batch_ind + 1) * batch_train]
            train_batch_sol_exact = train_u_true[batch_ind * batch_train : (batch_ind + 1) * batch_train]

            # Training & Eval
            # loss, nn_params, opt_state = train_step(nn_params, opt_state, x_batch)
            (nn_params, opt_state, train_loss
             , train_batch_sol_pred, rel_l2_train_batch, rel_linf_train_batch
             ) = train_eval_step(nn_params, opt_state
                                 , train_batch_forcing, train_batch_RHS, train_batch_sol_exact)

            batch_train_loss_hist.append(jnp.mean(train_loss).item())
            batch_train_l2_hist.append(jnp.mean(rel_l2_train_batch).item())
            batch_train_linf_hist.append(jnp.mean(rel_linf_train_batch).item())


            if (ep+1)%logging_epoch ==0 and batch_ind == 0:
                PlotSystem.Solution_2D(X_2D, Y_2D
                ,  pred_u=train_batch_sol_pred[train_ind]
                , true_u= train_batch_sol_exact[train_ind]
                , NORM={'RelL2':rel_l2_train_batch[train_ind].item()
                        , 'RelLinf':rel_linf_train_batch[train_ind].item()
                        }
                , SAVE_PATH= os.path.join(args.SOLUTION_PATH, f'train_ind_{train_ind}_solution_epoch_{ep+1}.png')
                , SHOW=False
                )
                
                
        ##### Test #####
        for batch_ind in range(num_test_batches):

            test_batch_forcing = test_forcing[batch_ind * batch_train : (batch_ind + 1) * batch_train]
            test_batch_RHS = test_RHS[batch_ind * batch_train : (batch_ind + 1) * batch_train]
            test_batch_sol_exact = test_u_true[batch_ind * batch_test : (batch_ind + 1) * batch_test]

            (test_loss, test_batch_sol_pred, rel_l2_test_batch, rel_linf_test_batch
             ) = test_eval_step(nn_params
                           , test_batch_forcing , test_batch_RHS, test_batch_sol_exact)

            batch_test_loss_hist.append(jnp.mean(test_loss).item())
            batch_test_l2_hist.append(jnp.mean(rel_l2_test_batch).item())
            batch_test_linf_hist.append(jnp.mean(rel_linf_test_batch).item())

            if (ep+1)%logging_epoch ==0 and batch_ind == 0:
                PlotSystem.Solution_2D(X_2D, Y_2D
                ,  pred_u = test_batch_sol_pred[test_ind]
                , true_u= test_batch_sol_exact[test_ind]
                , NORM={'RelL2':rel_l2_test_batch[test_ind].item()
                        , 'RelLinf':rel_linf_test_batch[test_ind].item()
                        }
                , SAVE_PATH= os.path.join(args.SOLUTION_PATH, f'test_ind_{train_ind}_solution_epoch_{ep+1}.png')
                , SHOW=False
                )


        elapsed_time = time.time() - start
        total_elapsed_time = time.time() - total_start

        epoch_train_loss_mean = jnp.mean(jnp.array(batch_train_loss_hist)).item()
        epoch_train_rel_l2_mean = jnp.mean(jnp.array(batch_train_l2_hist)).item()
        epoch_train_rel_linf_mean = jnp.mean(jnp.array(batch_train_linf_hist)).item()

        epoch_test_loss_mean = jnp.mean(jnp.array(batch_test_loss_hist)).item()
        epoch_test_rel_l2_mean = jnp.mean(jnp.array(batch_test_l2_hist)).item()
        epoch_test_rel_linf_mean = jnp.mean(jnp.array(batch_test_linf_hist)).item()

        TRAIN_COST_HIST.append(epoch_train_loss_mean)
        TRAIN_L2_HIST.append(epoch_train_rel_l2_mean)
        TRAIN_Linf_HIST.append(epoch_train_rel_linf_mean)

        TEST_COST_HIST.append(epoch_test_loss_mean)
        TEST_L2_HIST.append(epoch_test_rel_l2_mean)
        TEST_Linf_HIST.append(epoch_test_rel_linf_mean)

        if (best_cost is None) or (epoch_train_loss_mean < best_cost):
            best_cost = epoch_train_loss_mean
            best_ckpt = {
                        'params': nn_params,
                        'opt_state': opt_state,
                        'epoch': ep
                        }

        if ep%logging_epoch ==0 or ep == args.EPOCH-1:
            print(f"[Jax, GPU] Epoch {ep+1:02d}, time = {elapsed_time:.6f} sec, elapsed time: {total_elapsed_time:.6f} loss = {epoch_train_loss_mean}")
            checkpoint = {
                        'params': nn_params,
                        'opt_state': opt_state,
                        'epoch': ep
                        }
            with open( os.path.join(args.RESULT_PATH, 'nn_ckpt.pkl'), 'wb') as f:
                pickle.dump(checkpoint, f)


            PlotSystem.Losses({'Cost (Train)': [TRAIN_COST_HIST, 'b']
                               , 'Cost (Test)': [TEST_COST_HIST, 'g']
                               }
                            , SAVEPATH=os.path.join(args.RESULT_PATH, 'cost_mean_history.png')
                            , log_scale=True, SHOW=False)
            PlotSystem.Losses({r'Train Rel $L_2$ Error':[TRAIN_L2_HIST, 'b--']
                            ,r'Train Rel $L_\text{inf}$ Error':[TRAIN_Linf_HIST, 'b']
                            ,r'Test Rel $L_2$ Error':[TEST_L2_HIST, 'g--']
                            ,r'Test Rel $L_\text{inf}$ Error':[TEST_Linf_HIST, 'g']
                            }, SAVEPATH=os.path.join(args.RESULT_PATH, 'total_error_mean.png')
                            , log_scale=True
                            , SHOW=False
                            )
            PlotSystem.Losses({r'Relative $L_2$ Error':TRAIN_L2_HIST
                            ,r'Relative $L_\text{inf}$ Error':TRAIN_Linf_HIST
                            }, SAVEPATH=os.path.join(args.RESULT_PATH, 'train_error_mean.png')
                            , log_scale=True
                            , SHOW=False)
            PlotSystem.Losses({r'Relative $L_2$ Error':TEST_L2_HIST
                            ,r'Relative $L_\text{inf}$ Error':TEST_Linf_HIST
                            }, SAVEPATH=os.path.join(args.RESULT_PATH, 'test_error_mean.png')
                            , log_scale=True
                            , SHOW=False
                            , COLOR='g'
                            )
        
    total_elapsed_time = time.time() - total_start
    print(f"[Jax, GPU] total time = {total_elapsed_time:.6f} sec")
    

    # Save Best CKPT
    with open(os.path.join(args.CKPT_PATH, 'best_cost.txt'), 'w') as f:
        f.write(f"Best Cost       : {best_cost}\n")
        f.write(f"Best Cost (Test): {min(TEST_COST_HIST)}\n")
        f.write(f"Best Train L2   : {min(TRAIN_L2_HIST)}\n")
        f.write(f"Best Train Linf : {min(TRAIN_Linf_HIST)}\n")
        f.write(f"Best Test L2    : {min(TEST_L2_HIST)}\n")
        f.write(f"Best Test Linf  : {min(TEST_Linf_HIST)}\n")
    checkpoint = {
                'params': nn_params,
                'opt_state': opt_state,
                'epoch': ep
                }
    # Save Best CKPT
    with open( os.path.join(args.CKPT_PATH, f'nn_ckpt_epoch_{ep+1}.pkl'), 'wb') as f:
        pickle.dump(checkpoint, f)
    with open( os.path.join(args.CKPT_PATH, 'nn_best_ckpt.pkl'), 'wb') as f:
        pickle.dump(best_ckpt, f)

    with open( os.path.join(args.CKPT_PATH, 'TRAIN_COST_HIST.pkl'), 'wb') as f:
        pickle.dump(TRAIN_COST_HIST, f)
    with open( os.path.join(args.CKPT_PATH, 'TRAIN_L2_HIST.pkl'), 'wb') as f:
        pickle.dump(TRAIN_L2_HIST, f)
    with open( os.path.join(args.CKPT_PATH, 'TRAIN_Linf_HIST.pkl'), 'wb') as f:
        pickle.dump(TRAIN_Linf_HIST, f)

    with open( os.path.join(args.CKPT_PATH, 'TEST_COST_HIST.pkl'), 'wb') as f:
        pickle.dump(TEST_COST_HIST, f)
    with open( os.path.join(args.CKPT_PATH, 'TEST_L2_HIST.pkl'), 'wb') as f:
        pickle.dump(TEST_L2_HIST, f)
    with open( os.path.join(args.CKPT_PATH, 'TEST_Linf_HIST.pkl'), 'wb') as f:
        pickle.dump(TEST_Linf_HIST, f)

    # Solution: Exact & Prediction
    with open( os.path.join(args.SOLUTION_PATH, 'train_batch_sol_exact.pkl'), 'wb') as f:
        pickle.dump(train_batch_sol_exact, f)
    with open( os.path.join(args.SOLUTION_PATH, 'train_batch_sol_pred.pkl'), 'wb') as f:
        pickle.dump(train_batch_sol_pred, f)
    with open( os.path.join(args.SOLUTION_PATH, 'test_batch_sol_exact.pkl'), 'wb') as f:
        pickle.dump(test_batch_sol_exact, f)
    with open( os.path.join(args.SOLUTION_PATH, 'test_batch_sol_pred.pkl'), 'wb') as f:
        pickle.dump(test_batch_sol_pred, f)