import pickle
import pandas as pd
import numpy as np

from haversine import haversine, Unit
from torch_geometric.utils import dense_to_sparse


def load_pickle(pickle_file):
    try:
        with open(pickle_file, 'rb') as f:
            pickle_data = pickle.load(f)
    except UnicodeDecodeError as e:
        with open(pickle_file, 'rb') as f:
            pickle_data = pickle.load(f, encoding='latin1')
    except Exception as e:
        print('Unable to load data ', pickle_file, ':', e)
        raise
    return pickle_data

def get_adj_mx_traffic(pkl_filename):
    sensor_ids, sensor_id_to_ind, adj_mx = load_pickle(pkl_filename)
    return sensor_ids, sensor_id_to_ind, adj_mx

def get_adj_mx_air(csv_file, normalized_k=0.1):
    station_df = pd.read_csv(csv_file)
    sensor_ids = station_df['station']
    num_sensors = len(sensor_ids)

    # Builds sensor id to index map.
    sensor_id_to_ind = {}
    for i, sensor_id in enumerate(sensor_ids):
        sensor_id_to_ind[sensor_id] = i

    # Calculate the Distance Matrix
    dist_mx = np.zeros((num_sensors, num_sensors))
    for i in range(num_sensors):
        for j in range(i + 1, num_sensors):
            coords1 = (station_df.loc[i, 'latitude'], station_df.loc[i, 'longitude'])
            coords2 = (station_df.loc[j, 'latitude'], station_df.loc[j, 'longitude'])
            distance = haversine(coords1, coords2, unit=Unit.KILOMETERS)
            if distance != 0:
                dist_mx[i, j] = dist_mx[j, i] = 1 / distance

    # Normalize the distance matrix
    std_distance = np.std(dist_mx[dist_mx != 0])
    normalized_dist_mx = dist_mx / std_distance

    # Apply exponential decay
    adj_mx = np.exp(-normalized_dist_mx)

    # Apply threshold for sparsity
    adj_mx[adj_mx < normalized_k] = 0

    # Make the adjacency matrix symmetric
    adj_mx = np.maximum(adj_mx, adj_mx.T)

    return sensor_ids, sensor_id_to_ind, adj_mx
    

def get_neighbors(feature, location, region, threshold=0.6):
    if feature == "traffic speed":
        if region == "LA": 
            _, sensor_id_to_ind, adj_mx = get_adj_mx_traffic('data/METR-LA/adj_METR-LA.pkl')
        if region == "BAY": 
            _, sensor_id_to_ind, adj_mx = get_adj_mx_traffic('data/PEMS-BAY/adj_PEMS-BAY.pkl')
    elif feature == "air quality":
        if region == "Beijing":
            _, sensor_id_to_ind, adj_mx = get_adj_mx_air('data/AirQuality/Beijing/beijing_stations.csv')
        elif region == "Shenzhen":
            _, sensor_id_to_ind, adj_mx = get_adj_mx_air('data/AirQuality/Shenzhen/shenzhen_stations.csv')
    
    if region == "Shenzhen":
        location_index = sensor_id_to_ind.get(int(location))  # Get the index for the specified location
    else:
        location_index = sensor_id_to_ind.get(str(location))  # Get the index for the specified location
        
    #Create reverse Mapping
    index_to_sensor = {index: sensor_id for sensor_id, index in sensor_id_to_ind.items()}
    
    neighbors = []
    for idx, weight in enumerate(adj_mx[location_index]):
        if weight > threshold:  # Threshold can be set to filter weak connections
            loc = index_to_sensor.get(idx)  # Get the index for the specified location
            neighbors.append(loc)
                
    return neighbors

class Program:
    def __init__(self,prog_str,init_state=None):
        self.prog_str = prog_str
        self.state = init_state if init_state is not None else dict()
        self.instructions = self.prog_str.split('\n')


class ProgramGenerator():
    def __init__(self,prompter,temperature=0.7,top_p=0.5,prob_agg='mean'):
        openai.api_key = os.getenv("OPENAI_API_KEY")
        self.prompter = prompter
        self.temperature = temperature
        self.top_p = top_p
        self.prob_agg = prob_agg

    def compute_prob(self,response):
        eos = '<|endoftext|>'
        for i,token in enumerate(response.choices[0]['logprobs']['tokens']):
            if token==eos:
                break

        if self.prob_agg=='mean':
            agg_fn = np.mean
        elif self.prob_agg=='sum':
            agg_fn = np.sum
        else:
            raise NotImplementedError

        return np.exp(agg_fn(
            response.choices[0]['logprobs']['token_logprobs'][:i]))

    def generate(self, inputs):
        response = openai.ChatCompletion.create(
            model="gpt-3.5-turbo",  # or "gpt-4" if available
            messages=[
                # {"role": "system", "content": "You are a helpful assistant."},
                {"role": "user", "content": self.prompter(inputs)}
            ],
            temperature=self.temperature,
            max_tokens=512,
            top_p=self.top_p,
            frequency_penalty=0,
            presence_penalty=0
        )
        prog = response.choices[0].message['content'].strip()
        return prog, None


class ProgramInterpreter:
    def __init__(self):
        # Register the spatiotemporal step interpreters
        self.step_interpreters = register_step_interpreters()

    def execute_step(self, prog_step, inspect=False):
        # Parse the step to get the step name
        parsed_step = parse_step(prog_step.prog_str, partial=True)
        step_name = parsed_step['step_name']

        # Ensure the step interpreter is registered
        if step_name not in self.step_interpreters:
            raise ValueError(f"Step '{step_name}' not registered in interpreters.")

        # Execute the step using the corresponding interpreter
        result = self.step_interpreters[step_name].execute(prog_step, inspect)

        # Handle inspection mode with plain text summaries
        if inspect:
            if isinstance(result, tuple) and len(result) == 2:
                return result  # (output, text_summary)
            else:
                return result, f"Step '{step_name}' executed with no additional output."
        else:
            return result

    def execute(self, prog, init_state=None, inspect=False):
        # Initialize program if provided as a string
        if isinstance(prog, str):
            prog = Program(prog, init_state or {})
        else:
            assert isinstance(prog, Program), "prog must be a string or an instance of Program."

        # Create program steps from instructions
        prog_steps = [Program(instruction, init_state=prog.state) for instruction in prog.instructions]

        # Prepare for textual output if inspection is enabled
        summary_text = ""
        step_output = None

        # Execute each step
        for prog_step in prog_steps:
            if inspect:
                step_output, step_summary = self.execute_step(prog_step, inspect)
                # print(step_summary)
                # print(step_output)
                summary_text += step_summary + "\n"
            else:
                step_output = self.execute_step(prog_step, inspect)

        # Return appropriate results
        if inspect:
            return step_output, prog.state, summary_text
        return step_output, prog.state