import numpy as np
import tensorflow as tf
from functools import partial
import pickle
import argparse
from simulator import build_generative_model 
from dataset import AirTrafficDataset2, load_real_data
from configurator import Configurator
import bayesflow as bf
from bayesflow.trainers import Trainer
from bayesflow.networks import InvertibleNetwork
from bayesflow.amortizers import AmortizedPosterior

inference_net = bf.networks.InvertibleNetwork(num_params=5, 
                                              num_coupling_layers=6, 
                                              coupling_design="spline", 
                                              coupling_settings={"dense_args": dict(units=256, activation="elu"),},
                                              )

summary_net = tf.keras.Sequential([
                        tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(64)),
                        tf.keras.layers.Dense(256, activation="elu"),
                        tf.keras.layers.Dense(64),
                    ])

def sample_posterior(trainer, dataset_instance, filename='posterior_draws.npy'):
    tf.experimental.numpy.experimental_enable_numpy_behavior()
    country_codes = dataset_instance.country_codes()
    inputs_list = []
    for code in country_codes:
        bayesflow_input = dataset_instance.to_bayesflow_input_dict_single(code)
        inputs_list.append(bayesflow_input)
    input_dict = np.concatenate(inputs_list, axis=0)
    
    real_data_t1 = tf.convert_to_tensor(input_dict, dtype=tf.float32)
    D = real_data_t1.shape[-1]
    
    forward_dict = {
        "prior_draws": np.array([[0.0] * 5], dtype=np.float32),
        "sim_data": real_data_t1,
    }

    input_dict_config = trainer.configurator(forward_dict)

    for key in input_dict_config:
        tensor = input_dict_config[key]
        if len(tensor.shape) == 5:
            input_dict_config[key] = tf.squeeze(tensor, axis=1)
    
    posterior_draws_all = trainer.amortizer.sample(input_dict_config, n_samples=1000)
    np.save(filename, posterior_draws_all)
    print(f"Posterior samples saved to '{filename}'")

def parse_args():
    parser = argparse.ArgumentParser(
        description="Train the BayesFlow amortized posterior model for air traffic."
    )
    parser.add_argument("--lr", type=float, default=0.00005, help="Learning rate")
    parser.add_argument("--epochs", type=int, default=100, help="Number of training epochs")
    parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
    parser.add_argument("--simulation_budget", type=int, default=1024, help="Simulation budget")
    parser.add_argument("--num_obs", type=int, default=14, help="Number of observations")
    parser.add_argument("--run_id", type=int, default=0, help="Run identifier")
    return parser.parse_args()


def main():
    args = parse_args()
    rng = np.random.default_rng(seed=1)
    num_obs = args.num_obs #14
    
    generative_model, train_data, prior, simulator = build_generative_model(
        rng, num_obs, simulation_budget=args.simulation_budget
    )   

    dataset_instance = AirTrafficDataset2()

    configurator = Configurator(num_obs=num_obs)

    amortizer_no_sc = AmortizedPosterior(
    inference_net=inference_net,
    summary_net=summary_net
    )

    trainer_no_sc = bf.trainers.Trainer(
    amortizer=amortizer_no_sc,
    generative_model=generative_model,
    configurator=configurator.configure,
    )

    history_no_sc = trainer_no_sc.train_offline(
    train_data,
    epochs=args.epochs,
    lr=args.lr,
    batch_size=args.batch_size)

    print("Training NPE complete!")

    sample_posterior(trainer_no_sc, dataset_instance, filename='./results/posterior_draws_npe.npy')


if __name__ == "__main__":
    main()
