import os
import torch
import argparse
import numpy as np
import networkx as nx
import torch_geometric.data as pyg_data
import matplotlib.pyplot as plt
from scipy import stats
import networkx.algorithms.community as nx_comm
import pandas as pd
from datetime import datetime
from tqdm import tqdm
from scipy import stats

# Local imports
from dataset import CustomDataset
from model import NodeGCN, NodeGAT, NodeGraphConv
from memorization import calculate_node_memorization_score, plot_node_memorization_analysis
from nodeli import li_node
from onlysim2 import modify_graph
from main import (set_seed, train_models, verify_no_data_leakage, 
                setup_logging, get_model, test)
from torch_geometric.utils import to_networkx

# Edge budget configurations
EDGE_BUDGETS = [100, 500, 1000]

def parse_args():
    parser = argparse.ArgumentParser(description='Memorization Analysis with Graph Rewiring')
    
    # Dataset and model parameters
    parser.add_argument('--dataset', type=str, default='h0.00-r1', 
                      help='Name of the synthetic dataset (e.g., h0.00-r1)')
    parser.add_argument('--model_type', type=str, default='gcn',
                      choices=['gcn', 'gat', 'graphconv', 'graphsage'],
                      help='Type of GNN model to use')
    parser.add_argument('--hidden_dim', type=int, default=32,
                      help='Hidden dimension size')
    parser.add_argument('--num_layers', type=int, default=3,
                      help='Number of GNN layers')
    
    # Training parameters
    parser.add_argument('--lr', type=float, default=0.01,
                      help='Learning rate')
    parser.add_argument('--weight_decay', type=float, default=5e-4,
                      help='Weight decay for optimizer')
    parser.add_argument('--epochs', type=int, default=100,
                      help='Number of epochs to train')
    parser.add_argument('--seeds', type=int, nargs='+', default=[42, 123, 456],
                      help='Random seeds for multiple runs')
    parser.add_argument('--gat_heads', type=int, default=4,
                      help='Number of attention heads for GAT model')
    
    # Device settings
    parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu',
                      help='Device to use for training')
    
    # Rewiring parameters
    parser.add_argument('--edges_to_add', type=int, default=100,
                      help='Number of edges to add during rewiring')
    parser.add_argument('--edges_to_delete', type=int, default=0,
                      help='Number of edges to delete during rewiring')
    
    return parser.parse_args()

def load_and_process_dataset(args, dataset_name, logger):
    """Load synthetic Cora dataset and convert to PyG format"""
    # Construct full path to dataset
    root_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "syn-cora")
    dataset = CustomDataset(root=root_dir, name=dataset_name, setting="gcn")
    
    # Convert to PyG format
    edge_index = torch.from_numpy(np.vstack(dataset.adj.nonzero())).long()
    
    # Convert sparse features to dense numpy array if needed
    if isinstance(dataset.features, np.ndarray):
        x = torch.from_numpy(dataset.features).float()
    else:
        x = torch.from_numpy(dataset.features.todense()).float()
    
    y = torch.from_numpy(dataset.labels).long()
    
    # Create train/val/test masks
    train_mask = torch.zeros(len(y), dtype=torch.bool)
    val_mask = torch.zeros(len(y), dtype=torch.bool)
    test_mask = torch.zeros(len(y), dtype=torch.bool)
    
    train_mask[dataset.idx_train] = True
    val_mask[dataset.idx_val] = True
    test_mask[dataset.idx_test] = True
    
    # Convert to networkx for label informativeness calculation
    G = nx.Graph()
    G.add_nodes_from(range(len(y)))
    G.add_edges_from(edge_index.t().numpy())
    
    # Calculate label informativeness
    informativeness = li_node(G, dataset.labels)
    
    # Calculate homophily (edge homophily)
    edges = edge_index.t().numpy()
    same_label = dataset.labels[edges[:, 0]] == dataset.labels[edges[:, 1]]
    homophily = same_label.mean()
    
    # Create a proper PyG Data object
    data = pyg_data.Data(
        x=x,
        y=y,
        edge_index=edge_index,
        train_mask=train_mask,
        val_mask=val_mask,
        test_mask=test_mask,
        num_nodes=len(y)
    )
    
    # Add custom attributes
    data.informativeness = informativeness
    data.homophily = homophily
    
    logger.info(f"\nDataset Statistics:")
    logger.info(f"Number of nodes: {data.num_nodes}")
    logger.info(f"Number of edges: {len(edges)}")
    logger.info(f"Number of features: {x.shape[1]}")
    logger.info(f"Number of classes: {len(torch.unique(y))}")
    logger.info(f"Homophily: {homophily:.4f}")
    logger.info(f"Label Informativeness: {informativeness:.4f}")
    
    return data

def run_memorization_pipeline(data, nodes_dict, args, device, logger, log_dir, 
                           seed, suffix="", timestamp=None):
    """Run the full memorization analysis pipeline with a single seed"""
    # Unpack node splits
    shared_idx = nodes_dict['shared']
    candidate_idx = nodes_dict['candidate']
    independent_idx = nodes_dict['independent']
    
    # Verify no data leakage
    verify_no_data_leakage(shared_idx, candidate_idx, independent_idx, logger)
    
    logger.info(f"\n{suffix} Training Model for seed {seed}...")
    set_seed(seed)
    
    # Train models with just this seed
    models = train_models(
        args, data, shared_idx, candidate_idx, independent_idx,
        device, logger, log_dir, seeds=[seed]
    )
    
    # Extract models
    f_model = models[0][-1] if isinstance(models[0], list) else models[0]
    g_model = models[1][-1] if isinstance(models[1], list) else models[1]
    
    # Calculate test accuracy
    test_acc = test(f_model, data.x, data.edge_index, data.test_mask, data.y, device)
    
    # Calculate memorization scores
    logger.info(f"\n{suffix} Calculating Memorization Scores for seed {seed}...")
    node_scores = calculate_node_memorization_score(
        f_model, g_model, data, nodes_dict, device, logger
    )
    
    # Prepare results
    memorization_stats = {}
    memorization_rate = 0
    num_memorized = 0
    total_candidates = 0
    
    # Store scores for each node type
    for node_type in ['shared', 'candidate', 'independent']:
        if node_type in node_scores:
            scores = node_scores[node_type]['mem_scores']
            memorization_stats[node_type] = {'mean': np.mean(scores), 'ci': 0}
            
            # Calculate memorization rate from candidate nodes
            if node_type == 'candidate':
                mem_threshold = 0.5  # Threshold for considering a node memorized
                num_memorized = sum(score > mem_threshold for score in scores)
                total_candidates = len(scores)
                memorization_rate = (num_memorized / total_candidates) * 100 if total_candidates > 0 else 0

    # Add memorization rate to stats
    memorization_stats['rate'] = {'mean': memorization_rate, 'ci': 0}

    # Log results for this seed
    logger.info(f"\n{suffix} Results for seed {seed}:")
    logger.info(f"Test Accuracy: {test_acc*100:.2f}%")
    
    for node_type in ['shared', 'candidate', 'independent']:
        if node_type in memorization_stats:
            logger.info(f"{node_type.capitalize()} Node Memorization Score: {memorization_stats[node_type]['mean']:.4f}")
    
    logger.info(f"Memorization Rate: {memorization_rate:.2f}% ({num_memorized}/{total_candidates} nodes)")
    
    return {
        'test_accuracy': {'mean': test_acc*100, 'ci': 0},
        'memorization': memorization_stats
    }

def plot_comparison(original_scores, rewired_scores, save_path):
    """Create a comparative bar plot of memorization rates with error bars"""
    if 'memorization' not in original_scores or 'memorization' not in rewired_scores:
        return
    
    # Get memorization rates and CIs
    orig_stats = original_scores['memorization']['rate']
    rewired_stats = rewired_scores['memorization']['rate']
    
    # Extract means and use standard deviations for error bars
    rates = [orig_stats['mean'], rewired_stats['mean']]
    # Calculate standard deviations directly for error bars
    orig_std = np.std(original_scores['memorization']['raw_values'], ddof=1) if 'raw_values' in original_scores['memorization'] else orig_stats['ci']
    rewired_std = np.std(rewired_scores['memorization']['raw_values'], ddof=1) if 'raw_values' in rewired_scores['memorization'] else rewired_stats['ci']
    errors = [orig_std, rewired_std]
    
    # Create bar plot with standard deviation error bars
    plt.figure(figsize=(36, 32))
    x_pos = [0, 1]
    
    # Plot bars separately for proper legend
    bar1 = plt.bar(x_pos[0], rates[0], color='skyblue', 
                   yerr=errors[0], capsize=10, error_kw={'elinewidth': 3})
    bar2 = plt.bar(x_pos[1], rates[1], color='lightgreen',
                   yerr=errors[1], capsize=10, error_kw={'elinewidth': 3})
    
    plt.ylabel('Memorization Rate (%)', fontsize=100, font='Sans Serif')
    plt.xticks([])  # Remove x-axis ticks
    plt.yticks(fontsize=100)  # Set y-axis tick font size
    ax = plt.gca()
    ax.tick_params(axis='y', labelsize=100)  # Explicitly set y-axis tick label size
    plt.grid(axis='y', alpha=0.3)
    
    # Add legend with both bars in upper left to avoid overlap
    plt.legend([bar1, bar2], ['Original', 'Rewired'],
              loc='upper right', 
              fontsize=60, 
              frameon=True,
              fancybox=True,
              shadow=True,
              bbox_to_anchor=(1.0, 0.99))
    
    # Save plot
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()

def get_full_dataset_name(dataset_name):
    """Ensure dataset name has the correct format for synthetic Cora datasets"""
    # If the name doesn't end with '-r1' and doesn't already have the 'syn-cora-' prefix
    if not dataset_name.endswith('-r1') and not dataset_name.startswith('syn-cora-'):
        # Add the -r1 suffix if not present
        base_name = dataset_name.rstrip('-r1')
        dataset_name = f"{base_name}-r1"
    return dataset_name

def calculate_stats_with_ci(values, confidence=0.95):
    """Calculate mean and confidence interval for a list of values."""
    mean = np.mean(values)
    if len(values) > 1:
        # Calculate standard error of the mean
        std_err = np.std(values, ddof=1) / np.sqrt(len(values))
        # Get t-value for desired confidence level and degrees of freedom
        t_val = stats.t.ppf((1 + confidence) / 2, len(values) - 1)
        # Calculate confidence interval
        ci_range = t_val * std_err
    else:
        ci_range = 0
    return mean, ci_range

def log_metrics_with_ci(logger, name, values, percentage=True):
    """Log mean and CI for a set of values."""
    mean, ci = calculate_stats_with_ci(values)
    if percentage:
        logger.info(f"{name}: {mean:.2f}% ± {ci:.2f}%")
    else:
        logger.info(f"{name}: {mean:.4f} ± {ci:.4f}")
    return mean, ci

def run_rewiring_analysis(data, nodes_dict, args, device, logger, log_dir, timestamp, 
                         original_scores, edge_add, edge_delete):
    """Run rewiring analysis for a specific edge budget configuration"""
    # Update args with edge budget
    args.edges_to_add = edge_add
    args.edges_to_delete = edge_delete
    
    # Create NetworkX graph
    G = nx.Graph()
    edge_list = data.edge_index.t().numpy()
    G.add_nodes_from(range(data.num_nodes))
    G.add_edges_from(edge_list)
    
    # Create directory for predictions
    predictions_dir = os.path.join(log_dir, f'predictions_add{edge_add}_del{edge_delete}')
    os.makedirs(predictions_dir, exist_ok=True)
    
    # Detect communities
    communities = list(nx.community.louvain_communities(G, seed=42))
    
    # Apply graph rewiring
    original_edge_count = data.edge_index.size(1)
    rewired_data = modify_graph(data, args.dataset, G, communities, edge_add, edge_delete)
    
    # Calculate statistics
    rewired_edge_count = rewired_data.edge_index.size(1)
    edge_diff = rewired_edge_count - original_edge_count
    
    # Calculate new homophily
    edges = rewired_data.edge_index.t().numpy()
    same_label = (data.y[edges[:, 0]] == data.y[edges[:, 1]]).float()
    rewired_homophily = same_label.mean().item()
    
    # Log rewiring stats
    logger.info(f"\nGraph Rewiring Statistics (Add: {edge_add}, Delete: {edge_delete}):")
    logger.info(f"Original edge count: {original_edge_count}")
    logger.info(f"Rewired edge count: {rewired_edge_count}")
    logger.info(f"Net change in edges: {edge_diff:+} ({(edge_diff/original_edge_count*100):+.2f}%)")
    logger.info(f"Homophily: {data.homophily:.4f} → {rewired_homophily:.4f}")
    
    # Initialize results storage
    rewired_all_results = {
        'test_accuracies': [],
        'memorization_scores': {'shared': [], 'candidate': [], 'independent': [], 'rate': []}
    }
    
    # Run pipeline for each seed
    for seed in args.seeds:
        results = run_memorization_pipeline(
            rewired_data, nodes_dict, args, device, logger,
            log_dir, seed, f"Rewired Graph (Add:{edge_add},Del:{edge_delete},Seed:{seed}):", 
            timestamp
        )
        
        # Train models for predictions
        models = train_models(
            args, rewired_data, nodes_dict['shared'], nodes_dict['candidate'], 
            nodes_dict['independent'], device, logger, log_dir, seeds=[seed]
        )
        
        # Get model predictions
        with torch.no_grad():
            f_model = models[0][-1] if isinstance(models[0], list) else models[0]
            logits = torch.softmax(f_model(rewired_data.x.to(device), 
                                         rewired_data.edge_index.to(device)), dim=1)
            
            # Convert to numpy for saving
            logits_np = logits.cpu().numpy()
            true_labels = rewired_data.y.cpu().numpy()
            
            # Create dataframe with predictions for all nodes
            df = pd.DataFrame(index=range(len(true_labels)))
            df['Node_Idx'] = df.index
            df['Is_Train'] = rewired_data.train_mask.cpu().numpy()
            df['Is_Test'] = rewired_data.test_mask.cpu().numpy()
            df['True_Label'] = true_labels
            
            # Add class probabilities as separate columns
            for i in range(logits_np.shape[1]):
                df[f'Class_{i}_Prob'] = logits_np[:, i]
            
            # Save to CSV
            csv_path = os.path.join(predictions_dir, f'predictions_seed{seed}.csv')
            df.to_csv(csv_path, index=False)
            logger.info(f"Saved predictions for seed {seed} to {csv_path}")
        
        # Collect results
        rewired_all_results['test_accuracies'].append(results['test_accuracy']['mean'])
        for node_type in ['shared', 'candidate', 'independent', 'rate']:
            if node_type in results['memorization']:
                rewired_all_results['memorization_scores'][node_type].append(
                    results['memorization'][node_type]['mean']
                )
    
    # Calculate statistics
    rewired_scores = {
        'test_accuracy': {
            'mean': np.mean(rewired_all_results['test_accuracies']),
            'ci': calculate_stats_with_ci(rewired_all_results['test_accuracies'])[1]
        },
        'memorization': {},
        'config': {'add': edge_add, 'delete': edge_delete}
    }
    
    for node_type in ['shared', 'candidate', 'independent', 'rate']:
        scores = rewired_all_results['memorization_scores'][node_type]
        if scores:
            mean, ci = calculate_stats_with_ci(scores)
            std = np.std(scores, ddof=1)
            rewired_scores['memorization'][node_type] = {
                'mean': mean, 
                'ci': ci,
                'std': std,
                'raw_values': scores
            }
    
    # Create comparison plot
    plot_name = f'memorization_comparison_add{edge_add}_del{edge_delete}.pdf'
    plot_path = os.path.join(log_dir, plot_name)
    plot_comparison(original_scores, rewired_scores, plot_path)
    
    return rewired_scores

def find_optimal_configuration(configs_results, logger):
    """Find the optimal configuration based on test accuracy and memorization metrics"""
    best_configs = {
        'additions': None,
        'deletions': None,
        'balanced': None,
        'overall': None
    }
    
    # Helper function to score a configuration
    def score_config(config):
        test_acc = config['test_accuracy']['mean']
        mem_rate = config['memorization']['rate']['mean']
        mem_score = config['memorization']['candidate']['mean']
        edge_changes = config['config']['add'] + config['config']['delete']
        
        # Prefer configurations that:
        # 1. Increase test accuracy
        # 2. Reduce memorization rate and score
        # 3. Use fewer edge changes
        return (test_acc, -mem_rate, -mem_score, -edge_changes)
    
    # Categorize configs
    additions = [c for c in configs_results if c['config']['delete'] == 0]
    deletions = [c for c in configs_results if c['config']['add'] == 0]
    balanced = [c for c in configs_results if c['config']['add'] == c['config']['delete']]
    
    # Find best in each category
    if additions:
        best_configs['additions'] = max(additions, key=score_config)
    if deletions:
        best_configs['deletions'] = max(deletions, key=score_config)
    if balanced:
        best_configs['balanced'] = max(balanced, key=score_config)
    
    # Find overall best
    best_configs['overall'] = max(configs_results, key=score_config)
    
    return best_configs

def main():
    # Parse arguments
    args = parse_args()
    
    # Ensure dataset name is in the correct format
    args.dataset = get_full_dataset_name(args.dataset)
    
    # Setup logging
    logger, log_dir, timestamp = setup_logging(args)
    logger.info("Starting memorization analysis with graph rewiring...")
    
    # Set device
    device = torch.device(args.device)
    
    # Stage 1: Original Graph Analysis
    logger.info("\n=== Stage 1: Original Graph Analysis ===")
    # Load dataset with correct name format
    dataset_name = get_full_dataset_name(args.dataset)
    data = load_and_process_dataset(args, dataset_name, logger)
    
    # Create node splits
    train_indices = torch.where(data.train_mask)[0]
    num_nodes = len(train_indices)
    
    shared_size = int(0.50 * num_nodes)
    remaining = num_nodes - shared_size
    split_size = remaining // 2
    
    nodes_dict = {
        'shared': train_indices[:shared_size].tolist(),
        'candidate': train_indices[shared_size:shared_size + split_size].tolist(),
        'independent': train_indices[shared_size + split_size:].tolist()
    }
    
    # Initialize lists to store results from all seeds
    original_all_results = {
        'test_accuracies': [],
        'memorization_scores': {'shared': [], 'candidate': [], 'independent': [], 'rate': []}
    }
    
    # Run memorization pipeline for each seed on original graph
    logger.info("\nAnalyzing Original Graph...")
    for seed in args.seeds:
        results = run_memorization_pipeline(
            data, nodes_dict, args, device, logger, 
            log_dir, seed, f"Original Graph (Seed {seed}):", timestamp
        )
        
        # Collect results
        original_all_results['test_accuracies'].append(results['test_accuracy']['mean'])
        for node_type in ['shared', 'candidate', 'independent', 'rate']:
            if node_type in results['memorization']:
                original_all_results['memorization_scores'][node_type].append(
                    results['memorization'][node_type]['mean']
                )
    
    # Calculate statistics with confidence intervals
    original_scores = {
        'test_accuracy': {
            'mean': np.mean(original_all_results['test_accuracies']),
            'ci': calculate_stats_with_ci(original_all_results['test_accuracies'])[1]
        },
        'memorization': {}
    }
    
    for node_type in ['shared', 'candidate', 'independent', 'rate']:
        scores = original_all_results['memorization_scores'][node_type]
        if scores:
            mean, ci = calculate_stats_with_ci(scores)
            std = np.std(scores, ddof=1)
            original_scores['memorization'][node_type] = {
                'mean': mean, 
                'ci': ci,
                'std': std,
                'raw_values': scores
            }
    
    # Stage 2: Graph Rewiring Analysis with Multiple Edge Budgets
    logger.info("\n================== Stage 2: Multiple Edge Budget Analysis ===================================")
    
    # Store all configurations results
    all_configs_results = []
    
    # Run addition-only configurations
    logger.info("\nRunning addition-only configurations...")
    for budget in EDGE_BUDGETS:
        logger.info(f"\nTesting edge additions: {budget}")
        results = run_rewiring_analysis(data, nodes_dict, args, device, logger, 
                                      log_dir, timestamp, original_scores, budget, 0)
        all_configs_results.append(results)
    
    # Run deletion-only configurations
    logger.info("\nRunning deletion-only configurations...")
    for budget in EDGE_BUDGETS:
        logger.info(f"\nTesting edge deletions: {budget}")
        results = run_rewiring_analysis(data, nodes_dict, args, device, logger, 
                                      log_dir, timestamp, original_scores, 0, budget)
        all_configs_results.append(results)
    
    # Run balanced configurations
    logger.info("\nRunning balanced configurations...")
    for budget in EDGE_BUDGETS:
        logger.info(f"\nTesting balanced rewiring: +{budget}/-{budget}")
        results = run_rewiring_analysis(data, nodes_dict, args, device, logger, 
                                      log_dir, timestamp, original_scores, budget, budget)
        all_configs_results.append(results)
    
    # Find optimal configurations
    best_configs = find_optimal_configuration(all_configs_results, logger)
    
    # Log final comparison across all configurations
    # Log final comparison across all configurations
    logger.info("\n================ Final Comparison Across All Configurations ================")
    
    # Log original baseline
    logger.info("\nOriginal Graph Baseline:")
    orig_acc = original_scores['test_accuracy']
    orig_mem = original_scores['memorization']
    logger.info(f"Test Accuracy: {orig_acc['mean']:.2f}% ± {orig_acc['ci']:.2f}%")
    logger.info(f"Memorization Rate: {orig_mem['rate']['mean']:.2f}% ± {orig_mem['rate']['std']:.2f}%")
    logger.info(f"Candidate Set Memorization Score: {orig_mem['candidate']['mean']:.4f} ± {orig_mem['candidate']['std']:.4f}")
    
    # Log best configurations for each category
    for category in ['additions', 'deletions', 'balanced']:
        if best_configs[category]:
            config = best_configs[category]
            logger.info(f"\nBest {category.capitalize()} Configuration:")
            logger.info(f"Edge Changes: +{config['config']['add']}/-{config['config']['delete']}")
            logger.info(f"Test Accuracy: {config['test_accuracy']['mean']:.2f}% ± {config['test_accuracy']['ci']:.2f}%")
            logger.info(f"Memorization Rate: {config['memorization']['rate']['mean']:.2f}% ± {config['memorization']['rate']['std']:.2f}%")
            logger.info(f"Candidate Set Memorization Score: {config['memorization']['candidate']['mean']:.4f} ± {config['memorization']['candidate']['std']:.4f}")
            
            # Calculate improvements
            acc_change = config['test_accuracy']['mean'] - orig_acc['mean']
            mem_rate_change = config['memorization']['rate']['mean'] - orig_mem['rate']['mean']
            mem_score_change = config['memorization']['candidate']['mean'] - orig_mem['candidate']['mean']
            
            logger.info("\nImprovements over baseline:")
            logger.info(f"Test Accuracy: {acc_change:+.2f}%")
            logger.info(f"Memorization Rate: {mem_rate_change:+.2f}%")
            logger.info(f"Candidate Set Memorization Score: {mem_score_change:+.4f}")
    
    # Log overall best configuration
    best = best_configs['overall']
    logger.info("\n================ Overall Best Configuration ================")
    logger.info(f"Edge Changes: +{best['config']['add']}/-{best['config']['delete']}")
    logger.info(f"Test Accuracy: {best['test_accuracy']['mean']:.2f}% ± {best['test_accuracy']['ci']:.2f}%")
    logger.info(f"Memorization Rate: {best['memorization']['rate']['mean']:.2f}% ± {best['memorization']['rate']['std']:.2f}%")
    logger.info(f"Candidate Set Memorization Score: {best['memorization']['candidate']['mean']:.4f} ± {best['memorization']['candidate']['std']:.4f}")
    
    # Calculate and log overall improvements
    acc_change = best['test_accuracy']['mean'] - orig_acc['mean']
    mem_rate_change = best['memorization']['rate']['mean'] - orig_mem['rate']['mean']
    mem_score_change = best['memorization']['candidate']['mean'] - orig_mem['candidate']['mean']
    
    logger.info("\nOverall Improvements:")
    logger.info(f"Test Accuracy: {acc_change:+.2f}%")
    logger.info(f"Memorization Rate: {mem_rate_change:+.2f}%")
    logger.info(f"Candidate Set Memorization Score: {mem_score_change:+.4f}")
    

if __name__ == '__main__':
    main()
