import torch
import pickle
import math
import openai
import os
import execnet
import re

import pandas as pd
import numpy as np
import io, tokenize
import matplotlib.pyplot as plt
import matplotlib.dates as mdates

from datetime import datetime, timedelta
from statsmodels.tsa.arima.model import ARIMA
from statsmodels.tsa.seasonal import seasonal_decompose
from typing import List, Dict, Any
from sklearn.linear_model import LinearRegression
from scipy.stats import linregress
from haversine import haversine, Unit
from torch_geometric.utils import dense_to_sparse

from models.gwnet_arch import GraphWaveNet
from models.gwnet_utils import *
from func_utils import *


def parse_step(step_str, partial=False):
    tokens = list(tokenize.generate_tokens(io.StringIO(step_str).readline))
    # print(tokens)

    
    # Extract output variables by splitting up to '='
    output_var = step_str.split('=')[0].strip().split(',')[0].strip()  
    
    # Extract step name (third token)
    step_name = tokens[2].string
    
    # Initialize result dictionary
    parsed_result = {
        'output_var': output_var,
        'step_name': step_name
    }
    
    # If partial parsing is required
    if partial:
        return parsed_result

    # Parse arguments
    arg_tokens = [token for token in tokens[4:-1] if token.string not in [',','=']]
    num_tokens = len(arg_tokens) // 2
    args = {}
    
    for i in range(num_tokens):
        key = arg_tokens[2 * i].string
        value = arg_tokens[2 * i + 1].string.strip('"').strip("'")
        
        # Convert numeric strings to integers if applicable
        if value.isdigit():
            value = int(value)
        
        args[key] = value

    parsed_result['args'] = args
    return parsed_result


class LoadSpatiotemporalDataInterpreter:
    step_name = 'LOAD_SPATIOTEMPORAL_DATA'

    def parse(self, prog_step):
        parse_result = parse_step(prog_step.prog_str)
        location = parse_result['args']['location']
        time = parse_result['args']['time']
        feature = parse_result['args']['feature']
        region = parse_result['args']['region']
        time_int = parse_result['args']['time_int']
        period = parse_result['args']['period']
        unit = parse_result['args']['unit']
        task = parse_result['args']['task']
        output_var = parse_result['output_var']  

        return location, time, feature, region, output_var, time_int, period, unit, task

    def execute(self, prog_step, inspect=False):
        location, time, feature, region, output_var, time_int, period, unit, task = self.parse(prog_step)
        
        # Load the .h5 file
        if feature == "traffic speed":
            if region == "LA": 
                data = pd.read_hdf('data/METR-LA/METR-LA.h5')  # Replace with your actual file path
            if region == "BAY": 
                data = pd.read_hdf('data/PEMS-BAY/PEMS-BAY.h5')  # Replace with your actual file path
        elif feature == "air quality":
            if region == "Beijing":
                data = pd.read_hdf('data/AirQuality/Beijing/beijing.h5')
            elif region == "Shenzhen":
                data = pd.read_hdf('data/AirQuality/Shenzhen/shenzhen.h5')  

        # Parse the time as a datetime object
        end_time = datetime.strptime(time, '%Y-%m-%d %H:%M:%S')
        if unit == 'days':
            start_time = end_time - timedelta(days=period)
            steps = int((period*24*60)/time_int)
        elif unit == 'hours':
            start_time = end_time - timedelta(hours=period)
            steps = int((60*period)/time_int)
        elif unit == 'minutes':
            start_time = end_time - timedelta(minutes=period)
            steps = int(period/time_int)
        else:
            raise ValueError(f"Unsupported time unit: {unit}")

        # start_time = end_time - timedelta(minutes=55)  # Adjust for 12 time steps with 5-minute intervals

        # Filter the data for the specified location and time range
        time_df = data.loc[start_time:end_time]
        time_df.columns = time_df.columns.astype(str)
        # print(time_df.head)
        st_df = time_df[[str(location)]]
        # print(st_df.head)
        
        # st_df.index.freq = None
        # print(st_df)

        if task == "analysis" or task=="anomaly detection": 
            prog_step.state[output_var] = st_df
        elif task == "prediction": 
            prog_step.state[output_var] = time_df
        
        return prog_step.state[output_var], self.text_summary([task, location, start_time, end_time, feature], prog_step.state[output_var])

    def text_summary(self, inputs: List, output: Any):
        task, location, start_time, end_time, feature = inputs
        if task == "analysis" or task=="anomaly detection": 
            return f"Loaded data for Location: {location}, Feature: {feature}, Time Range: From {start_time} to {end_time}."
        if task == "prediction": 
            return f"Loaded data for Feature: {feature}, Time Range: From {start_time} to {end_time}."
            
class LoadSpatialAuxDataInterpreter:
    step_name = 'LOAD_SPATIAL_AUX_DATA'

    def parse(self, prog_step):
        parse_result = parse_step(prog_step.prog_str)
        spatial_var = parse_result['args']['spatial_var']
        location = parse_result['args']['location']
        time = parse_result['args']['time']
        feature = parse_result['args']['feature']
        region = parse_result['args']['region']
        time_int = parse_result['args']['time_int']
        period = parse_result['args']['period']
        unit = parse_result['args']['unit']
        constraint = parse_result['args']['constraint']
        output_var = parse_result['output_var']  

        return spatial_var, location, time, feature, region, output_var, time_int, period, unit, constraint
   
    def execute(self, prog_step, inspect=False):
        spatial_var, location, time, feature, region, output_var, time_int, period, unit, constraint = self.parse(prog_step)
        
        # Load the .h5 file
        if feature == "traffic speed":
            if region == "LA": 
                data = pd.read_hdf('data/METR-LA/METR-LA.h5') 
            elif region == "BAY": 
                data = pd.read_hdf('data/PEMS-BAY/PEMS-BAY.h5')  
            else: 
                print("Dataset Not Found!")

        elif feature == "air quality":
            if region == "Beijing":
                data = pd.read_hdf('data/AirQuality/Beijing/beijing.h5')
            elif region == "Shenzhen":
                data = pd.read_hdf('data/AirQuality/Shenzhen/shenzhen.h5')  
            else: 
                print("Dataset Not Found!")

        # Parse the time as a datetime object
        end_time = datetime.strptime(time, '%Y-%m-%d %H:%M:%S')
        if unit == 'days':
            start_time = end_time - timedelta(days=period)
            steps = int((period*24*60)/time_int)
        elif unit == 'hours':
            start_time = end_time - timedelta(hours=period)
            steps = int((60*period)/time_int)
        elif unit == 'minutes':
            start_time = end_time - timedelta(minutes=period)
            steps = int(period/time_int)
        else:
            raise ValueError(f"Unsupported time unit: {unit}")


        # Filter the data for the specified location and time range
        time_df = data.loc[start_time:end_time]
        time_df.columns = time_df.columns.astype(str)

        if spatial_var == "neighbour":
            neighbors = get_neighbors(feature, location, region, threshold=0.6)
            neighbor_list = [str(num) for num in neighbors]
            st_df = time_df[neighbor_list]

        if constraint == "weekdays only":            
            st_df = st_df[st_df.index.dayofweek < 5]  # Filter for weekdays
        if constraint == "weekends only":            
            st_df = st_df[st_df.index.dayofweek >=5]  # Filter for weekends   
        
        prog_step.state[output_var] = st_df[-steps:]    
        
        return prog_step.state[output_var], self.text_summary([spatial_var, location, start_time, end_time, feature], prog_step.state[output_var])

    def text_summary(self, inputs: List, output: Any):
        spatial_var, location, start_time, end_time, feature = inputs
        return f"Loaded {spatial_var} data as spatial auxiliary data for Location: {location}, Feature: {feature}, Time Range: From {start_time} to {end_time}."

class LoadTemporalAuxDataInterpreter:
    step_name = 'LOAD_TEMPORAL_AUX_DATA'

    def parse(self, prog_step):
        parse_result = parse_step(prog_step.prog_str)
        temp_var = parse_result['args']['temp_var']
        location = parse_result['args']['location']
        time = parse_result['args']['time']
        feature = parse_result['args']['feature']
        region = parse_result['args']['region']
        time_int = parse_result['args']['time_int']
        period = parse_result['args']['period']
        unit = parse_result['args']['unit']
        constraint = parse_result['args']['constraint']
        output_var = parse_result['output_var']  

        return temp_var, location, time, feature, region, output_var, time_int, period, unit, constraint
   
    def execute(self, prog_step, inspect=False):
        temp_var, location, time, feature, region, output_var, time_int, period, unit, constraint = self.parse(prog_step)

        if temp_var == "weather":
            # Load Weather Data 
            if feature == "air quality":
                if region == "Beijing":
                    humidity = pd.read_hdf('data/AirQuality/Beijing/weather/beijing_humidity.h5')
                    pressure = pd.read_hdf('data/AirQuality/Beijing/weather/beijing_pressure.h5')
                    temperature = pd.read_hdf('data/AirQuality/Beijing/weather/beijing_temperature.h5')
                    ws = pd.read_hdf('data/AirQuality/Beijing/weather/beijing_ws.h5')
    
                elif region == "Shenzhen":
                    humidity = pd.read_hdf('data/AirQuality/Shenzhen/weather/shenzhen_humidity.h5')
                    pressure = pd.read_hdf('data/AirQuality/Shenzhen/weather/shenzhen_pressure.h5')
                    temperature = pd.read_hdf('data/AirQuality/Shenzhen/weather/shenzhen_temperature.h5')
                    ws = pd.read_hdf('data/AirQuality/Shenzhen/weather/shenzhen_ws.h5')
                
            # Parse the time as a datetime object
            end_time = datetime.strptime(time, '%Y-%m-%d %H:%M:%S')
            if unit == 'days':
                start_time = end_time - timedelta(days=period)
                steps = int((period*24*60)/time_int)
            elif unit == 'hours':
                start_time = end_time - timedelta(hours=period)
                steps = int((60*period)/time_int)
            elif unit == 'minutes':
                start_time = end_time - timedelta(minutes=period)
                steps = int(period/time_int)
            else:
                raise ValueError(f"Unsupported time unit: {unit}")
    
            all_data = [humidity, pressure, temperature, ws]
            data_names = ['Humidity', 'Pressure', 'Temperature', 'Wind Speed']
            
            processed_data = []  
            for i, name in zip(all_data, data_names):    
                time_df = i.loc[start_time:end_time]
                time_df.columns = time_df.columns.astype(str)
                st_df = time_df[[str(location)]]
                if constraint == "weekdays only":            
                    st_df = st_df[st_df.index.dayofweek < 5]  # Filter for weekdays
                if constraint == "weekends only":            
                    st_df = st_df[st_df.index.dayofweek >=5]  # Filter for weekends   
                    
                st_df.columns = [name]
                processed_data.append(st_df)
            
            weather_data = pd.concat(processed_data, axis=1)
        
            prog_step.state[output_var] = weather_data[-steps:]    
        
        return prog_step.state[output_var], self.text_summary([temp_var, location, start_time, end_time], prog_step.state[output_var])

    def text_summary(self, inputs: List, output: Any):
        temp_var,location, start_time, end_time = inputs
        return f"Loaded {temp_var} data as temporal auxiliary data for Location: {location}, Time Range: From {start_time} to {end_time}."


class ImposeConstraintsInterpreter:
    step_name = 'IMPOSE_CONSTRAINTS'

    def parse(self, prog_step):
        parse_result = parse_step(prog_step.prog_str)
        data = parse_result['args']['data']
        time = parse_result['args']['time']
        time_int = parse_result['args']['time_int']
        period = parse_result['args']['period']
        unit = parse_result['args']['unit']
        task = parse_result['args']['task']
        constraint = parse_result['args']['constraint']
        constraint_val = parse_result['args']['constraint_val']
        preds = parse_result['args']['preds']
        output_var = parse_result['output_var']  

        return data, time, time_int, period, unit, task, constraint, constraint_val, preds, output_var

    def execute(self, prog_step, inspect=False):
        data, time, time_int, period, unit, task, constraint, constraint_val, preds, output_var = self.parse(prog_step)

        if task == "analysis" or task=="anomaly detection":
            # Parse the time as a datetime object
            end_time = datetime.strptime(time, '%Y-%m-%d %H:%M:%S')
            if unit == 'days':
                start_time = end_time - timedelta(days=period)
                steps = int((period*24*60)/time_int)
            elif unit == 'hours':
                start_time = end_time - timedelta(hours=period)
                steps = int((60*period)/time_int)
            elif unit == 'minutes':
                start_time = end_time - timedelta(minutes=period)
                steps = int(period/time_int)
            else:
                raise ValueError(f"Unsupported time unit: {unit}")

            st_df = prog_step.state[data]
            # st_df.index.freq = None
            # print(st_df)
            
            if constraint == "weekdays only":            
                st_df = st_df[st_df.index.dayofweek < 5]  # Filter for weekdays
            elif constraint == "weekends only":            
                st_df = st_df[st_df.index.dayofweek >=5]  # Filter for weekends 

            # Store the data in the program state (last 12 timesteps)
            prog_step.state[output_var] = st_df[-steps:]
            # print(prog_step.state[output_var])

        elif task == "prediction":
            ### Adjusting Predictions to adhere to constraints ###
            predictions = prog_step.state[preds]
            # print(predictions.shape)
            predictions = np.where(predictions > constraint_val, constraint_val, predictions)
            prog_step.state[output_var] = predictions
        
        return prog_step.state[output_var], self.text_summary([task, constraint, constraint_val], prog_step.state[output_var])

    def text_summary(self, inputs: List, output: Any):
        task, constraint, constraint_val = inputs
        if task == "analysis" or task=="anomaly detection": 
            if constraint != "None":
                return f"Imposed constraints and retrieved data from {constraint}."
            else:
                return f"No data retrieval constraints available imposed."
        elif task == "prediction": 
            return f"Imposed constraints considering {constraint} of {constraint_val}."
      
class STTrendInterpreter:
    step_name = 'ANALYZE_TREND'

    def parse(self, prog_step):
        parse_result = parse_step(prog_step.prog_str)
        data = parse_result['args']['data']
        feature = parse_result['args']['feature']
        location = parse_result['args']['location']
        time_int = parse_result['args']['time_int']
        constraint = parse_result['args']['constraint']
        output_var = parse_result['output_var']
        return data, feature, location, time_int, constraint, output_var

    def execute(self, prog_step, inspect=False):
        data, feature, location, time_int, constraint, output_var = self.parse(prog_step)
        series = pd.Series(prog_step.state[data].values.flatten())  # Access the numerical data

        prog_step.state[data].index.freq = None
        datetime_index = pd.to_datetime(prog_step.state[data].index)

        # Perform analysis
        slope, pval, stderr, plot_path = self.analyze_trend(series, datetime_index, time_int, feature, location, constraint)
        
        # Generate explanation
        trend_summary = self.generate_summary(slope, pval, stderr)

        # Update state with explanation
        prog_step.state[output_var] = trend_summary
        return trend_summary, self.text_summary([plot_path], trend_summary)


    def analyze_trend(self, series, datetime_index, time_int, feature, location, constraint):
        """
        Analyze the data for trend
        """
        ### Time series Decompostion ###
        if time_int == 5:
            decomposition_daily = seasonal_decompose(series, model='additive', period=288) #Since data is recorded every 5 minutes, there are (24×60/5)=288
                                                                                    #observations in a day.
            if constraint == "None":
                decomposition_weekly = seasonal_decompose(series, model='additive', period=2016) #Since data is recorded every 5 minutes, there are (24×60/5)=288*7 observations in a week.
            elif constraint == "weekdays only":
                decomposition_weekly = seasonal_decompose(series, model='additive', period=1440) #Since data is recorded every 5 minutes, there are (24×60/5)=288*5 observations in a week.
            elif constraint == "weekends only":
                decomposition_weekly = seasonal_decompose(series, model='additive', period=576) #Since data is recorded every 5 minutes, there are (24×60/5)=288*2 observations in a week.    
        
        elif time_int == 60:
            ### Time series Decompostion ###
            decomposition_daily = seasonal_decompose(series, model='additive', period=24) #Since data is recorded every hourly, there are (24×60/60)=24 observations in a day.

            if constraint == "None":
                decomposition_weekly = seasonal_decompose(series, model='additive', period=168) 
            elif constraint == "weekdays only":
                decomposition_weekly = seasonal_decompose(series, model='additive', period=120) 
            elif constraint == "weekends only":
                decomposition_weekly = seasonal_decompose(series, model='additive', period=48) 


        ### Trend Analysis ###
        trend = decomposition_daily.trend.tolist()
        
        # Trend Line using Linear Regression
        X = np.arange(len(series)).reshape(-1, 1)
        y = series.values
        model = LinearRegression().fit(X, y)
        trend_line = model.predict(X)
        slope = model.coef_[0]

        # Calculate p-value, standard error
        slope_stats = linregress(X.flatten(), y)  # linregress needs flat array for X
        pval = slope_stats.pvalue
        stderr = slope_stats.stderr

        # Generate and save the plot
        plot_path = self.gen_trend_plot(series, datetime_index, trend_line, feature, location)

        return slope, pval, stderr, plot_path


    def gen_trend_plot(self, series, datetime_index, trend_line, feature, location):
        
        output_dir = "./Visualizations"
        
        plt.figure(figsize=(14, 7))
        plt.plot(series, label='Data')
        plt.plot(series.index, trend_line, color='green', label='Trend')
        plt.title(f"Visualization of {feature} at location {location}")

        plt.legend()
        
        # Save the plot to the output directory
        plot_path = os.path.join(output_dir, f"{location}_{feature}_trend_plot.png")

        plt.savefig(plot_path)
        plt.close()
        
        return plot_path


    def generate_summary(self,slope, pval, stderr):   
        # Summarize the findings
        summary = (
            f"Trend Analysis:\n"
            f"-Method: Linear regression on time series data.\n"
            f"-Results: Slope: {slope:.5f}, pvalue: {pval:.5f}, Standard Error: {stderr:.5f}.\n"        
        )

        return summary

    def text_summary(self, inputs: List, output: Any):
        plot_path = inputs
        return (
            f"Trend Analysis Conducted.\n"
            f"Trend plot saved at {plot_path}."
        )

class STSeasonalityInterpreter:
    step_name = 'ANALYZE_SEASONALITY'

    def parse(self, prog_step):
        parse_result = parse_step(prog_step.prog_str)
        data = parse_result['args']['data']
        time_int = parse_result['args']['time_int']
        constraint = parse_result['args']['constraint']
        output_var = parse_result['output_var']
        return data, time_int, constraint, output_var

    def execute(self, prog_step, inspect=False):
        data, time_int, constraint, output_var = self.parse(prog_step)
        series = pd.Series(prog_step.state[data].values.flatten())  # Access the numerical data

        prog_step.state[data].index.freq = None

        # Perform analysis
        daily_seasonality_strength, weekly_seasonality_strength = self.analyze_seasonality(series, time_int, constraint)
        
        # Generate explanation
        seasonality_summary = self.generate_summary(daily_seasonality_strength, weekly_seasonality_strength)

        # Update state with explanation
        prog_step.state[output_var] = seasonality_summary
        return seasonality_summary, self.text_summary([],seasonality_summary)


    def analyze_seasonality(self, series, time_int, constraint):
        """
        Analyze the data for seasonality.
        """
        ### Time series Decompostion ###
        if time_int == 5:
            decomposition_daily = seasonal_decompose(series, model='additive', period=288) #Since data is recorded every 5 minutes, there are (24×60/5)=288
                                                                                    #observations in a day.
            if constraint == "None":
                decomposition_weekly = seasonal_decompose(series, model='additive', period=2016) #Since data is recorded every 5 minutes, there are (24×60/5)=288*7 observations in a week.
            elif constraint == "weekdays only":
                decomposition_weekly = seasonal_decompose(series, model='additive', period=1440) #Since data is recorded every 5 minutes, there are (24×60/5)=288*5 observations in a week.
            elif constraint == "weekends only":
                decomposition_weekly = seasonal_decompose(series, model='additive', period=576) #Since data is recorded every 5 minutes, there are (24×60/5)=288*2 observations in a week.    
        
        elif time_int == 60:
            ### Time series Decompostion ###
            decomposition_daily = seasonal_decompose(series, model='additive', period=24) #Since data is recorded every hourly, there are (24×60/60)=24 observations in a day.

            if constraint == "None":
                decomposition_weekly = seasonal_decompose(series, model='additive', period=168) 
            elif constraint == "weekdays only":
                decomposition_weekly = seasonal_decompose(series, model='additive', period=120) 
            elif constraint == "weekends only":
                decomposition_weekly = seasonal_decompose(series, model='additive', period=48) 

        ### Seasonality Analysis ###
        seasonality_daily = decomposition_daily.seasonal
        seasonality_weekly = decomposition_weekly.seasonal

        # Calculate variances
        daily_seasonal_var = np.var(seasonality_daily)
        weekly_seasonal_var = np.var(seasonality_weekly)
        total_var = np.var(series)

        # Compute seasonality strength
        daily_seasonality_strength = daily_seasonal_var / total_var
        weekly_seasonality_strength = weekly_seasonal_var / total_var

        return daily_seasonality_strength, weekly_seasonality_strength


    def generate_summary(self, daily_seasonality_strength, weekly_seasonality_strength):
        # Summarize the findings
        summary = (   
            f"Seasonality Analysis:\n"
            f"- Method: Perform time series decomposition to identify daily and weekly seasonality. Get the ratio of the variance of the seasonal component to the total variance to calculate the strength of each seasonality.\n"
            f"- Results: Daily Seasonality Strength:{daily_seasonality_strength:.5f}, Weekly Seasonality Strength:{weekly_seasonality_strength:.5f},\n"
        )
        return summary

    def text_summary(self, inputs: List, output: Any):
        # [] = inputs
        return (
            f"Seasonality Analysis Conducted."
        )


class STNeighbourhoodInterpreter:
    step_name = 'ANALYZE_NEIGHBOURHOOD'

    def parse(self, prog_step):
        parse_result = parse_step(prog_step.prog_str)
        feature = parse_result['args']['feature']
        location = parse_result['args']['location']
        region = parse_result['args']['region']
        output_var = parse_result['output_var']
        return feature, location, region, output_var

    def execute(self, prog_step, inspect=False):
        feature, location, region, output_var = self.parse(prog_step)

        # Perform analysis
        neighbors = self.analyze_neighbourhood(feature, location, region)
        
        # Generate explanation
        neighbourhood_summary = self.generate_summary(neighbors)

        # Update state with explanation
        prog_step.state[output_var] = neighbourhood_summary
        return neighbourhood_summary, self.text_summary([neighbors], neighbourhood_summary)


    def analyze_neighbourhood(self, feature, location, region):

        ### Spatial Analysis ###
        neighbors = get_neighbors(feature, location, region, threshold=0.6)

        return neighbors


    def generate_summary(self, neighbors):
        # Summarize the findings
        summary = (
            f"Spatial Analysis:\n"
            f"- Neighbours identified based on weighted adjacency matrix : {neighbors}.\n"

        )
        return summary

    def text_summary(self, inputs: List, output: Any):
        """
        Generate a textual summary of the analysis step.
        """
        neighbors = inputs
        return (
            f"Neighbourhood Analysis Conducted.\n"
            f"Neighbouring locations detected:{neighbors}."
        )

class ExplainInterpreter:
    step_name = 'GEN_EXPLAINATION'

    def parse(self, prog_step):
        parse_result = parse_step(prog_step.prog_str)
        task = parse_result['args']['task']
        data = parse_result['args']['data']
        feature = parse_result['args']['feature']
        location = parse_result['args']['location']
        region = parse_result['args']['region']
        time_int = parse_result['args']['time_int']
        horizon = parse_result['args']['horizon']
        horizon_unit = parse_result['args']['horizon_unit']
        constraint = parse_result['args']['constraint']
        constraint_val = parse_result['args']['constraint_val']
        trend = parse_result['args']['trend']
        seasonality = parse_result['args']['seasonality']
        neighbourhood = parse_result['args']['neighbourhood']
        preds = parse_result['args']['preds']
        sensitivity = parse_result['args']['sensitivity']
        output_var = parse_result['output_var']

        return task, data, feature, location, region, time_int, horizon, horizon_unit, constraint, constraint_val, trend, seasonality, neighbourhood, preds, sensitivity, output_var

    def execute(self, prog_step, inspect=False):
        task, data, feature, location, region, time_int, horizon, horizon_unit, constraint, constraint_val, trend, seasonality, neighbourhood, preds, sensitivity, output_var = self.parse(prog_step)
        
        if task == "analysis" or task=="anomaly detection":
            series = pd.Series(prog_step.state[data].values.flatten())  
            # print(series)
    
            if time_int ==5:
                time_interval = "5 minutes"
            elif time_int ==60:
                time_interval = "1 hour" 
    
            if feature=="traffic speed":
                unit = "Kilometres per hour"
            elif feature=="air quality":
                unit = "Micrograms per cubic metre"
    
            # print(constraints)
                
            if constraint =="weekdays only":
                time_cons = "weekdays only"
            elif constraint =="weekends only":
                time_cons = "weekends only"
            else:
                time_cons = "everyday"
            # print(time_cons)
    
            # Summarize the findings
            explanation = (
                f"Original data: {series}.\n"
                f"- Unit of data: {unit}.\n"
                f"- Data corresponds to {time_cons}.\n"
            )
            
            if trend != "None":
                trend_summary = prog_step.state[trend] 
                explanation += f"{trend_summary}\n"
    
            if seasonality != "None":
                seasonality_summary = prog_step.state[seasonality] 
                explanation += f"{seasonality_summary}\n"
            
            if neighbourhood != "None":
                neighbourhood_summary = prog_step.state[neighbourhood] 
                explanation += f"{neighbourhood_summary}\n"
    
            # print(explanation)
            
            # Refine explanation using GPT API
            explanation = self.refine_explanation_with_gpt(feature, time_interval, time_cons, explanation)
            
        if task == "prediction":
            predictions = prog_step.state[preds]
            exceed_indices = predictions > constraint_val
            if exceed_indices.any():
                warning_text = f'Warning: Some predictions exceeded the {constraint} of {constraint_val}.'
            else:
                # All predictions are within the acceptable range
                warning_text = f'All predictions are within the {constraint}'
                
            neighbors = get_neighbors(feature, location, region, threshold=0.6) 
            spatiotemporal_sensitivity = prog_step.state[sensitivity]
                
            response = openai.ChatCompletion.create(
                model="gpt-3.5-turbo",  
                messages=[
                    {"role": "user", "content": (
                        f"As an expert in spatiotemporal forecasting, you conducted an analysis and summarized the prediction results for {feature} data at location {location} for the next {horizon}{horizon_unit} as follows. The data is recorded in {time_int}-minute intervals.\n"
                        f"Predictions: {predictions}\n"
                        f"{warning_text}.\n"
                        f"Neighbouring Nodes: {neighbors}\n"
                        f"{spatiotemporal_sensitivity}\n"
                        f"Based on this information, provide a detailed interpretation of the predictions. Begin by stating the Predictions: the predicted {feature} values for the next {horizon}{horizon_unit} and include:\n"
                        f"- Predictions are made using the deep learning based GraphWaveNet Model"
                        f"- Constraint Adherence: State if the predictions adhere to the {constraint}"
                        f"- Temporal Features: Based on the Timestamp Sensitivities, analyze how distinct daily traffic patterns like morning/evening rush hours, weekends, late-nights or mid day patterns influence the {feature} predictions. Do not mention the maginitude of the impacts.\n"
                        f"- Spatial Features: Based on the Significant Nodes, discuss how data from other sensors influence the forecast. State any Neighbouring Nodes among the Significant Nodes. Do not mention the maginitude of the impacts."
                        f"Formulate your response as if directly answering the query about {feature} prediction, constraints adherence, and the influences of temporal and spatial factors on the predictions, ensuring to explicitly mention the specific nodes and timestamps that have the most significant impact."
                    )}  ],
                temperature=0.7,
                max_tokens=4096,
                top_p=1,
                frequency_penalty=0,
            presence_penalty=0)
 
 
            explanation = response['choices'][0]['message']['content'].strip() 

        # Update state with explanation
        prog_step.state[output_var] = explanation
        
        return explanation, self.text_summary([feature], explanation)


    def refine_explanation_with_gpt(self, feature, time_interval, time_cons, explanation):
        """
        Use GPT API to refine the explanation.
        """
        
        # Call the OpenAI ChatCompletion API
        response = openai.ChatCompletion.create(
            model="gpt-3.5-turbo",  # or "gpt-4" if available
            messages=[
                {"role": "user", "content": f"You are an expert in spatio temporal data analysis. Given below is the time series of {feature} data recorded at every {time_interval}, on {time_cons} along with a summary of numerical findings. Start by giving a comprehensive analysis of the original data. Then, proceed to explain the provided summaries that can be understood by anyone. It is crucial that the final explanation clearly relate the interpretations to the data collection constraints (weekdays, weekends, or every day).\n{explanation}"}
            ],
            temperature=0.7,
            max_tokens=4096,
            top_p=1,
            frequency_penalty=0,
            presence_penalty=0
        )
        return response['choices'][0]['message']['content'].strip()

    def text_summary(self, inputs: List, output: Any):
        feature = inputs
        return (
            f"Final Explanation Generated.\n"
        )

class STAnomalyInterpreter:
    step_name = 'DETECT_ANOMALY_ST_DATA'

    def parse(self, prog_step):
        parse_result = parse_step(prog_step.prog_str)
        data = parse_result['args']['data']
        spatial_aux_data = parse_result['args']['spatial_aux_data']
        temp_aux_data = parse_result['args']['temp_aux_data']
        location = parse_result['args']['location']
        temp_reasoning = parse_result['args']['temp_reasoning']
        spatial_reasoning = parse_result['args']['spatial_reasoning']
        feature = parse_result['args']['feature']
        time_int = parse_result['args']['time_int']
        constraint = parse_result['args']['constraint']
        output_var = parse_result['output_var']
        return data, spatial_aux_data, temp_aux_data, temp_reasoning, spatial_reasoning, location, feature, time_int, constraint, output_var

    # Function to detect anomalies in a series
    def detect_anomalies(self, series):
        X = np.arange(len(series)).reshape(-1, 1)
        y = series.values
        model = LinearRegression().fit(X, y)
        trend_line = model.predict(X)
        residuals = y - trend_line
        std_dev = np.std(residuals)
        mean = np.mean(residuals)
        anomalies = (residuals > mean + 3 * std_dev) | (residuals < mean - 3 * std_dev)
        return series[anomalies], pd.Index(series.index[anomalies])  # Return as pd.Index

    def generate_explanation(self, anomalies_dict, source):
        if not anomalies_dict: 
            return f"- No common anomalies detected in other {source}.\n"
        explanation = (
            f"Common anomalies refer to unusual or outlier conditions that occur simultaneously in multiple datasets or variables. "
            f"These are significant as they can indicate broader issues or influences affecting the system. For {source}:\n"
        )        
        for var, indices in anomalies_dict.items():
            explanation += f"  - {var}: Common anomalies at indices {list(indices)}.\n"
        return explanation
  
    def execute(self, prog_step, inspect=False):
        data, spatial_aux_data, temp_aux_data, temp_reasoning, spatial_reasoning, location, feature, time_int, constraint, output_var = self.parse(prog_step)
        
        loc_data = prog_step.state[data]
        loc_data.index.freq = None
        # Detect anomalies of the main location
        anomaly_vals, anomaly_dates = self.detect_anomalies(loc_data[str(location)])
        
        neigh_explanation = ""
        weather_explanation = ""
        
        neighbor_common_anomalies = {}
        weather_common_anomalies = {}
        
        if spatial_reasoning == "True":
            spatial_aux_data = prog_step.state[spatial_aux_data]
            spatial_aux_data.index.freq = None
            neighbor_anomalies = {}
            
            # Loop over each column in spatial_aux_data except 'timestamp'
            for column in spatial_aux_data.columns:
                if column != location:  
                    _, neighbor_anomalies[column] = self.detect_anomalies(spatial_aux_data[column])
    
            # Compare and find common anomalies
            neighbor_common_anomalies = {}
            for neighbor, indices in neighbor_anomalies.items():
                common_indices = anomaly_dates.intersection(indices)  
                if not common_indices.empty:
                    neighbor_common_anomalies[column] = common_indices


        if temp_reasoning == "True":
            temp_aux_data = prog_step.state[temp_aux_data]
            temp_aux_data.index.freq = None
            weather_anomalies = {}

            for column in temp_aux_data.columns:
                _, weather_anomalies[column] = self.detect_anomalies(temp_aux_data[column])
    
            # Compare and find common anomalies
            for weather_var, indices in weather_anomalies.items():
                common_indices = anomaly_dates.intersection(indices)  # Ensure intersection is called on Index objects
                if not common_indices.empty:
                    weather_common_anomalies[column] = common_indices 

        neigh_explanation = self.generate_explanation(neighbor_common_anomalies, "neighboring locations")     
        weather_explanation = self.generate_explanation(weather_common_anomalies, "weather variables")

        # print(anomaly_vals)
        # print(anomaly_dates)

        # print(neigh_explanation)
        # print(weather_explanation)

        if len(anomaly_dates) == 0:
            anomaly_details = f"No anomalies detected in {feature} data for location {location} during the observed period."
        else:
            anomaly_details = (
                f"Anomaly Values and Timestamps: {anomaly_vals}.\n"
                f"{neigh_explanation}\n" # Additional explanations only if relevant
                f"{weather_explanation}\n"
            )

        if time_int ==5:
            time_interval = "5 minutes"
        elif time_int ==60:
            time_interval = "1 hour" 
            
        if constraint =="weekdays only":
            time_cons = "weekdays only"
        elif constraint =="weekends only":
            time_cons = "weekends only"
        else:
            time_cons = "everyday"

        response = openai.ChatCompletion.create(
            model="gpt-3.5-turbo",  # or "gpt-4" if available
            messages=[
                {"role": "user", "content": (
                    f"You are an expert in spatiotemporal data anomaly detection. Given below are the details of the anomalies detected in {feature} data for location {location} recorded at every {time_interval}, on {time_cons}. \n"
                    f"{anomaly_details}.\n"
                    f"Start by listing the significant anomalies or a summary of anomalies detected:\n"
                    f"Then provide a detailed interpretation of the detected anomalies, noting the magnitude, specific timestamps, and comparison with normal conditions. Also include an interpretation of the meaning and significance of common anomalies. Consider the following in your analysis:\n"
                    f"- Patterns or trends in the timestamps of anomalies.\n"
                    f"- Possible logical or realistic reasons for these anomalies.\n"
                    f"- Evaluation of common anomalies with neighboring locations and weather conditions, especially considering the {feature} scenario. \n"
                    f"Neighboring Locations' Anomalies: {'None' if not neighbor_common_anomalies else ', '.join(f'{k}: {v}' for k, v in neighbor_common_anomalies.items())}\n"
                    f"Weather Variables' Anomalies: {'Data not evaluated for this scenario' if feature != 'air quality' and not weather_common_anomalies else 'None' if not weather_common_anomalies else ', '.join(f'{k}: {v}' for k, v in weather_common_anomalies.items())}\n"
                    f"Based on these details, also suggest any suspected events and recommend further investigations or immediate actions to be taken. Make sure that the final explanation clearly indicate the data collection constraints (weekdays, weekends, or every day)"       
                )}
            ],
            temperature=0.7,
            max_tokens=4096,
            top_p=1,
            frequency_penalty=0,
            presence_penalty=0)


        prog_step.state[output_var] = response['choices'][0]['message']['content'].strip() 
        return prog_step.state[output_var], self.text_summary([location], prog_step.state[output_var])

    def text_summary(self, inputs: List, output: Any):
        location = inputs
        return f"Anomaly Detection completed and corresponding explanation generated for location {location}.\n"


class ForecastInterpreter:
    step_name = 'FORECAST'

    def parse(self, prog_step):
        parse_result = parse_step(prog_step.prog_str)
        data = parse_result['args']['data']
        location = parse_result['args']['location']
        time = parse_result['args']['time']
        feature = parse_result['args']['feature']
        region = parse_result['args']['region']
        time_int = parse_result['args']['time_int']
        period = parse_result['args']['period']
        unit = parse_result['args']['unit']
        horizon = int(parse_result['args']['horizon'])
        horizon_unit = parse_result['args']['horizon_unit']
        output_var = parse_result['output_var']
        return data, location, time, feature, region, time_int, period, unit, horizon, horizon_unit, output_var
        

    def load_model(self, model_path, adj_mx, nodes, forecast_horizon):
        # Define the model architecture with appropriate parameters
        model = GraphWaveNet(
            num_nodes= nodes,
            dropout=0.3,
            supports=[torch.tensor(i) for i in adj_mx],  
            gcn_bool=True,
            addaptadj=True,
            in_dim=2,
            out_dim=forecast_horizon,  # Forecast horizon
            residual_channels=32,
            dilation_channels=32,
            skip_channels=256,
            end_channels=512,
            layers=2
        )

        # Load the state dictionary from the .pt file
        checkpoint = torch.load(model_path, map_location=torch.device('cpu'))  # use map_location if needed
        model.load_state_dict(checkpoint["model_state_dict"])
        model.eval()

        # # Load the checkpoint and filter the state dict if necessary
        # state_dict = torch.load(model_path, map_location=torch.device('cpu'))
        
        # # Check if the state_dict contains a nested 'model_state_dict' key
        # if 'model_state_dict' in state_dict:
        #     state_dict = state_dict['model_state_dict']
        
        # # Filter the state dict to only include matching keys, if there are any extra
        # filtered_state_dict = {k: v for k, v in state_dict.items() if k in model.state_dict()}
        # model.load_state_dict(filtered_state_dict, strict=False)
        
        # # Set the model to evaluation mode
        # model.eval()
        return model


    def scaler_transform(self, data, mean, std):
        return (data - mean) / std

    def scaler_inverse_transform(self, data, mean, std):
        return (data * std) + mean
    
        
    def predict_with_model(self, model, input_data, num_nodes, location_index, mean, std, horizon_steps):
        input_data[..., 0] = self.scaler_transform(input_data[..., 0], mean, std)
        with torch.no_grad():
            forecast = model(input_data, None, 0, 0, False)
        final_pred = forecast[0, :horizon_steps, location_index, 0]
        final_pred = self.scaler_inverse_transform(final_pred, mean, std)
        return final_pred
        

    def execute(self, prog_step, inspect=False):
        # Parse the arguments from the program step
        data, location, time, feature, region, time_int, period, unit, horizon, horizon_unit, output_var = self.parse(prog_step)

        # Parse the time as a datetime object
        end_time = datetime.strptime(time, '%Y-%m-%d %H:%M:%S')
        if unit == 'days':
            start_time = end_time - timedelta(days=period)
            hist_steps = int((period*24*60)/time_int)
        elif unit == 'hours':
            start_time = end_time - timedelta(hours=period)
            hist_steps = int((60*period)/time_int)
        elif unit == 'minutes':
            start_time = end_time - timedelta(minutes=period)
            hist_steps = int(period/time_int)
        else:
            raise ValueError(f"Unsupported time unit: {unit}")

        if horizon_unit == 'hours':
            horizon_steps = int((60*horizon)/time_int)
        elif horizon_unit == 'minutes':
            horizon_steps = int(horizon/time_int)
        else:
            raise ValueError(f"Unsupported horizon: {unit}")
        
        # Load the Pretrained Model
        if feature == "traffic speed":
            forecast_horizon = 12
            if region == "LA": 
                # data = pd.read_hdf('data/METR-LA/METR-LA.h5')  
                _, sensor_id_to_ind, _ = get_adj_mx_traffic('data/METR-LA/adj_METR-LA.pkl')
                adj_mx, _ = load_adj('data/METR-LA/adj_METR-LA.pkl', "doubletransition")
                num_nodes = 207
                model_path = 'data/METR-LA/GraphWaveNet_best_val_MAE.pt'
            elif region == "BAY": 
                # data = pd.read_hdf('data/PEMS-BAY/PEMS-BAY.h5')  
                _, sensor_id_to_ind, _ = get_adj_mx_traffic('data/PEMS-BAY/adj_PEMS-BAY.pkl')
                adj_mx, _ = load_adj('data/PEMS-BAY/adj_PEMS-BAY.pkl', "doubletransition")
                num_nodes = 325
                model_path = 'data/PEMS-BAY/GraphWaveNet_best_val_MAE.pt'
                
        elif feature == "air quality":
            forecast_horizon = 24
            if region == "Beijing":
                # data = pd.read_hdf('data/AirQuality/Beijing/beijing.h5')
                _, sensor_id_to_ind, _ = get_adj_mx_air('data/AirQuality/Beijing/beijing_stations.csv')
                adj_mx, _ = load_adj('data/AirQuality/Beijing/adj_mx_BEIJING.pkl', "doubletransition")
                num_nodes = 35
                model_path = 'data/AirQuality/Beijing/GraphWaveNet_best_val_MAE.pt'
            elif region == "Shenzhen":
                # data = pd.read_hdf('data/AirQuality/Shenzhen/shenzhen.h5')
                _, sensor_id_to_ind, _ = get_adj_mx_air('data/AirQuality/Shenzhen/shenzhen_stations.csv')
                adj_mx, _ = load_adj('data/AirQuality/Shenzhen/adj_mx_SHENZHEN.pkl', "doubletransition")
                num_nodes = 11
                model_path = 'data/AirQuality/Shenzhen/GraphWaveNet_best_val_MAE.pt'

        if region == "Shenzhen":
            location_index = sensor_id_to_ind.get(int(location))  # Get the index for the specified location
        else:
            location_index = sensor_id_to_ind.get(str(location))  # Get the index for the specified location
            
        self.model = self.load_model(model_path, adj_mx, num_nodes, forecast_horizon)  # Replace with the path to trained model

        # Filter the data for the specified  time range
        # time_df = data.loc[start_time:end_time]
        time_df = prog_step.state[data]
        # time_df.index.freq = None
        # print(time_df)
        hist_data = time_df.values[-hist_steps:]
        hist_data = torch.tensor(hist_data).float().unsqueeze(2)
        # print(hist_data.shape)  # [Hist_time_Steps, Num Nodes, 1]

        # Calculate time of the day variable
        time_ind = (time_df.index[-hist_steps:].values - time_df.index[-hist_steps:].values.astype("datetime64[D]")) / np.timedelta64(1, "D")
        time_in_day = np.tile(time_ind[:, np.newaxis], (1, num_nodes))
        time_data = torch.tensor(time_in_day).float().unsqueeze(2)
        # print(time_data.shape)  # [Hist_time_Steps, Num Nodes, 1]

        input_data = torch.cat([hist_data, time_data], dim=-1)
        input_data = input_data.unsqueeze(0)
        # print(input_data.shape) # [1, Hist_time_Steps, Num Nodes, 2]

        mean = input_data[..., 0].mean()
        std = input_data[..., 0].std()
        input_data[..., 0] = self.scaler_transform(input_data[..., 0],mean,std)

        # Make predictions with GraphWaveNet
        with torch.no_grad():
            # Dummy values for extra arguments - adjust based on model requirements
            future_data = None  # or set the appropriate future data
            batch_seen = 0  # For evaluation, you can set batch_seen to 0
            epoch = 0       # Epoch can be 0 since we're not training
            train = False   # Set train to False as it's evaluation
            forecast = self.model(input_data,future_data, batch_seen, epoch, train)
        # print(forecast.shape) # [1, Hist_time_Steps, Num Nodes, 1]

        final_pred = forecast[0, :horizon_steps, location_index, 0]  # Shape horizon_steps
        final_pred = self.scaler_inverse_transform(final_pred,mean,std)  # Shape horizon_steps
        predictions =  final_pred.numpy()
        # print(predictions.shape)

        # ### Adjusting Predictions to adhere to constraints ###
        # predictions = np.where(predictions > constraint_val, constraint_val, predictions)
        # # print(predictions)

        prog_step.state[output_var] = predictions 
        
        return prog_step.state[output_var], self.text_summary([location, feature, horizon, horizon_unit], prog_step.state[output_var])

    def text_summary(self, inputs: List, output: Any):
        location, feature, horizon, horizon_unit = inputs
        return f"Forecasted {feature} for location {location} for next {horizon} {horizon_unit} using Graph Wavenet Model." 
        
class SensitivityAnalysistInterpreter:
    step_name = 'CONDUCT_SENSITIVITY_ANALYSIS'

    def parse(self, prog_step):
        parse_result = parse_step(prog_step.prog_str)
        data = parse_result['args']['data']
        preds = parse_result['args']['preds']
        location = parse_result['args']['location']
        time = parse_result['args']['time']
        feature = parse_result['args']['feature']
        region = parse_result['args']['region']
        time_int = parse_result['args']['time_int']
        period = parse_result['args']['period']
        unit = parse_result['args']['unit']
        horizon = int(parse_result['args']['horizon'])
        horizon_unit = parse_result['args']['horizon_unit']
        output_var = parse_result['output_var']
        
        return data, preds, location, time, feature, region, time_int, period, unit, horizon, horizon_unit, output_var
        
    def load_model(self, model_path, adj_mx, nodes, forecast_horizon):
        # Define the model architecture with appropriate parameters
        model = GraphWaveNet(
            num_nodes= nodes,
            dropout=0.3,
            supports=[torch.tensor(i) for i in adj_mx],  
            gcn_bool=True,
            addaptadj=True,
            in_dim=2,
            out_dim=forecast_horizon,  # Forecast horizon
            residual_channels=32,
            dilation_channels=32,
            skip_channels=256,
            end_channels=512,
            layers=2
        )

        # Load the state dictionary from the .pt file
        checkpoint = torch.load(model_path, map_location=torch.device('cpu'))  # use map_location if needed
        model.load_state_dict(checkpoint["model_state_dict"])
        model.eval()
        return model


    def scaler_transform(self, data, mean, std):
        return (data - mean) / std

    def scaler_inverse_transform(self, data, mean, std):
        return (data * std) + mean
    
        
    def predict_with_model(self, model, input_data, num_nodes, location_index, mean, std, horizon_steps):
        input_data[..., 0] = self.scaler_transform(input_data[..., 0], mean, std)
        with torch.no_grad():
            forecast = model(input_data, None, 0, 0, False)
        final_pred = forecast[0, :horizon_steps, location_index, 0]
        final_pred = self.scaler_inverse_transform(final_pred, mean, std)
        return final_pred
        
    def execute_with_modified_adjacency(self, model_path, modified_adj_mx, num_nodes, input_data, horizon_steps, location_index,
                                       mean, std):
        
        model =  self.load_model(model_path, modified_adj_mx, num_nodes,horizon_steps)
        
        # Run evaluation
        with torch.no_grad():
            # Dummy values for extra arguments - adjust based on model requirements
            future_data = None  # or set the appropriate future data
            batch_seen = 0  # For evaluation, you can set batch_seen to 0
            epoch = 0       # Epoch can be 0 since we're not training
            train = False   # Set train to False as it's evaluation
        
            # Run the model with required additional arguments
            forecast = model(input_data, future_data, batch_seen, epoch, train)

        final_pred = forecast[0, :horizon_steps, location_index, 0]  
        final_pred = self.scaler_inverse_transform(final_pred,mean,std)  
        
        return final_pred  
        
    def adjacency_sensitivity_analysis(self, adj_mx, num_nodes, model_path, input_data, horizon_steps, 
                                       location_index, mean, std):
        baseline_adj_mx = adj_mx
        baseline_forecast = self.execute_with_modified_adjacency(model_path, baseline_adj_mx, num_nodes, input_data, horizon_steps,
                                                                 location_index, mean, std)
        
        impacts = []
    
        for node in range(num_nodes):
            if node == location_index:
                continue
    
            # Create a copy of the original adjacency matrix and remove connections for the current node
            modified_adj_mx = np.copy(baseline_adj_mx)
            modified_adj_mx[0][node, :] = 0  # Remove all outbound connections from the node
            modified_adj_mx[0][:, node] = 0  # Remove all inbound connections to the node
            modified_adj_mx[1][node, :] = 0  # Remove all outbound connections from the node
            modified_adj_mx[1][:, node] = 0  # Remove all inbound connections to the node
            
            # Execute the model with the modified adjacency matrix
            forecast_with_node_removed = self.execute_with_modified_adjacency(model_path, modified_adj_mx, num_nodes, input_data, horizon_steps,
                                                                              location_index, mean, std)
    
            # Calculate impact
            impact = (baseline_forecast - forecast_with_node_removed).abs().mean().item()
            impacts.append((node, impact))
    
        impacts.sort(key=lambda x: x[1], reverse=True)
        return impacts

    # Function to replace node_id with node_name
    def replace_node_id_with_name(self, impacts, sensor_id_to_ind):
        index_to_sensor_id = {v: k for k, v in sensor_id_to_ind.items()}

        updated_impacts = []
        for node_id, impact in impacts:
            node_name = index_to_sensor_id.get(int(node_id))  # Convert node_id to int and get the name
            if node_name is not None:
                updated_impacts.append((node_name, impact))
            else:
                updated_impacts.append((f"Unknown ID: {node_id}", impact))
        return updated_impacts

    def temporal_sensitivity_analysis(self, model, input_data, time_indices, num_nodes, location_index, mean, std, horizon_steps):
        baseline_forecast = self.predict_with_model(model, input_data, num_nodes, location_index, mean, std, horizon_steps)
        impacts = []
    
        for i, time_index in enumerate(time_indices):
            # Modify the input data for the specific time index
            modified_input_data = input_data.clone()
            modified_input_data[0, i, :, :] = 0  # Zero out the specific time step
    
            # Execute the model with the modified input data
            modified_forecast = self.predict_with_model(model, modified_input_data, num_nodes, location_index, mean, std, horizon_steps)
    
            # Calculate the impact of the modification
            impact = (baseline_forecast - modified_forecast).abs().mean().item()
            impacts.append((time_index, impact))
    
        impacts.sort(key=lambda x: x[1], reverse=True)  # Sort by impact
        return impacts     

    def execute(self, prog_step, inspect=False):
        # Parse the arguments from the program step
        data, preds, location, time, feature, region, time_int, period, unit, horizon, horizon_unit, output_var = self.parse(prog_step)

        # Parse the time as a datetime object
        end_time = datetime.strptime(time, '%Y-%m-%d %H:%M:%S')
        if unit == 'days':
            start_time = end_time - timedelta(days=period)
            hist_steps = int((period*24*60)/time_int)
        elif unit == 'hours':
            start_time = end_time - timedelta(hours=period)
            hist_steps = int((60*period)/time_int)
        elif unit == 'minutes':
            start_time = end_time - timedelta(minutes=period)
            hist_steps = int(period/time_int)
        else:
            raise ValueError(f"Unsupported time unit: {unit}")

        if horizon_unit == 'hours':
            horizon_steps = int((60*horizon)/time_int)
        elif horizon_unit == 'minutes':
            horizon_steps = int(horizon/time_int)
        else:
            raise ValueError(f"Unsupported horizon: {unit}")

        # Load the Pretrained Model
        if feature == "traffic speed":
            forecast_horizon = 12
            if region == "LA": 
                # data = pd.read_hdf('data/METR-LA/METR-LA.h5')  
                _, sensor_id_to_ind, _ = get_adj_mx_traffic('data/METR-LA/adj_METR-LA.pkl')
                adj_mx, _ = load_adj('data/METR-LA/adj_METR-LA.pkl', "doubletransition")
                num_nodes = 207
                model_path = 'data/METR-LA/GraphWaveNet_best_val_MAE.pt'
            elif region == "BAY": 
                # data = pd.read_hdf('data/PEMS-BAY/PEMS-BAY.h5')  
                _, sensor_id_to_ind, _ = get_adj_mx_traffic('data/PEMS-BAY/adj_PEMS-BAY.pkl')
                adj_mx, _ = load_adj('data/PEMS-BAY/adj_PEMS-BAY.pkl', "doubletransition")
                num_nodes = 325
                model_path = 'data/PEMS-BAY/GraphWaveNet_best_val_MAE.pt'
                
        elif feature == "air quality":
            forecast_horizon = 24
            if region == "Beijing":
                # data = pd.read_hdf('data/AirQuality/Beijing/beijing.h5')
                _, sensor_id_to_ind, _ = get_adj_mx_air('data/AirQuality/Beijing/beijing_stations.csv')
                adj_mx, _ = load_adj('data/AirQuality/Beijing/adj_mx_BEIJING.pkl', "doubletransition")
                num_nodes = 35
                model_path = 'data/AirQuality/Beijing/GraphWaveNet_best_val_MAE.pt'
            elif region == "Shenzhen":
                # data = pd.read_hdf('data/AirQuality/Shenzhen/shenzhen.h5')
                _, sensor_id_to_ind, _ = get_adj_mx_air('data/AirQuality/Shenzhen/shenzhen_stations.csv')
                adj_mx, _ = load_adj('data/AirQuality/Shenzhen/adj_mx_SHENZHEN.pkl', "doubletransition")
                num_nodes = 11
                model_path = 'data/AirQuality/Shenzhen/GraphWaveNet_best_val_MAE.pt'

        if region == "Shenzhen":
            location_index = sensor_id_to_ind.get(int(location))  # Get the index for the specified location
        else:
            location_index = sensor_id_to_ind.get(str(location))  # Get the index for the specified location
            
        self.model = self.load_model(model_path, adj_mx, num_nodes, forecast_horizon)  # Replace with the path to trained model

        time_df = prog_step.state[data]
        hist_data = time_df.values[-hist_steps:]
        hist_data = torch.tensor(hist_data).float().unsqueeze(2)
        # print(hist_data.shape)  # [Hist_time_Steps, Num Nodes, 1]

        # Calculate time of the day variable
        time_ind = (time_df.index[-hist_steps:].values - time_df.index[-hist_steps:].values.astype("datetime64[D]")) / np.timedelta64(1, "D")
        time_in_day = np.tile(time_ind[:, np.newaxis], (1, num_nodes))
        time_data = torch.tensor(time_in_day).float().unsqueeze(2)
        # print(time_data.shape)  # [Hist_time_Steps, Num Nodes, 1]

        input_data = torch.cat([hist_data, time_data], dim=-1)
        input_data = input_data.unsqueeze(0)
        # print(input_data.shape) # [1, Hist_time_Steps, Num Nodes, 2]

        mean = input_data[..., 0].mean()
        std = input_data[..., 0].std()
        input_data[..., 0] = self.scaler_transform(input_data[..., 0],mean,std)
            
        ##### Spatial Sensitivity Analysis #####
        impacts = self.adjacency_sensitivity_analysis(adj_mx, num_nodes, model_path, input_data, forecast_horizon, 
                                                      location_index, mean, std)
        top_impacts = sorted(impacts, key=lambda x: x[1], reverse=True)[:10]  # top 10 impacts
        top_impacts = self.replace_node_id_with_name(top_impacts, sensor_id_to_ind)
        impacts_text = ', '.join(f'Location {node} impacts by {impact:.2f}' for node, impact in top_impacts)

        ##### Temporal Sensitivity Analysis #####
        times =  pd.to_datetime(time_df.index[-hist_steps:].values)
        temporal_impacts = self.temporal_sensitivity_analysis(self.model, input_data, times, num_nodes, location_index, mean, std, horizon_steps)
        temp_impact_text = ', '.join(f'Timestamp {timestamp} impacts by {temporal_impact:.2f}' for timestamp, temporal_impact in temporal_impacts)

        # print(impacts_text)
        # print(temp_impact_text)

        # Now, format them into one combined string
        spatiotemporal_sensitivity = (
            f"Significant Nodes with corresponding node impacts: {impacts_text}\n"
            f"Timestamp Sensitivities: {temp_impact_text}\n"
        )
        prog_step.state[output_var] = spatiotemporal_sensitivity

        return prog_step.state[output_var], self.text_summary([location, feature, horizon_steps], prog_step.state[output_var])

    
    def text_summary(self, inputs: List, output: Any):
        location, feature, horizon_steps = inputs
        return f"Spatial and Temporal Sensitivity analysis conducted." 


class ExtractForecastGT:
    step_name = 'EXTRACT_FORECAST_GT'
    
    def parse(self, prog_step):
        parse_result = parse_step(prog_step.prog_str)
        location = parse_result['args']['location']
        time = parse_result['args']['time']
        feature = parse_result['args']['feature']
        region = parse_result['args']['region']
        time_int = parse_result['args']['time_int']
        period = parse_result['args']['period']
        unit = parse_result['args']['unit']
        horizon = int(parse_result['args']['horizon'])
        horizon_unit = parse_result['args']['horizon_unit']
        constraint = parse_result['args']['constraint']
        constraint_val = parse_result['args']['constraint_val']
        output_var = parse_result['output_var']
        return location, time, feature, region, time_int, period, unit, horizon, horizon_unit, constraint, constraint_val, output_var
         
    def execute(self, prog_step, inspect=False):
       # Parse the arguments from the program step
        location, time, feature, region, time_int, period, unit, horizon, horizon_unit, constraint, constraint_val, output_var = self.parse(prog_step)
        
        # Load the .h5 file
        if feature == "traffic speed":
            if region == "LA": 
                data = pd.read_hdf('data/METR-LA/METR-LA.h5')  # Replace with your actual file path
            if region == "BAY": 
                data = pd.read_hdf('data/PEMS-BAY/PEMS-BAY.h5')  # Replace with your actual file path
        elif feature == "air quality":
            if region == "Beijing":
                data = pd.read_hdf('data/AirQuality/Beijing/beijing.h5')
            elif region == "Shenzhen":
                data = pd.read_hdf('data/AirQuality/Shenzhen/shenzhen.h5')  

        # Parse the time as a datetime object
        start_time = datetime.strptime(time, '%Y-%m-%d %H:%M:%S')
        
        if horizon_unit == 'days':
            end_time = start_time + timedelta(days=horizon)
            horizon_steps = int((24*60*horizon)/time_int)
        if horizon_unit == 'hours':
            end_time = start_time + timedelta(hours=horizon)
            horizon_steps = int((60*horizon)/time_int)
        elif horizon_unit == 'minutes':
            end_time = start_time + timedelta(minutes=horizon)
            horizon_steps = int(horizon/time_int)
        else:
            raise ValueError(f"Unsupported horizon: {unit}")


        # Filter the data for the specified location and time range
        time_df = data.loc[start_time:end_time]
        time_df.columns = time_df.columns.astype(str)
        # print(time_df.head)
        st_df = time_df[[str(location)]]
        # print(st_df.head)
        # print(st_df[-horizon_steps:].values.flatten())
        
        # Store the data in the program state (last 12 timesteps)
        prog_step.state[output_var] = st_df[-horizon_steps:].values.flatten()
        
        # st_df.index.freq = None
        # datetime_index = pd.to_datetime(st_df.index)       
        
        return prog_step.state[output_var], self.text_summary([location, start_time, end_time, feature], prog_step.state[output_var])

    def text_summary(self, inputs: List, output: Any):
        location, start_time, end_time, feature = inputs
        return f"Extracted ground truth data for Location: {location}, Feature: {feature}, Time Range: From {start_time} to {end_time}."



class ResultInterpreter:
    step_name = 'REFINE_OUTPUT'

    def parse(self, prog_step):
        parse_result = parse_step(prog_step.prog_str)
        output_var = parse_result['args']['var']
        return output_var

    def execute(self, prog_step, inspect=False):
        output_var = self.parse(prog_step)
        result = prog_step.state[output_var]
        return result, self.text_summary([output_var], result)

    def text_summary(self, inputs: List, output: Any):
        output_var = inputs[0]
        return f"Final result for variable '{output_var}': {output}"


# Helper function for registering interpreters
def register_step_interpreters():
    return dict(
        LOAD_SPATIOTEMPORAL_DATA=LoadSpatiotemporalDataInterpreter(),
        LOAD_SPATIAL_AUX_DATA=LoadSpatialAuxDataInterpreter(),
        LOAD_TEMPORAL_AUX_DATA=LoadTemporalAuxDataInterpreter(),
        IMPOSE_CONSTRAINTS=ImposeConstraintsInterpreter(),
        ANALYZE_TREND=STTrendInterpreter(),
        ANALYZE_SEASONALITY=STSeasonalityInterpreter(),
        ANALYZE_NEIGHBOURHOOD=STNeighbourhoodInterpreter(),
        GEN_EXPLANATION=ExplainInterpreter(),
        DETECT_ANOMALY_ST_DATA=STAnomalyInterpreter(),
        FORECAST=ForecastInterpreter(),
        CONDUCT_SENSITIVITY_ANALYSIS=SensitivityAnalysistInterpreter(),
        EXTRACT_FORECAST_GT = ExtractForecastGT(),
        REFINE_OUTPUT=ResultInterpreter(),
    )
