import numpy as np
import json
import os
from tqdm import tqdm
import sys
import random

def extract_traffic_data(data_path, output_path, feature_idx=0, prediction_length=12, sample_interval=5, batch_size=1000, 
                          sample_nodes=False, node_count=10, random_seed=42):
    dataset = []
    last_processed_idx = -1
    last_processed_node = 0
    
    if os.path.exists(output_path):
        print(f"Found existing data file {output_path}, attempting to load...")
        try:
            with open(output_path, 'r', encoding='utf-8') as f:
                dataset = json.load(f)
                
            print(f"Successfully loaded existing data, contains {len(dataset)} samples")
            if dataset:
                try:

                    last_sample = dataset[-1]
                    instruction_data = last_sample["instruction"]

                    start_idx = instruction_data.find("[")
                    end_idx = instruction_data.find("]")
                    if start_idx != -1 and end_idx != -1:
                        data_array_str = instruction_data[start_idx:end_idx+1]

                        last_processed_idx = data_array_str.count(",") + 1 - 1
                    



                    last_processed_node = 0

                except Exception as e:
                    print(f"Error parsing last processed position: {e}")
                    print("Will continue from last saved position...")
        except Exception as e:
            print(f"Failed to load existing data file: {e}")
            print("Will create new data file...")
            dataset = []
    
    print(f"Loading data from {data_path}...")
    data = np.load(data_path)
    
    print(f"Dataset keys: {data.files}")
    
    if 'data' in data.files:
        traffic_data = data['data']
    else:

        key = data.files[0]
        traffic_data = data[key]
    
    print(f"Original data shape: {traffic_data.shape}")

    traffic_data = traffic_data[:, :, feature_idx:feature_idx+1]
    print(f"Shape after keeping feature: {traffic_data.shape}")

    traffic_data = traffic_data.squeeze(axis=-1)
    print(f"Shape after compression: {traffic_data.shape}")
    
    time_steps, num_nodes = traffic_data.shape

    selected_nodes = list(range(num_nodes))
    if sample_nodes:
        if node_count >= num_nodes:
            print(f"Specified node count {node_count} is greater than or equal to total nodes {num_nodes}, will use all nodes")
        else:

            random.seed(random_seed)
            selected_nodes = sorted(random.sample(range(num_nodes), node_count))
            print(f"Randomly sampled {len(selected_nodes)} nodes: {selected_nodes[:10]}{'...' if len(selected_nodes) > 10 else ''}")
    
    print(f"Starting to extract data pairs, prediction length is {prediction_length}...")
    print(f"Starting from time index {last_processed_idx + 1} and node {last_processed_node}")

    num_samples = time_steps - prediction_length

    current_batch = []
    total_new_samples = 0

    log_file = open(output_path + ".log", "a", encoding="utf-8")
    log_file.write(f"===== Starting new data processing session {import_time.strftime('%Y-%m-%d %H:%M:%S')} =====\n")
    log_file.write(f"Data path: {data_path}\n")
    log_file.write(f"Output path: {output_path}\n")
    log_file.write(f"Feature index: {feature_idx}\n")
    log_file.write(f"Prediction length: {prediction_length}\n")
    log_file.write(f"Node sampling: {'Yes' if sample_nodes else 'No'}\n")
    if sample_nodes:
        log_file.write(f"Sampled nodes count: {len(selected_nodes)}\n")
        log_file.write(f"Random seed: {random_seed}\n")
    log_file.write(f"Total time steps: {time_steps}\n")
    log_file.write(f"Total nodes: {num_nodes}\n")
    log_file.flush()
    
    def log_message(message):
        log_file.write(f"{import_time.strftime('%Y-%m-%d %H:%M:%S')} - {message}\n")
        log_file.flush()
    
    with tqdm(total=num_samples, desc="Processing progress", file=sys.stdout) as pbar:

        if last_processed_idx >= 0:
            pbar.update(last_processed_idx + 1)
        
        for current_idx in range(num_samples):

            if current_idx >= 0:

                if current_idx <= last_processed_idx:
                    continue
            
                pbar.update(1)

                start_node_idx = 0 if current_idx > last_processed_idx else last_processed_node
                

                for i, node_idx in enumerate(selected_nodes):

                    if i < start_node_idx:
                        continue
                        
                    instruction_data = traffic_data[:current_idx+1, node_idx].tolist()
                    

                    answer_data = traffic_data[current_idx+1:current_idx+1+prediction_length, node_idx].tolist()
                    
                    if len(answer_data) == prediction_length:

                        if len(instruction_data) > 0:
                            min_val = min(instruction_data)
                            max_val = max(instruction_data)
                            if max_val > min_val:
                                instruction_data = [(x - min_val) / (max_val - min_val) for x in instruction_data]
                        
                        if len(answer_data) > 0:
                            min_val = min(answer_data)
                            max_val = max(answer_data)
                            if max_val > min_val:
                                answer_data = [(x - min_val) / (max_val - min_val) for x in answer_data]
                        

                        instruction_str = "[" + ", ".join([f"{x:.4f}" for x in instruction_data]) + "]"
                        answer_str = "[" + ", ".join([f"{x:.4f}" for x in answer_data]) + "]"
                        
                        sample = {
                            "instruction": f"Given historical data of traffic flow {instruction_str}, Predict the traffic flow in the next {prediction_length} time steps.",
                            "input": "",
                            "output": answer_str
                        }
                        
                        dataset.append(sample)
                        current_batch.append(sample)
                        total_new_samples += 1
                        

                        if len(current_batch) >= batch_size:
                            save_json(output_path, dataset)

                            log_message(f"Current batch saved, total samples: {len(dataset)}")

                            pbar.set_description(f"Processing progress (Saved: {len(dataset)} samples)")
                            current_batch = []
    if current_batch:
        save_json(output_path, dataset)
        log_message(f"Last batch of data saved, total samples: {len(dataset)}")
    
    log_file.write(f"===== Data processing session ended {import_time.strftime('%Y-%m-%d %H:%M:%S')} =====\n")
    log_file.write(f"Total new samples generated: {total_new_samples}\n")
    log_file.write(f"Total samples: {len(dataset)}\n")
    log_file.close()
    
    print(f"\nExtraction completed, generated {total_new_samples} new samples, total samples: {len(dataset)}")
    print(f"Data processing completed, file saved to {output_path}")
    print(f"Detailed log saved to {output_path}.log")
    return dataset

def save_json(output_path, dataset):

    temp_path = output_path + ".temp"
    try:
        with open(temp_path, 'w', encoding='utf-8') as f:
            json.dump(dataset, f, ensure_ascii=False, indent=2)
        

        if os.path.exists(output_path):
            os.remove(output_path)
        os.rename(temp_path, output_path)
    except Exception as e:

        with open(output_path + ".log", "a", encoding="utf-8") as log:
            log.write(f"{import_time.strftime('%Y-%m-%d %H:%M:%S')} - Error saving data: {e}\n")
        if os.path.exists(temp_path):
            os.remove(temp_path)
        raise

import datetime
import_time = datetime.datetime.now()

if __name__ == "__main__":

    data_path = "./Orion/data/PEMS03.npz"
    output_dir = "/./Fine_tunning/Data_processing_and_data"
    output_filename = "pems_dataset.json"
    output_path = os.path.join(output_dir, output_filename)
    

    feature_idx = 0
    sample_interval = 5
    prediction_length = 12
    batch_size = 5000
    

    sample_nodes = True
    node_count = 2
    random_seed = 42
    

    extract_traffic_data(
        data_path=data_path,
        output_path=output_path,
        feature_idx=feature_idx,
        prediction_length=prediction_length,
        sample_interval=sample_interval,
        batch_size=batch_size,
        sample_nodes=sample_nodes,
        node_count=node_count,
        random_seed=random_seed
    )