import os
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader, TensorDataset
import configparser
import argparse
import logging
from datetime import datetime
import time
from pathlib import Path
from tqdm import tqdm
from scipy.interpolate import interp1d
from scipy import stats as sp_stats
from scipy.ndimage import median_filter
from scipy.signal import savgol_filter
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import Dict, List, Tuple, Optional
import json
import pickle
import warnings
warnings.filterwarnings('ignore')

try:
    import pywt
    PYWT_AVAILABLE = True
except ImportError:
    PYWT_AVAILABLE = False
    print("Warning: PyWavelets not installed. Wavelet denoising will be disabled.")

class LLMEmbeddingEngine:
    
    def __init__(self, model_path: str, device: str = 'cuda'):
        self.model_path = model_path
        self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
        self.tokenizer = None
        self.model = None
        self.embedding_cache = {}
        
    def _load_model(self):
        if self.model is None:
            print("Loading LLM model...")
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
            self.model = AutoModelForCausalLM.from_pretrained(
                self.model_path,
                torch_dtype=torch.bfloat16,
                device_map="auto",
                output_hidden_states=True
            )
            self.model.eval()
    
    def get_embeddings(self, time_series: np.ndarray, cache_key: Optional[str] = None) -> np.ndarray:
        if cache_key and cache_key in self.embedding_cache:
            return self.embedding_cache[cache_key]
        
        self._load_model()
        
        data_strings = [f"{val:.4f}" for val in time_series]
        combined_input = " ".join(data_strings)
        
        inputs = self.tokenizer(combined_input, return_tensors="pt", padding=True, truncation=True)
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        
        with torch.no_grad():
            outputs = self.model(**inputs, output_hidden_states=True)
            last_hidden_state = outputs.hidden_states[-1].squeeze(0)
            embeddings = last_hidden_state.float().cpu().numpy()
        
        if cache_key:
            self.embedding_cache[cache_key] = embeddings
        
        return embeddings
    
    def compute_similarity(self, emb1: np.ndarray, emb2: np.ndarray, cosine_weight: float = 0.7) -> float:
        vec1 = np.mean(emb1, axis=0)
        vec2 = np.mean(emb2, axis=0)
        
        cosine_sim = np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2) + 1e-8)
        cosine_sim = (cosine_sim + 1) / 2
        
        euclidean_dist = np.linalg.norm(vec1 - vec2)
        euclidean_sim = 1 / (1 + euclidean_dist)
        
        return cosine_weight * cosine_sim + (1 - cosine_weight) * euclidean_sim

class OrionDataPreprocessor:
    
    def __init__(self, config_path: str):
        self.config = self._load_config(config_path)
        self._setup_logging()
        
        self.llm_engine = None
        if self.config.get('use_llm_embedding', False):
            self.llm_engine = LLMEmbeddingEngine(
                self.config['llm_model_path'],
                device='cuda'
            )
        
        self.use_robust_normalization = self.config.get('use_robust_normalization', True)
        self.outlier_method = self.config.get('outlier_method', None)
        self.denoise_method = self.config.get('denoise_method', None)
    
    def _load_config(self, config_path: str) -> dict:
        config = configparser.ConfigParser()
        config.read(config_path)
        
        data_config = dict(config['Data'])
        
        bool_keys = ['use_llm_embedding', 'normalize_adj', 'use_robust_normalization']
        for key in bool_keys:
            if key in data_config:
                data_config[key] = config.getboolean('Data', key)
        
        return data_config
    
    def _setup_logging(self):
        logging.basicConfig(
            level=logging.INFO,
            format='%(message)s',
            handlers=[logging.StreamHandler()]
        )
    
    def detect_and_handle_outliers(self, data: np.ndarray, method='iqr') -> np.ndarray:
        if method is None or method.lower() == 'none':
            return data
        
        if method == 'iqr':
            q1 = np.percentile(data, 25)
            q3 = np.percentile(data, 75)
            iqr = q3 - q1
            lower_bound = q1 - 1.5 * iqr
            upper_bound = q3 + 1.5 * iqr
            
            data_cleaned = np.where(data < lower_bound, lower_bound, data)
            data_cleaned = np.where(data_cleaned > upper_bound, upper_bound, data_cleaned)
        
        elif method == 'rolling_median':
            original_shape = data.shape
            data_flat = data.flatten()
            data_cleaned = median_filter(data_flat, size=5)
            data_cleaned = data_cleaned.reshape(original_shape)
        else:
            return data
        
        return data_cleaned
    
    def denoise_data(self, data: np.ndarray, method='savgol') -> np.ndarray:
        if method is None or method.lower() == 'none':
            return data
        
        if method == 'wavelet' and PYWT_AVAILABLE:
            wavelet = 'db4'
            level = 3
            
            denoised = np.zeros_like(data)
            for i in range(data.shape[1]):
                for j in range(data.shape[2]):
                    signal = data[:, i, j]
                    
                    coeffs = pywt.wavedec(signal, wavelet, level=level)
                    
                    sigma = np.median(np.abs(coeffs[-1])) / 0.6745
                    threshold = sigma * np.sqrt(2 * np.log(len(signal)))
                    coeffs_denoised = [pywt.threshold(c, threshold, 'soft') for c in coeffs]
                    
                    denoised[:, i, j] = pywt.waverec(coeffs_denoised, wavelet)[:len(signal)]
            
            return denoised
        
        elif method == 'savgol':
            denoised = np.zeros_like(data)
            for i in range(data.shape[1]):
                for j in range(data.shape[2]):
                    signal = data[:, i, j]
                    window_length = min(7, len(signal))
                    if window_length % 2 == 0:
                        window_length -= 1
                    if window_length >= 3:
                        denoised[:, i, j] = savgol_filter(signal, window_length=window_length, polyorder=min(2, window_length-1))
                    else:
                        denoised[:, i, j] = signal
            
            return denoised
        else:
            return data
    
    def handle_missing_values(self, data: np.ndarray) -> np.ndarray:
        mask = np.isnan(data) | np.isinf(data)
        missing_ratio = np.sum(mask) / data.size
        
        if missing_ratio > 0:
            for i in range(data.shape[1]):
                for j in range(data.shape[2]):
                    series = data[:, i, j]
                    if np.any(np.isnan(series)):
                        valid_idx = ~np.isnan(series)
                        if np.sum(valid_idx) > 1:
                            f = interp1d(np.where(valid_idx)[0], series[valid_idx], 
                                       kind='linear', fill_value='extrapolate')
                            data[:, i, j] = f(np.arange(len(series)))
                        else:
                            data[:, i, j] = np.nanmean(series)
        
        return data
    
    def validate_data_quality(self, data: Dict) -> Dict:
        quality_report = {}
        
        for key, value in data.items():
            if key != 'time_indices' and isinstance(value, np.ndarray):
                nan_ratio = np.isnan(value).sum() / value.size
                inf_ratio = np.isinf(value).sum() / value.size
                
                flat_data = value.flatten()
                valid_data = flat_data[~(np.isnan(flat_data) | np.isinf(flat_data))]
                
                if len(valid_data) > 0:
                    skewness = sp_stats.skew(valid_data)
                    kurtosis = sp_stats.kurtosis(valid_data)
                    
                    quality_report[key] = {
                        'nan_ratio': nan_ratio,
                        'inf_ratio': inf_ratio,
                        'skewness': skewness,
                        'kurtosis': kurtosis,
                        'needs_transform': abs(skewness) > 2
                    }
        
        return quality_report
    
    def load_data(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        print(f"Loading data from {self.config['graph_signal_matrix_filename']}")
        
        data_file = self.config['graph_signal_matrix_filename']
        data = np.load(data_file)
        
        possible_keys = ['data', 'graph_signal_matrix', 'flow_data', 'traffic_data']
        data_seq = None
        
        for key in possible_keys:
            if key in data:
                data_seq = data[key]
                break
        
        if data_seq is None:
            data_seq = data[list(data.keys())[0]]
        
        print(f"Data shape: {data_seq.shape}")
        
        data_seq = self.handle_missing_values(data_seq)
        
        if self.outlier_method and self.outlier_method.lower() != 'none':
            for f in range(data_seq.shape[2]):
                data_seq[:, :, f] = self.detect_and_handle_outliers(
                    data_seq[:, :, f], 
                    method=self.outlier_method
                )
        
        if self.denoise_method and self.denoise_method.lower() != 'none':
            data_seq = self.denoise_data(data_seq, method=self.denoise_method)
        
        time_indices = self._generate_time_indices(len(data_seq))
        
        quality_report = self.validate_data_quality({'data': data_seq})
        
        adj_matrix = self._load_adj_matrix()
        
        return data_seq, time_indices, adj_matrix
    
    def _generate_time_indices(self, total_timesteps: int) -> np.ndarray:
        time_segments = []
        for segment in self.config['time_segments'].split(';'):
            start, end = map(int, segment.split(','))
            time_segments.append((start, end))
        
        time_indices = np.zeros(total_timesteps, dtype=np.int32)
        points_per_hour = int(self.config['points_per_hour'])
        timesteps_per_day = int(self.config['timesteps_per_day'])
        
        for t in range(total_timesteps):
            hour = (t % timesteps_per_day) // points_per_hour
            
            for seg_idx, (start, end) in enumerate(time_segments):
                if start <= hour < end or (start > end and (hour >= start or hour < end)):
                    time_indices[t] = seg_idx
                    break
        
        return time_indices
    
    def _estimate_gaussian_sigma(self, distances: np.ndarray, strategy: str = 'median') -> float:
        distances = np.asarray(distances)
        distances = distances[np.isfinite(distances)]
        distances = distances[distances > 0]

        if distances.size == 0:
            return 1.0

        strategy = (strategy or 'median').strip().lower()
        if strategy == 'p75':
            sigma = float(np.percentile(distances, 75))
        elif strategy == 'std':
            sigma = float(np.std(distances))
        elif strategy == 'iqr':
            q1 = np.percentile(distances, 25)
            q3 = np.percentile(distances, 75)
            iqr = q3 - q1
            sigma = float(iqr / 1.349) if iqr > 0 else float(np.median(distances))
        else:
            sigma = float(np.median(distances))

        if not np.isfinite(sigma) or sigma <= 1e-12:
            sigma = 1.0

        return sigma

    def _load_adj_matrix(self) -> np.ndarray:
        adj_file = self.config['adj_filename']
        
        if adj_file.endswith('.csv'):
            try:
                adj_df = pd.read_csv(adj_file)
                num_nodes = int(self.config['num_of_vertices'])
                
                adj_matrix = np.zeros((num_nodes, num_nodes), dtype=np.float32)
                
                if adj_df.empty or not all(col in adj_df.columns for col in ['from', 'to']):
                    return adj_matrix
                
                valid_rows = adj_df.dropna(subset=['from', 'to'])
                if valid_rows.empty:
                    return adj_matrix
                
                sigma_cfg = self.config.get('gaussian_sigma', None)
                auto_strategy = self.config.get('gaussian_sigma_auto_strategy', 'median')
                
                if 'distance' not in adj_df.columns:
                    for _, row in valid_rows.iterrows():
                        try:
                            i, j = int(row['from']), int(row['to'])
                            if 0 <= i < num_nodes and 0 <= j < num_nodes:
                                adj_matrix[i, j] = 1.0
                        except (ValueError, TypeError):
                            continue
                            
                    if self.config.get('normalize_adj', False):
                        adj_matrix = self._normalize_adj(adj_matrix).astype(np.float32, copy=False)
                    return adj_matrix
                
                valid_distance_rows = adj_df.dropna(subset=['from', 'to', 'distance'])
                if valid_distance_rows.empty:
                    for _, row in valid_rows.iterrows():
                        try:
                            i, j = int(row['from']), int(row['to'])
                            if 0 <= i < num_nodes and 0 <= j < num_nodes:
                                adj_matrix[i, j] = 1.0
                        except (ValueError, TypeError):
                            continue
                            
                    if self.config.get('normalize_adj', False):
                        adj_matrix = self._normalize_adj(adj_matrix).astype(np.float32, copy=False)
                    return adj_matrix
                
                distances = valid_distance_rows['distance'].to_numpy()
                
                use_auto = False
                sigma_value = None
                if sigma_cfg is None:
                    use_auto = True
                else:
                    try:
                        sigma_value = float(sigma_cfg)
                        if sigma_value <= 0:
                            use_auto = True
                    except (TypeError, ValueError):
                        if isinstance(sigma_cfg, str) and sigma_cfg.strip().lower() == 'auto':
                            use_auto = True
                        else:
                            use_auto = True
                
                if use_auto:
                    sigma = self._estimate_gaussian_sigma(distances, strategy=auto_strategy)
                else:
                    sigma = float(sigma_value)
                
                for _, row in valid_distance_rows.iterrows():
                    try:
                        i, j = int(row['from']), int(row['to'])
                        distance = float(row['distance'])
                        if 0 <= i < num_nodes and 0 <= j < num_nodes:
                            adj_matrix[i, j] = np.exp(-(distance ** 2) / (sigma ** 2))
                    except (ValueError, TypeError):
                        continue
                
                if self.config.get('normalize_adj', False):
                    adj_matrix = self._normalize_adj(adj_matrix).astype(np.float32, copy=False)
                    
            except Exception as e:
                print(f"Error reading CSV file: {e}")
                num_nodes = int(self.config['num_of_vertices'])
                adj_matrix = np.zeros((num_nodes, num_nodes), dtype=np.float32)
                
        elif adj_file.endswith('.npz'):
            data = np.load(adj_file)
            adj_matrix = data['adj_matrix'] if 'adj_matrix' in data else data[list(data.keys())[0]]
            if adj_matrix.dtype != np.float32:
                adj_matrix = adj_matrix.astype(np.float32, copy=False)
        else:
            num_nodes = int(self.config['num_of_vertices'])
            adj_matrix = (np.ones((num_nodes, num_nodes), dtype=np.float32) - 
                        np.eye(num_nodes, dtype=np.float32))
        
        return adj_matrix

    def _normalize_adj(self, adj: np.ndarray) -> np.ndarray:
        d = np.sum(adj, axis=1)
        d_inv_sqrt = np.power(d, -0.5)
        d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
        d_mat_inv_sqrt = np.diag(d_inv_sqrt)
        return d_mat_inv_sqrt @ adj @ d_mat_inv_sqrt
    
    def extract_periodic_samples_all(self, data: np.ndarray, time_indices: np.ndarray) -> Dict:
        print("Extracting periodic samples...")
        
        source_len = int(self.config['source_len'])
        target_len = int(self.config['target_len'])
        num_weeks = int(self.config['num_of_weeks'])
        num_days = int(self.config['num_of_days'])
        timesteps_per_week = int(self.config['timesteps_per_week'])
        timesteps_per_day = int(self.config['timesteps_per_day'])
        
        min_history = max(
            source_len,
            source_len + num_weeks * timesteps_per_week,
            source_len + num_days * timesteps_per_day
        )
        
        valid_start = min_history
        valid_end = len(data) - target_len + 1
        num_samples = valid_end - valid_start
        
        print(f"Number of samples: {num_samples}")
        
        x_h_list, x_w_list, x_d_list = [], [], []
        target_list, time_indices_list = [], []
        
        for t in tqdm(range(valid_start, valid_end), desc="Processing"):
            x_h = data[t-source_len:t].transpose(1, 2, 0)
            x_h_list.append(x_h)
            
            target = data[t:t + target_len, :, 0].transpose(1, 0)
            target_list.append(target)
            
            target_time_indices = time_indices[t:t + target_len]
            time_indices_list.append(target_time_indices)
            
            x_w_sample = self._extract_weekly_data_simple(
                data, t, source_len, num_weeks, 
                timesteps_per_week, x_h if self.llm_engine else None
            )
            x_w_list.append(x_w_sample)
            
            x_d_sample = self._extract_daily_data_simple(
                data, t, source_len, num_days,
                timesteps_per_day, x_h if self.llm_engine else None
            )
            x_d_list.append(x_d_sample)
        
        x_h = np.array(x_h_list)
        x_w = np.array(x_w_list)
        x_d = np.array(x_d_list)
        target = np.array(target_list)
        time_indices_array = np.array(time_indices_list)
        
        return {
            'x_h': x_h,
            'x_w': x_w,
            'x_d': x_d,
            'target': target,
            'time_indices': time_indices_array
        }
    
    def _extract_weekly_data_simple(self, data: np.ndarray, current_t: int,
                                   source_len: int, num_weeks: int,
                                   timesteps_per_week: int, 
                                   current_x_h: Optional[np.ndarray]) -> np.ndarray:
        N, F = data.shape[1], data.shape[2]
        
        if self.llm_engine and current_x_h is not None:
            x_w_sample = np.zeros((N, F, source_len))
            
            for node_idx in range(N):
                current_ts = current_x_h[node_idx, 0, :]
                current_emb = self.llm_engine.get_embeddings(current_ts)
                
                week_similarities = []
                for week_offset in range(1, num_weeks + 1):
                    hist_start = current_t - week_offset * timesteps_per_week - source_len
                    if hist_start >= 0:
                        hist_data = data[hist_start:hist_start + source_len, node_idx, 0]
                        hist_emb = self.llm_engine.get_embeddings(hist_data)
                        similarity = self.llm_engine.compute_similarity(
                            current_emb, hist_emb,
                            float(self.config.get('similarity_cosine_weight', 0.7))
                        )
                        week_similarities.append((week_offset, similarity))
                
                best_week = max(week_similarities, key=lambda x: x[1])[0] if week_similarities else 1
                
                window_start = current_t - best_week * timesteps_per_week - source_len
                window_end = window_start + source_len
                
                if window_start >= 0 and window_end <= len(data):
                    x_w_sample[node_idx, :, :] = data[window_start:window_end, node_idx, :].T
                else:
                    x_w_sample[node_idx, :, :] = data[current_t-source_len:current_t, node_idx, :].T
        else:
            window_start = current_t - timesteps_per_week - source_len
            window_end = window_start + source_len
            
            if window_start >= 0 and window_end <= len(data):
                x_w_sample = data[window_start:window_end, :, :].transpose(1, 2, 0)
            else:
                x_w_sample = data[current_t-source_len:current_t, :, :].transpose(1, 2, 0)
        
        return x_w_sample
    
    def _extract_daily_data_simple(self, data: np.ndarray, current_t: int,
                                  source_len: int, num_days: int,
                                  timesteps_per_day: int, 
                                  current_x_h: Optional[np.ndarray]) -> np.ndarray:
        N, F = data.shape[1], data.shape[2]
        
        if self.llm_engine and current_x_h is not None:
            x_d_sample = np.zeros((N, F, source_len))
            
            for node_idx in range(N):
                current_ts = current_x_h[node_idx, 0, :]
                current_emb = self.llm_engine.get_embeddings(current_ts)
                
                day_similarities = []
                for day_offset in range(1, num_days + 1):
                    hist_start = current_t - day_offset * timesteps_per_day - source_len
                    if hist_start >= 0:
                        hist_data = data[hist_start:hist_start + source_len, node_idx, 0]
                        hist_emb = self.llm_engine.get_embeddings(hist_data)
                        similarity = self.llm_engine.compute_similarity(
                            current_emb, hist_emb,
                            float(self.config.get('similarity_cosine_weight', 0.7))
                        )
                        day_similarities.append((day_offset, similarity))
                
                best_day = max(day_similarities, key=lambda x: x[1])[0] if day_similarities else 1
                
                window_start = current_t - best_day * timesteps_per_day - source_len
                window_end = window_start + source_len
                
                if window_start >= 0 and window_end <= len(data):
                    x_d_sample[node_idx, :, :] = data[window_start:window_end, node_idx, :].T
                else:
                    x_d_sample[node_idx, :, :] = data[current_t-source_len:current_t, node_idx, :].T
        else:
            window_start = current_t - timesteps_per_day - source_len
            window_end = window_start + source_len
            
            if window_start >= 0 and window_end <= len(data):
                x_d_sample = data[window_start:window_end, :, :].transpose(1, 2, 0)
            else:
                x_d_sample = data[current_t-source_len:current_t, :, :].transpose(1, 2, 0)
        
        return x_d_sample
    
    def split_periodic_samples(self, samples: Dict) -> Tuple[Dict, Dict, Dict]:
        print("Splitting dataset...")
        
        train_ratio = float(self.config['train_ratio'])
        val_ratio = float(self.config['val_ratio'])
        
        total_samples = samples['x_h'].shape[0]
        
        train_size = int(total_samples * train_ratio)
        val_size = int(total_samples * val_ratio)
        
        train_data = {
            'x_h': samples['x_h'][:train_size],
            'x_w': samples['x_w'][:train_size],
            'x_d': samples['x_d'][:train_size],
            'target': samples['target'][:train_size],
            'time_indices': samples['time_indices'][:train_size]
        }
        
        val_data = {
            'x_h': samples['x_h'][train_size:train_size + val_size],
            'x_w': samples['x_w'][train_size:train_size + val_size],
            'x_d': samples['x_d'][train_size:train_size + val_size],
            'target': samples['target'][train_size:train_size + val_size],
            'time_indices': samples['time_indices'][train_size:train_size + val_size]
        }
        
        test_data = {
            'x_h': samples['x_h'][train_size + val_size:],
            'x_w': samples['x_w'][train_size + val_size:],
            'x_d': samples['x_d'][train_size + val_size:],
            'target': samples['target'][train_size + val_size:],
            'time_indices': samples['time_indices'][train_size + val_size:]
        }
        
        print(f"Train: {train_data['x_h'].shape[0]} samples")
        print(f"Val: {val_data['x_h'].shape[0]} samples")
        print(f"Test: {test_data['x_h'].shape[0]} samples")
        
        return train_data, val_data, test_data
    
    def compute_train_stats(self, train_data: Dict) -> Dict:
        stats = {'feature_stats': {}}
        num_features = train_data['x_h'].shape[2]
        
        for f in range(num_features):
            feature_data = []
            
            feature_data.append(train_data['x_h'][:, :, f, :].flatten())
            feature_data.append(train_data['x_w'][:, :, f, :].flatten())
            feature_data.append(train_data['x_d'][:, :, f, :].flatten())
            
            all_feature_data = np.concatenate(feature_data)
            
            if self.use_robust_normalization:
                median_val = np.median(all_feature_data)
                q1 = np.percentile(all_feature_data, 25)
                q3 = np.percentile(all_feature_data, 75)
                iqr = q3 - q1
                
                mean_val = np.mean(all_feature_data)
                std_val = np.std(all_feature_data)
                
                stats['feature_stats'][f] = {
                    'min': np.min(all_feature_data),
                    'max': np.max(all_feature_data),
                    'mean': mean_val,
                    'std': std_val,
                    'median': median_val,
                    'q1': q1,
                    'q3': q3,
                    'iqr': iqr
                }
            else:
                min_val = np.min(all_feature_data)
                max_val = np.max(all_feature_data)
                mean_val = np.mean(all_feature_data)
                std_val = np.std(all_feature_data)
                
                stats['feature_stats'][f] = {
                    'min': min_val,
                    'max': max_val,
                    'mean': mean_val,
                    'std': std_val
                }
        
        stats['min_flow'] = stats['feature_stats'][0]['min']
        stats['max_flow'] = stats['feature_stats'][0]['max']
        stats['mean_flow'] = stats['feature_stats'][0]['mean']
        stats['std_flow'] = stats['feature_stats'][0]['std']
        
        return stats
    
    def normalize_data_with_stats(self, data: Dict, stats: Dict, split_name: str) -> Dict:
        if self.use_robust_normalization:
            return self.normalize_data_robust(data, stats, split_name)
        else:
            normalized_data = {
                'x_h': data['x_h'].copy(),
                'x_w': data['x_w'].copy(),
                'x_d': data['x_d'].copy(),
                'target': data['target'].copy(),
                'time_indices': data['time_indices'].copy()
            }
            
            num_features = normalized_data['x_h'].shape[2]
            
            for f in range(num_features):
                min_val = stats['feature_stats'][f]['min']
                max_val = stats['feature_stats'][f]['max']
                
                if max_val - min_val > 1e-8:
                    normalized_data['x_h'][:, :, f, :] = (
                        normalized_data['x_h'][:, :, f, :] - min_val
                    ) / (max_val - min_val)
                    
                    normalized_data['x_w'][:, :, f, :] = (
                        normalized_data['x_w'][:, :, f, :] - min_val
                    ) / (max_val - min_val)
                    
                    normalized_data['x_d'][:, :, f, :] = (
                        normalized_data['x_d'][:, :, f, :] - min_val
                    ) / (max_val - min_val)
                else:
                    normalized_data['x_h'][:, :, f, :] = 0
                    normalized_data['x_w'][:, :, f, :] = 0
                    normalized_data['x_d'][:, :, f, :] = 0
            
            min_flow = stats['min_flow']
            max_flow = stats['max_flow']
            
            if max_flow - min_flow > 1e-8:
                normalized_data['target'] = (
                    normalized_data['target'] - min_flow
                ) / (max_flow - min_flow)
            else:
                normalized_data['target'] = 0
            
            for key in ['x_h', 'x_w', 'x_d', 'target']:
                normalized_data[key] = np.clip(normalized_data[key], 0, 1)
            
            return normalized_data
    
    def normalize_data_robust(self, data: Dict, stats: Dict, split_name: str) -> Dict:
        normalized_data = {
            'x_h': data['x_h'].copy(),
            'x_w': data['x_w'].copy(),
            'x_d': data['x_d'].copy(),
            'target': data['target'].copy(),
            'time_indices': data['time_indices'].copy()
        }
        
        num_features = normalized_data['x_h'].shape[2]
        
        for f in range(num_features):
            if 'median' in stats['feature_stats'][f]:
                median = stats['feature_stats'][f]['median']
                iqr = stats['feature_stats'][f]['iqr']
                
                if iqr > 1e-8:
                    normalized_data['x_h'][:, :, f, :] = (
                        normalized_data['x_h'][:, :, f, :] - median
                    ) / iqr
                    
                    normalized_data['x_w'][:, :, f, :] = (
                        normalized_data['x_w'][:, :, f, :] - median
                    ) / iqr
                    
                    normalized_data['x_d'][:, :, f, :] = (
                        normalized_data['x_d'][:, :, f, :] - median
                    ) / iqr
                else:
                    mean = stats['feature_stats'][f]['mean']
                    std = stats['feature_stats'][f]['std']
                    
                    if std > 1e-8:
                        normalized_data['x_h'][:, :, f, :] = (
                            normalized_data['x_h'][:, :, f, :] - mean
                        ) / std
                        
                        normalized_data['x_w'][:, :, f, :] = (
                            normalized_data['x_w'][:, :, f, :] - mean
                        ) / std
                        
                        normalized_data['x_d'][:, :, f, :] = (
                            normalized_data['x_d'][:, :, f, :] - mean
                        ) / std
                    else:
                        normalized_data['x_h'][:, :, f, :] = 0
                        normalized_data['x_w'][:, :, f, :] = 0
                        normalized_data['x_d'][:, :, f, :] = 0
            else:
                mean = stats['feature_stats'][f]['mean']
                std = stats['feature_stats'][f]['std']
                
                if std > 1e-8:
                    normalized_data['x_h'][:, :, f, :] = (
                        normalized_data['x_h'][:, :, f, :] - mean
                    ) / std
                    
                    normalized_data['x_w'][:, :, f, :] = (
                        normalized_data['x_w'][:, :, f, :] - mean
                    ) / std
                    
                    normalized_data['x_d'][:, :, f, :] = (
                        normalized_data['x_d'][:, :, f, :] - mean
                    ) / std
                else:
                    normalized_data['x_h'][:, :, f, :] = 0
                    normalized_data['x_w'][:, :, f, :] = 0
                    normalized_data['x_d'][:, :, f, :] = 0
        
        if self.use_robust_normalization and 'median' in stats['feature_stats'][0]:
            median_flow = stats['feature_stats'][0]['median']
            iqr_flow = stats['feature_stats'][0]['iqr']
            
            if iqr_flow > 1e-8:
                normalized_data['target'] = (
                    normalized_data['target'] - median_flow
                ) / iqr_flow
            else:
                mean_flow = stats['mean_flow']
                std_flow = stats['std_flow']
                if std_flow > 1e-8:
                    normalized_data['target'] = (
                        normalized_data['target'] - mean_flow
                    ) / std_flow
                else:
                    normalized_data['target'] = 0
        else:
            min_flow = stats['min_flow']
            max_flow = stats['max_flow']
            
            if max_flow - min_flow > 1e-8:
                normalized_data['target'] = (
                    normalized_data['target'] - min_flow
                ) / (max_flow - min_flow)
            else:
                normalized_data['target'] = 0
        
        return normalized_data
    
    def save_dataset(self, train_data: Dict, val_data: Dict, test_data: Dict,
                train_stats: Dict, adj_matrix: np.ndarray):
        base_name = Path(self.config['graph_signal_matrix_filename']).stem
        num_hours = self.config['num_of_hours']
        num_days = self.config['num_of_days']
        num_weeks = self.config['num_of_weeks']
        
        output_dir = Path(self.config['output_path'])
        output_dir.mkdir(parents=True, exist_ok=True)
        
        output_file = output_dir / f"{base_name}_r{num_hours}_d{num_days}_w{num_weeks}_Orion.npz"
        
        save_dict = {
            'train_x_h': train_data['x_h'],
            'train_x_w': train_data['x_w'],
            'train_x_d': train_data['x_d'],
            'train_target': train_data['target'],
            'train_time_indices': train_data['time_indices'],
            
            'val_x_h': val_data['x_h'],
            'val_x_w': val_data['x_w'],
            'val_x_d': val_data['x_d'],
            'val_target': val_data['target'],
            'val_time_indices': val_data['time_indices'],
            
            'test_x_h': test_data['x_h'],
            'test_x_w': test_data['x_w'],
            'test_x_d': test_data['x_d'],
            'test_target': test_data['target'],
            'test_time_indices': test_data['time_indices'],
            
            'adj_matrix': adj_matrix,
            
            'train_min_flow': train_stats['min_flow'],
            'train_max_flow': train_stats['max_flow'],
            'train_mean_flow': train_stats['mean_flow'],
            'train_std_flow': train_stats['std_flow']
        }
        
        for feat_idx, feat_stats in train_stats['feature_stats'].items():
            save_dict[f'train_feature_{feat_idx}_min'] = feat_stats['min']
            save_dict[f'train_feature_{feat_idx}_max'] = feat_stats['max']
            
            if 'median' in feat_stats:
                save_dict[f'train_feature_{feat_idx}_median'] = feat_stats['median']
                save_dict[f'train_feature_{feat_idx}_iqr'] = feat_stats['iqr']
            
            save_dict[f'train_feature_{feat_idx}_mean'] = feat_stats['mean']
            save_dict[f'train_feature_{feat_idx}_std'] = feat_stats['std']
        
        np.savez_compressed(output_file, **save_dict)
        print(f"Dataset saved to: {output_file}")
        
        self._save_dataset_info(output_dir, base_name, train_data, val_data, test_data, train_stats)
    
    def _save_dataset_info(self, output_dir: Path, base_name: str,
                      train_data: Dict, val_data: Dict, test_data: Dict,
                      train_stats: Dict):
        info_file = output_dir / f"{base_name}_dataset_info.json"
        
        dataset_info = {
            'creation_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
            'config': self.config,
            'preprocessing': {
                'use_robust_normalization': self.use_robust_normalization,
                'outlier_method': self.outlier_method,
                'denoise_method': self.denoise_method
            },
            'data_shapes': {
                'train': {k: v.shape for k, v in train_data.items()},
                'val': {k: v.shape for k, v in val_data.items()},
                'test': {k: v.shape for k, v in test_data.items()}
            },
            'statistics': {
                'train': train_stats,
                'note': 'All datasets normalized using training set statistics'
            }
        }
        
        with open(info_file, 'w', encoding='utf-8') as f:
            json.dump(dataset_info, f, indent=2, default=str)
        
        print(f"Dataset info saved to: {info_file}")
    
    def process(self):
        start_time = time.time()
        
        data_seq, time_indices, adj_matrix = self.load_data()
        
        all_samples = self.extract_periodic_samples_all(data_seq, time_indices)
        
        train_samples, val_samples, test_samples = self.split_periodic_samples(all_samples)
        
        train_stats = self.compute_train_stats(train_samples)
        
        print("Normalizing data...")
        
        train_data = self.normalize_data_with_stats(train_samples, train_stats, 'train')
        val_data = self.normalize_data_with_stats(val_samples, train_stats, 'val')
        test_data = self.normalize_data_with_stats(test_samples, train_stats, 'test')
        
        self.save_dataset(
            train_data, val_data, test_data,
            train_stats,
            adj_matrix
        )
        
        total_time = time.time() - start_time
        print(f"Processing complete! Total time: {total_time/60:.2f} minutes")

def main():
    parser = argparse.ArgumentParser(description='Orion Data Preprocessing')
    parser.add_argument('--config', type=str, required=True, help='Configuration file path')
    args = parser.parse_args()
    
    preprocessor = OrionDataPreprocessor(args.config)
    preprocessor.process()

if __name__ == '__main__':
    main()