import numpy as np
import tensorflow as tf
from functools import partial
import pickle
import argparse
from simulator import build_generative_model 
from dataset import AirTrafficDataset2, load_real_data
from configurator import Configurator
from amortizer import AmortizedPosteriorSC
import bayesflow as bf
from bayesflow.trainers import Trainer
from bayesflow.networks import InvertibleNetwork
from bayesflow.amortizers import AmortizedPosterior

inference_net = bf.networks.InvertibleNetwork(num_params=5, 
                                              num_coupling_layers=6, 
                                              coupling_design="spline", 
                                              coupling_settings={"dense_args": dict(units=256, activation="elu"),},
                                              )

summary_net = tf.keras.Sequential([
                        tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(64)),
                        tf.keras.layers.Dense(256, activation="elu"),
                        tf.keras.layers.Dense(64),
                    ])

def sample_posterior(trainer, dataset_instance, filename='posterior_draws.npy'):
    tf.experimental.numpy.experimental_enable_numpy_behavior()
    country_codes = dataset_instance.country_codes()
    inputs_list = []
    for code in country_codes:
        bayesflow_input = dataset_instance.to_bayesflow_input_dict_single(code)
        inputs_list.append(bayesflow_input)
    input_dict = np.concatenate(inputs_list, axis=0)
    
    real_data_t1 = tf.convert_to_tensor(input_dict, dtype=tf.float32)
    D = real_data_t1.shape[-1]
    
    forward_dict = {
        "prior_draws": np.array([[0.0] * 5], dtype=np.float32),
        "sim_data": real_data_t1,
    }

    input_dict_config = trainer.configurator(forward_dict)

    for key in input_dict_config:
        tensor = input_dict_config[key]
        if len(tensor.shape) == 5:
            input_dict_config[key] = tf.squeeze(tensor, axis=1)
    
    posterior_draws_all = trainer.amortizer.sample(input_dict_config, n_samples=1000)
    np.save(filename, posterior_draws_all)
    print(f"Posterior samples saved to '{filename}'")

def parse_args():
    parser = argparse.ArgumentParser(
        description="Train the BayesFlow amortized posterior model for air traffic."
    )
    parser.add_argument("--lr", type=float, default=0.00005, help="Learning rate")
    parser.add_argument("--epochs", type=int, default=100, help="Number of training epochs")
    parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
    parser.add_argument("--simulation_budget", type=int, default=1024, help="Simulation budget")
    parser.add_argument("--num_obs", type=int, default=14, help="Number of observations")
    parser.add_argument("--num_c", type=int, default=8, help="Number of countries for training")
    parser.add_argument("--run_id", type=int, default=0, help="Run identifier")
    parser.add_argument("--output", type=str,
                        default="./results/posterior_draws_sc_8c_32.npy",
                        help="Path to save the posterior draws"
                        )
    return parser.parse_args()

def main():
    args = parse_args()
    rng = np.random.default_rng(seed=1)
    num_obs = args.num_obs #14
    
    generative_model, train_data, prior, simulator = build_generative_model(
        rng, num_obs, simulation_budget=args.simulation_budget
    )
    
    real_data = load_real_data(args.num_c)

    dataset_instance = AirTrafficDataset2()

    amortizer = AmortizedPosteriorSC(
        prior=prior,
        simulator=simulator,
        real_data=real_data,
        inference_net=inference_net,
        summary_net=summary_net,
        n_consistency_samples=32
    )
    
    configurator = Configurator(num_obs=num_obs)
    trainer = Trainer(
        amortizer=amortizer,
        generative_model=generative_model,
        configurator=configurator.configure,
    )
    
    history = trainer.train_offline(
        train_data,
        epochs=args.epochs,
        lr=args.lr,
        batch_size=args.batch_size
    )
    
    print("Training SC complete!")

    sample_posterior(trainer, dataset_instance, filename=args.output)

if __name__ == "__main__":
    main()
