import numpy as np
import pandas as pd
import os
import time
import argparse
from tqdm import tqdm
import networkx as nx
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import silhouette_score
import warnings
warnings.filterwarnings('ignore')

def analyze_graph_properties(adj_df, num_nodes_total):
    G = nx.DiGraph()
    G.add_nodes_from(range(num_nodes_total))
    
    for _, row in adj_df.iterrows():
        from_node = int(row.iloc[0])
        to_node = int(row.iloc[1])
        weight = 1.0 / float(row.iloc[2]) if row.iloc[2] > 0 else 0.0
        if from_node < num_nodes_total and to_node < num_nodes_total:
            G.add_edge(from_node, to_node, weight=weight)
    
    in_degree_cent = nx.in_degree_centrality(G)
    out_degree_cent = nx.out_degree_centrality(G)
    degree_cent = {node: (in_degree_cent[node] + out_degree_cent[node]) / 2 
                   for node in G.nodes()}
    
    betweenness_cent = nx.betweenness_centrality(G, k=min(100, G.number_of_nodes()))
    
    try:
        closeness_cent = nx.closeness_centrality(G)
    except:
        closeness_cent = {node: 0.0 for node in G.nodes()}
    
    pagerank = nx.pagerank(G, max_iter=100)
    
    causal_potential = {}
    for node in G.nodes():
        in_deg = G.in_degree(node)
        out_deg = G.out_degree(node)
        if in_deg + out_deg > 0:
            potential = abs(in_deg - out_deg) / (in_deg + out_deg)
        else:
            potential = 0.0
        causal_potential[node] = potential
    
    return G, {
        'degree': degree_cent,
        'betweenness': betweenness_cent,
        'closeness': closeness_cent,
        'pagerank': pagerank,
        'causal_potential': causal_potential
    }

def select_nodes_kmeans_strategy(G, centrality_metrics, flow_data, num_nodes=40):
    num_total_nodes = flow_data.shape[1]
    
    features = []
    
    for i in range(num_total_nodes):
        node_features = []
        
        node_flow = flow_data[:, i, 0]
        node_features.extend([
            np.mean(node_flow),
            np.std(node_flow),
            np.percentile(node_flow, 25),
            np.percentile(node_flow, 50),
            np.percentile(node_flow, 75),
            np.max(node_flow) - np.min(node_flow)
        ])
        
        if flow_data.shape[2] > 1:
            for feature_idx in range(1, min(flow_data.shape[2], 3)):
                feature_data = flow_data[:, i, feature_idx]
                node_features.extend([
                    np.mean(feature_data),
                    np.std(feature_data)
                ])
        
        features.append(node_features)
    
    features = np.array(features)
    
    scaler = StandardScaler()
    features_scaled = scaler.fit_transform(features)
    
    kmeans = KMeans(n_clusters=num_nodes, random_state=42, n_init=10)
    clusters = kmeans.fit_predict(features_scaled)
    
    scores = {}
    
    flow_importance = np.zeros(num_total_nodes)
    flow_std = np.zeros(num_total_nodes)
    
    for i in range(num_total_nodes):
        node_flow = flow_data[:, i, 0]
        flow_importance[i] = np.mean(node_flow)
        flow_std[i] = np.std(node_flow)
    
    flow_importance = (flow_importance - flow_importance.min()) / (flow_importance.max() - flow_importance.min() + 1e-10)
    flow_std_norm = (flow_std - flow_std.min()) / (flow_std.max() - flow_std.min() + 1e-10)
    
    for node in range(num_total_nodes):
        if node not in G.nodes():
            scores[node] = 0.0
            continue
        
        graph_score = (
            0.15 * centrality_metrics['degree'].get(node, 0) +
            0.15 * centrality_metrics['betweenness'].get(node, 0) +
            0.05 * centrality_metrics['closeness'].get(node, 0) +
            0.15 * centrality_metrics['pagerank'].get(node, 0) +
            0.10 * centrality_metrics['causal_potential'].get(node, 0)
        )
        
        flow_score = 0.3 * flow_importance[node] + 0.1 * flow_std_norm[node]
        
        scores[node] = graph_score + flow_score
    
    selected_nodes = []
    cluster_info = {}
    
    for cluster_id in range(num_nodes):
        cluster_nodes = [node for node in range(num_total_nodes) if clusters[node] == cluster_id]
        
        if cluster_nodes:
            cluster_scores = [(node, scores[node]) for node in cluster_nodes]
            cluster_scores.sort(key=lambda x: x[1], reverse=True)
            best_node = cluster_scores[0][0]
            selected_nodes.append(best_node)
            
            cluster_info[cluster_id] = {
                'selected_node': best_node,
                'cluster_size': len(cluster_nodes),
                'avg_score': np.mean([scores[n] for n in cluster_nodes]),
                'best_score': scores[best_node]
            }
    
    selected_nodes.sort()
    
    clustering_results = {
        'clusters': clusters,
        'features_scaled': features_scaled,
        'cluster_centers': kmeans.cluster_centers_,
        'cluster_info': cluster_info
    }
    
    return selected_nodes, scores, clustering_results

def create_subset_dataset(data_path, csv_path, output_dir, num_nodes=30, 
                         time_steps=3000, start_time_idx=0):
    print(f"Creating subset: {num_nodes} nodes, {time_steps} time steps")
    
    os.makedirs(output_dir, exist_ok=True)
    
    try:
        data = np.load(data_path)
    except Exception as e:
        print(f"Failed to load data: {e}")
        return None, None
    
    try:
        adj_df = pd.read_csv(csv_path)
    except Exception as e:
        print(f"Failed to load adjacency matrix: {e}")
        return None, None
    
    if 'data' in data:
        original_data = data['data']
    else:
        print("Error: No 'data' key in NPZ file")
        return None, None
    
    total_nodes = original_data.shape[1]
    
    end_time_idx = start_time_idx + time_steps
    if end_time_idx > original_data.shape[0]:
        end_time_idx = original_data.shape[0]
        time_steps = end_time_idx - start_time_idx
    
    if num_nodes >= total_nodes:
        time_slice_data = original_data[start_time_idx:end_time_idx]
        filtered_adj = adj_df.copy()
        
        npz_output_path = os.path.join(output_dir, 
            f"metrla_subset_t{time_steps}_start{start_time_idx}.npz")
        csv_output_path = csv_path
        
        save_dict = {}
        for key in data.keys():
            if key == 'data':
                save_dict[key] = time_slice_data
            else:
                save_dict[key] = data[key]
        
        np.savez_compressed(npz_output_path, **save_dict)
        
        print(f"Saved: {npz_output_path}")
        
        return npz_output_path, csv_output_path
    
    G, centrality_metrics = analyze_graph_properties(adj_df, original_data.shape[1])
    
    time_slice_data = original_data[start_time_idx:end_time_idx]
    
    selected_nodes, scores, clustering_results = select_nodes_kmeans_strategy(
        G, centrality_metrics, time_slice_data, num_nodes
    )
    
    subset_data = time_slice_data[:, selected_nodes, :]
    
    node_map = {old_id: new_id for new_id, old_id in enumerate(selected_nodes)}
    reverse_map = {new_id: old_id for new_id, old_id in enumerate(selected_nodes)}
    
    filtered_adj = adj_df[(adj_df.iloc[:, 0].isin(selected_nodes)) & 
                          (adj_df.iloc[:, 1].isin(selected_nodes))].copy()
    
    filtered_adj.iloc[:, 0] = filtered_adj.iloc[:, 0].map(node_map)
    filtered_adj.iloc[:, 1] = filtered_adj.iloc[:, 1].map(node_map)
    
    node_mapping_path = os.path.join(output_dir, 
        f"node_id_mapping_n{num_nodes}_t{time_steps}.txt")
    with open(node_mapping_path, 'w') as f:
        f.write("# Node ID Mapping (Continuous ID -> Original ID)\n")
        f.write(f"# Total selected nodes: {len(selected_nodes)}\n")
        f.write(f"# Data subset: {num_nodes} nodes, {time_steps} time steps\n")
        f.write(f"# Time range: [{start_time_idx}, {end_time_idx})\n")
        f.write("# Format: new_id,original_id\n")
        f.write("#" + "="*50 + "\n")
        for new_id, old_id in reverse_map.items():
            f.write(f"{new_id},{old_id}\n")
    
    npz_output_path = os.path.join(output_dir, 
        f"metrla_subset_kmeans_n{num_nodes}_t{time_steps}.npz")
    csv_output_path = os.path.join(output_dir, 
        f"metrla_subset_kmeans_n{num_nodes}_t{time_steps}.csv")
    
    np.savez_compressed(npz_output_path, 
                       data=subset_data,
                       selected_nodes=np.array(selected_nodes),
                       node_mapping=node_map,
                       reverse_mapping=reverse_map)
    
    filtered_adj.to_csv(csv_output_path, index=False)
    
    print(f"Saved: {npz_output_path}")
    print(f"Saved: {csv_output_path}")
    
    if filtered_adj.shape[0] > 0:
        from_nodes = filtered_adj.iloc[:, 0].unique()
        to_nodes = filtered_adj.iloc[:, 1].unique()
        all_csv_nodes = sorted(list(set(from_nodes) | set(to_nodes)))
        
        if len(all_csv_nodes) < len(selected_nodes):
            isolated_nodes = []
            for new_id in range(len(selected_nodes)):
                if new_id not in all_csv_nodes:
                    isolated_nodes.append(new_id)
            
            with open(node_mapping_path, 'a') as f:
                f.write("\n# Isolated nodes (no edges in adjacency matrix):\n")
                f.write(f"# Total isolated nodes: {len(isolated_nodes)}\n")
                f.write("# new_id -> original_id\n")
                for new_id in isolated_nodes:
                    f.write(f"# {new_id} -> {reverse_map[new_id]} (isolated)\n")
    
    return npz_output_path, csv_output_path

def validate_subset(original_path, subset_path):
    original_data = np.load(original_path)
    subset_data = np.load(subset_path)
    
    original = original_data['data']
    subset = subset_data['data']
    
    print(f"Validation: Original {original.shape} -> Subset {subset.shape}")
    
    if 'selected_nodes' in subset_data:
        selected_nodes = subset_data['selected_nodes']
        print(f"Selected {len(selected_nodes)} nodes")

def compare_csv_files(csv1_path, csv2_path):
    df1 = pd.read_csv(csv1_path)
    df2 = pd.read_csv(csv2_path)
    
    nodes1 = set(df1.iloc[:, 0].unique()) | set(df1.iloc[:, 1].unique())
    nodes2 = set(df2.iloc[:, 0].unique()) | set(df2.iloc[:, 1].unique())
    
    print(f"CSV Comparison: {len(nodes1)} nodes -> {len(nodes2)} nodes")

def update_config_for_subset(config_path, new_config_path, npz_path, csv_path, num_nodes, is_full_nodes=False):
    try:
        with open(config_path, 'r') as f:
            config_content = f.read()
        
        import re
        
        csv_pattern = r'(adj_filename\s*=\s*)([^\n]+)'
        config_content = re.sub(csv_pattern, lambda m: m.group(1) + csv_path, config_content)
        
        npz_pattern = r'(graph_signal_matrix_filename\s*=\s*)([^\n]+)'
        config_content = re.sub(npz_pattern, lambda m: m.group(1) + npz_path, config_content)
        
        if not is_full_nodes:
            node_pattern = r'(num_of_vertices\s*=\s*)\d+'
            config_content = re.sub(node_pattern, lambda m: m.group(1) + str(num_nodes), config_content)
            
            comment = f"\n# This configuration uses a K-means clustered subset\n"
            comment += f"# Subset: {num_nodes} nodes (remapped to 0-{num_nodes-1})\n"
            comment += f"# See node_id_mapping file for original IDs\n\n"
            config_content = comment + config_content
        else:
            comment = f"\n# This configuration uses all nodes with time slicing\n"
            comment += f"# All original nodes preserved\n\n"
            config_content = comment + config_content
        
        with open(new_config_path, 'w') as f:
            f.write(config_content)
        
        print(f"Config saved: {new_config_path}")
        
    except Exception as e:
        print(f"Config update failed: {e}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Create subset for metrla dataset')
    parser.add_argument('--data_path', type=str, default='./Orion/data/metrla.npz', 
                       help='metrla.npz file path')
    parser.add_argument('--csv_path', type=str, default='./Orion/data/metrla.csv',
                       help='metrla.csv file path')
    parser.add_argument('--output_dir', type=str, default='./Orion/data/subset',
                       help='Output directory')
    parser.add_argument('--config_path', type=str, 
                       default='./Orion/configurations/Orion_metrla_config.conf',
                       help='Configuration file path')
    parser.add_argument('--num_nodes', type=int, default=5, 
                       help='Number of nodes to keep')
    parser.add_argument('--time_steps', type=int, default=34272,
                       help='Number of time steps')
    parser.add_argument('--start_time_idx', type=int, default=0,
                       help='Starting time index')
    parser.add_argument('--validate', action='store_true',
                       help='Validate subset')
    parser.add_argument('--compare_csv', action='store_true',
                       help='Compare CSVs')
    
    args = parser.parse_args()
    
    os.makedirs(args.output_dir, exist_ok=True)
    
    npz_path, csv_path = create_subset_dataset(
        args.data_path, 
        args.csv_path, 
        args.output_dir,
        args.num_nodes,
        args.time_steps,
        args.start_time_idx
    )
    
    if npz_path is None:
        print("Failed!")
        exit(1)
    
    if args.validate:
        validate_subset(args.data_path, npz_path)
    
    if args.compare_csv and csv_path != args.csv_path:
        compare_csv_files(args.csv_path, csv_path)
    
    data = np.load(args.data_path)
    original_num_nodes = data['data'].shape[1]
    is_full_nodes = (args.num_nodes >= original_num_nodes)
    
    if is_full_nodes:
        new_config_name = f"Orion_metrla_subset_t{args.time_steps}_start{args.start_time_idx}.conf"
    else:
        new_config_name = f"Orion_metrla_subset_kmeans_n{args.num_nodes}_t{args.time_steps}.conf"
    
    new_config_path = os.path.join(args.output_dir, new_config_name)
    
    update_config_for_subset(args.config_path, new_config_path, npz_path, csv_path, 
                           args.num_nodes, is_full_nodes)
    
    print("\nCompleted!")