import torch
import pickle
import math
import openai
import os
import execnet
import re

import pandas as pd
import numpy as np
import io, tokenize
import matplotlib.pyplot as plt
import matplotlib.dates as mdates

from datetime import datetime, timedelta
from statsmodels.tsa.arima.model import ARIMA
from statsmodels.tsa.seasonal import seasonal_decompose
from typing import List, Dict, Any
from sklearn.linear_model import LinearRegression
from scipy.stats import linregress
from haversine import haversine, Unit
from torch_geometric.utils import dense_to_sparse

from models.gwnet_arch import GraphWaveNet
from models.gwnet_utils import *
from func_utils import *



def parse_step(step_str, partial=False):
    tokens = list(tokenize.generate_tokens(io.StringIO(step_str).readline))
    # print(tokens)

    
    # Extract output variables by splitting up to '='
    output_var = step_str.split('=')[0].strip().split(',')[0].strip()  
    
    # Extract step name (third token)
    step_name = tokens[2].string
    
    # Initialize result dictionary
    parsed_result = {
        'output_var': output_var,
        'step_name': step_name
    }
    
    # If partial parsing is required
    if partial:
        return parsed_result

    # Parse arguments
    arg_tokens = [token for token in tokens[4:-1] if token.string not in [',','=']]
    num_tokens = len(arg_tokens) // 2
    args = {}
    
    for i in range(num_tokens):
        key = arg_tokens[2 * i].string
        value = arg_tokens[2 * i + 1].string.strip('"').strip("'")
        
        # Convert numeric strings to integers if applicable
        if value.isdigit():
            value = int(value)
        
        args[key] = value

    parsed_result['args'] = args
    return parsed_result


class LoadSpatiotemporalDataInterpreter:
    step_name = 'LOAD_SPATIOTEMPORAL_DATA'

    def parse(self, prog_step):
        parse_result = parse_step(prog_step.prog_str)
        location = parse_result['args']['location']
        time = parse_result['args']['time']
        feature = parse_result['args']['feature']
        region = parse_result['args']['region']
        time_int = parse_result['args']['time_int']
        period = parse_result['args']['period']
        unit = parse_result['args']['unit']
        constraints = parse_result['args']['constraints']
        output_var = parse_result['output_var']  

        return location, time, feature, region, output_var, time_int, period, unit, constraints

    def execute(self, prog_step, inspect=False):
        location, time, feature, region, output_var, time_int, period, unit, constraints = self.parse(prog_step)
        
        # Load the .h5 file
        if feature == "traffic speed":
            if region == "LA": 
                data = pd.read_hdf('data/METR-LA/METR-LA.h5')  # Replace with your actual file path
            if region == "BAY": 
                data = pd.read_hdf('data/PEMS-BAY/PEMS-BAY.h5')  # Replace with your actual file path
        elif feature == "air quality":
            if region == "Beijing":
                data = pd.read_hdf('data/AirQuality/Beijing/beijing.h5')
            elif region == "Shenzhen":
                data = pd.read_hdf('data/AirQuality/Shenzhen/shenzhen.h5')  

        # Parse the time as a datetime object
        end_time = datetime.strptime(time, '%Y-%m-%d %H:%M:%S')
        if unit == 'days':
            start_time = end_time - timedelta(days=period)
            steps = int((period*24*60)/time_int)
        elif unit == 'hours':
            start_time = end_time - timedelta(hours=period)
            steps = int((60*period)/time_int)
        elif unit == 'minutes':
            start_time = end_time - timedelta(minutes=period)
            steps = int(period/time_int)
        else:
            raise ValueError(f"Unsupported time unit: {unit}")

        # start_time = end_time - timedelta(minutes=55)  # Adjust for 12 time steps with 5-minute intervals

        # Filter the data for the specified location and time range
        time_df = data.loc[start_time:end_time]
        time_df.columns = time_df.columns.astype(str)
        # print(time_df.head)
        st_df = time_df[[str(location)]]
        # st_df.index.freq = None

        if constraints == "weekdays only":            
            st_df = st_df[st_df.index.dayofweek < 5]  # Filter for weekdays
        if constraints == "weekends only":            
            st_df = st_df[st_df.index.dayofweek >=5]  # Filter for weekends   
        # print(st_df)

        # Store the data in the program state (last 12 timesteps)
        prog_step.state[output_var] = st_df[-steps:]
        
        # st_df.index.freq = None
        # datetime_index = pd.to_datetime(st_df.index)       
        
        return prog_step.state[output_var], self.text_summary([location, start_time, end_time, feature], prog_step.state[output_var])

    def text_summary(self, inputs: List, output: Any):
        location, start_time, end_time, feature = inputs
        return f"Loaded data for Location: {location}, Feature: {feature}, Time Range: From {start_time} to {end_time}."



class LoadSpatialAuxDataInterpreter:
    step_name = 'LOAD_SPATIAL_AUX_DATA'

    def parse(self, prog_step):
        parse_result = parse_step(prog_step.prog_str)
        spatial_var = parse_result['args']['spatial_var']
        location = parse_result['args']['location']
        time = parse_result['args']['time']
        feature = parse_result['args']['feature']
        region = parse_result['args']['region']
        time_int = parse_result['args']['time_int']
        period = parse_result['args']['period']
        unit = parse_result['args']['unit']
        constraints = parse_result['args']['constraints']
        output_var = parse_result['output_var']  

        return spatial_var, location, time, feature, region, output_var, time_int, period, unit, constraints
   
    def execute(self, prog_step, inspect=False):
        spatial_var, location, time, feature, region, output_var, time_int, period, unit, constraints = self.parse(prog_step)
        
        # Load the .h5 file
        if feature == "traffic speed":
            if region == "LA": 
                data = pd.read_hdf('data/METR-LA/METR-LA.h5')  # Replace with your actual file path
            if region == "BAY": 
                data = pd.read_hdf('data/PEMS-BAY/PEMS-BAY.h5')  # Replace with your actual file path
        elif feature == "air quality":
            if region == "Beijing":
                data = pd.read_hdf('data/AirQuality/Beijing/beijing.h5')
            elif region == "Shenzhen":
                data = pd.read_hdf('data/AirQuality/Shenzhen/shenzhen.h5')  

        # Parse the time as a datetime object
        end_time = datetime.strptime(time, '%Y-%m-%d %H:%M:%S')
        if unit == 'days':
            start_time = end_time - timedelta(days=period)
            steps = int((period*24*60)/time_int)
        elif unit == 'hours':
            start_time = end_time - timedelta(hours=period)
            steps = int((60*period)/time_int)
        elif unit == 'minutes':
            start_time = end_time - timedelta(minutes=period)
            steps = int(period/time_int)
        else:
            raise ValueError(f"Unsupported time unit: {unit}")

        # data.index.freq = None
        # print(data.head)
        
        # Filter the data for the specified location and time range
        time_df = data.loc[start_time:end_time]
        time_df.columns = time_df.columns.astype(str)

        if spatial_var == "neighbour":
            neighbors = get_neighbors(feature, location, region, threshold=0.6)
            neighbor_list = [str(num) for num in neighbors]
            st_df = time_df[neighbor_list]

        if constraints == "weekdays only":            
            st_df = st_df[st_df.index.dayofweek < 5]  # Filter for weekdays
        if constraints == "weekends only":            
            st_df = st_df[st_df.index.dayofweek >=5]  # Filter for weekends   
        
        prog_step.state[output_var] = st_df[-steps:]    
        
        return prog_step.state[output_var], self.text_summary([spatial_var, location, start_time, end_time, feature], prog_step.state[output_var])

    def text_summary(self, inputs: List, output: Any):
        spatial_var, location, start_time, end_time, feature = inputs
        return f"Loaded {spatial_var} data as spatial auxiliary for Location: {location}, Feature: {feature}, Time Range: From {start_time} to {end_time}."

class LoadTemporalAuxDataInterpreter:
    step_name = 'LOAD_TEMPORAL_AUX_DATA'

    def parse(self, prog_step):
        parse_result = parse_step(prog_step.prog_str)
        temp_var = parse_result['args']['temp_var']
        location = parse_result['args']['location']
        time = parse_result['args']['time']
        feature = parse_result['args']['feature']
        region = parse_result['args']['region']
        time_int = parse_result['args']['time_int']
        period = parse_result['args']['period']
        unit = parse_result['args']['unit']
        constraints = parse_result['args']['constraints']
        output_var = parse_result['output_var']  

        return temp_var, location, time, feature, region, output_var, time_int, period, unit, constraints
   
    def execute(self, prog_step, inspect=False):
        temp_var, location, time, feature, region, output_var, time_int, period, unit, constraints = self.parse(prog_step)

        if temp_var == "weather":
            # Load Weather Data 
            if feature == "air quality":
                if region == "Beijing":
                    humidity = pd.read_hdf('data/AirQuality/Beijing/weather/beijing_humidity.h5')
                    pressure = pd.read_hdf('data/AirQuality/Beijing/weather/beijing_pressure.h5')
                    temperature = pd.read_hdf('data/AirQuality/Beijing/weather/beijing_temperature.h5')
                    ws = pd.read_hdf('data/AirQuality/Beijing/weather/beijing_ws.h5')
    
                elif region == "Shenzhen":
                    humidity = pd.read_hdf('data/AirQuality/Shenzhen/weather/shenzhen_humidity.h5')
                    pressure = pd.read_hdf('data/AirQuality/Shenzhen/weather/shenzhen_pressure.h5')
                    temperature = pd.read_hdf('data/AirQuality/Shenzhen/weather/shenzhen_temperature.h5')
                    ws = pd.read_hdf('data/AirQuality/Shenzhen/weather/shenzhen_ws.h5')
                
            # Parse the time as a datetime object
            end_time = datetime.strptime(time, '%Y-%m-%d %H:%M:%S')
            if unit == 'days':
                start_time = end_time - timedelta(days=period)
                steps = int((period*24*60)/time_int)
            elif unit == 'hours':
                start_time = end_time - timedelta(hours=period)
                steps = int((60*period)/time_int)
            elif unit == 'minutes':
                start_time = end_time - timedelta(minutes=period)
                steps = int(period/time_int)
            else:
                raise ValueError(f"Unsupported time unit: {unit}")
    
            all_data = [humidity, pressure, temperature, ws]
            data_names = ['Humidity', 'Pressure', 'Temperature', 'WS']
            
            processed_data = []  
            for i, name in zip(all_data, data_names):    
                time_df = i.loc[start_time:end_time]
                time_df.columns = time_df.columns.astype(str)
                st_df = time_df[[str(location)]]
                if constraints == "weekdays only":            
                    st_df = st_df[st_df.index.dayofweek < 5]  # Filter for weekdays
                if constraints == "weekends only":            
                    st_df = st_df[st_df.index.dayofweek >=5]  # Filter for weekends   
                    
                st_df.columns = [name]
                processed_data.append(st_df)
            
            weather_data = pd.concat(processed_data, axis=1)
        
            prog_step.state[output_var] = weather_data[-steps:]    
        
        return prog_step.state[output_var], self.text_summary([temp_var, location, start_time, end_time], prog_step.state[output_var])

    def text_summary(self, inputs: List, output: Any):
        temp_var,location, start_time, end_time = inputs
        return f"Loaded {temp_var} data as temporal auxiliary data for Location: {location}, Time Range: From {start_time} to {end_time}."


class CreateFinalInputInterpreter:
    step_name = 'CREATE_FINAL_INPUT'

    def parse(self, prog_step):
        parse_result = parse_step(prog_step.prog_str)
        data0 = parse_result['args']['data0']
        spatial_data = parse_result['args']['spatial_data']
        temporal_data = parse_result['args']['temporal_data']
        output_var = parse_result['output_var']

        return data0, spatial_data, temporal_data , output_var

    def execute(self, prog_step, inspect=False):
        data0, spatial_data, temporal_data , output_var = self.parse(prog_step)

        prog_step.state[data0].index.freq = None
        # Start forming the input_data string with mandatory data
        input_data_parts = [f"Input Data:\n{prog_step.state[data0]}"]
        prog_step.state[data0].to_csv("HumanEval/QueryInput/input_119.csv")

        # Conditionally add spatial and temporal data if they are not None
        if spatial_data != "None":
            prog_step.state[spatial_data].index.freq = None
            input_data_parts.append(f"Neighbour Data:\n{prog_step.state[spatial_data]}")
            prog_step.state[spatial_data].to_csv("HumanEval/QueryInput/spatial_input_119.csv")

        if temporal_data != "None":
            prog_step.state[temporal_data].index.freq = None
            input_data_parts.append(f"Weather Data:\n{prog_step.state[temporal_data]}")
            prog_step.state[temporal_data].to_csv("HumanEval/QueryInput/temporal_input_119.csv")


        # Join all parts into a single string
        input_data = ".\n".join(input_data_parts) + "."

        # print(input_data)
        
        # Update state with explanation
        prog_step.state[output_var] = input_data
        return input_data, self.text_summary([data0], input_data)

    def text_summary(self, inputs: List, output: Any):
        feature = inputs
        return (
            f"Final Input Data Generated. "
        )

class ResultInterpreter:
    step_name = 'RESULT'

    def parse(self, prog_step):
        parse_result = parse_step(prog_step.prog_str)
        output_var = parse_result['args']['var']
        return output_var

    def execute(self, prog_step, inspect=False):
        output_var = self.parse(prog_step)
        result = prog_step.state[output_var]
        return result, self.text_summary([output_var], result)

    def text_summary(self, inputs: List, output: Any):
        output_var = inputs[0]
        return f"Final result for variable '{output_var}': {output}"


# Helper function for registering interpreters
def register_step_interpreters():
    return dict(
        LOAD_SPATIOTEMPORAL_DATA=LoadSpatiotemporalDataInterpreter(),
        LOAD_SPATIAL_AUX_DATA=LoadSpatialAuxDataInterpreter(),
        LOAD_TEMPORAL_AUX_DATA=LoadTemporalAuxDataInterpreter(),
        CREATE_FINAL_INPUT=CreateFinalInputInterpreter(),
        RESULT=ResultInterpreter(),
    )
