import cmdstanpy
from cmdstanpy import CmdStanModel
import numpy as np
import pandas as pd
from dataset import AirTrafficDataset2
import tensorflow as tf

stan_model = CmdStanModel(stan_file='air_traffic.stan')

dataset = AirTrafficDataset2()
country_codes = dataset.country_codes()

results = []

for code in country_codes:
  data_dict = dataset.get_country_data(code)
  print(f"Running model for country: {code} with N = {data_dict['N']}")

  fit = stan_model.sample(
      data=data_dict,
      seed=123,
      chains=4,
      iter_warmup=500,
      iter_sampling=1500
      )

  df_draws = fit.draws_pd()
  parameters = ["b0", "bt", "b1", "b2", "logsigma"]
  draws_array = np.column_stack([df_draws[param].values for param in parameters])
  draws_tf = tf.convert_to_tensor(draws_array, dtype=tf.float32)
  results.append(draws_tf)

  result_stan = tf.stack(results)
np.save('./results/result_stan.npy', result_stan)
