import pandas as pd
import numpy as np
import tensorflow as tf

class AirTrafficDataset2:
    def __init__(self):
        air_traffic_df = pd.read_csv("./dataset/air_traffic_EU_US.csv")
        debt_df = pd.read_csv("./dataset/household_debt_EU.csv")
        gdp_df = pd.read_csv("./dataset/gdp_EU.csv")

        df = pd.merge(air_traffic_df, debt_df, on=["country", "year"])
        df = pd.merge(df, gdp_df, on=["country", "year"])
        df["year"] = pd.to_datetime(df["year"], format="%Y")

        df["air_traffic"] = df.groupby("country")["air_traffic"].transform(lambda x: x)

        # Sort by country and year and reset index
        self.df = df.sort_values(["country", "year"]).reset_index(drop=True)

    def country_codes(self):
        return sorted(self.df["country"].unique())

    def country_code_to_name(self):
        """Returns a dictionary mapping country codes to country names."""
        return {
            "AT": "Austria",
            "BE": "Belgium",
            "CZ": "Czechia",
            "DE": "Germany",
            "DK": "Denmark",
            "EL": "Greece",
            "ES": "Spain",
            "FI": "Finland",
            "FR": "France",
            "IE": "Ireland",
            "IT": "Italy",
            "NL": "Netherlands",
            "PL": "Poland",
            "PT": "Portugal",
            "SE": "Sweden",
        }

    def get_country_data(self, code):
        """
        For a given country code, extract and process the time series.
        Returns a dictionary with keys: 'N', 'y', 'debt', 'gdp', and 'lag_y'.

        - 'y' is the year-to-year difference in air_traffic.
        - 'debt' and 'gdp' are the predictors aligned with y (dropping the first observation),
          and standardized for that country.
        - 'lag_y' is the lagged y_diff, where the first value is set to 0.
        """

        df_country = self.df[self.df["country"] == code].sort_values("year").reset_index(drop=True)

        y = df_country["air_traffic"].values
        if len(y) < 2:
            raise ValueError(f"Country {code} does not have enough data points.")

        y_diff = np.diff(y)
        y_diff_new = y_diff[1:]

        lag_y = np.empty_like(y_diff_new) 
        lag_y = y_diff[:-1]

        debt = df_country["debt"].values[2:]
        gdp = df_country["gdp"].values[2:]
        debt = (debt - np.mean(debt)) / np.std(debt)
        gdp = (gdp - np.mean(gdp)) / np.std(gdp)

        return {
            "N": len(y_diff_new),
            "y": y_diff_new.tolist(),
            "debt": debt.tolist(),
            "gdp": gdp.tolist(),
            "lag_y": lag_y.tolist()
        }

    def to_bayesflow_input_dict_single(self, country_code):
        d = self.get_country_data(country_code)
        y_diff = np.array(d["y"])      
        debt = np.array(d["debt"])       
        gdp = np.array(d["gdp"])        
        lag_y = np.array(d["lag_y"])

        y_diff_exp = np.expand_dims(y_diff, axis=(0, -1))
        debt_exp   = np.expand_dims(debt, axis=(0, -1))
        gdp_exp    = np.expand_dims(gdp, axis=(0, -1))
        lag_y_exp  = np.expand_dims(lag_y, axis=(0, -1))

        summary_conditions = np.concatenate([y_diff_exp, debt_exp, gdp_exp, lag_y_exp], axis=-1)

        return summary_conditions

def load_real_data(num_c=8):

    dataset_instance = AirTrafficDataset2()
    country_codes = dataset_instance.country_codes()
    
    inputs_list = []
    for code in country_codes:
        bayesflow_input = dataset_instance.to_bayesflow_input_dict_single(code)
        inputs_list.append(bayesflow_input)
    
    # Concatenate all countries
    input_dict = np.concatenate(inputs_list, axis=0)
    
    # Use only the first 8 rows (e.g., 8 countries) for training
    real_data_np = input_dict[:num_c, :, :]
    real_data = tf.convert_to_tensor(real_data_np, dtype=tf.float32)
    return real_data
