import torch
import random
import json
import os
from tqdm import tqdm
from train.train import test
from train.train_ogbn import test_batch_ogbn
from train.train_batch import test_batch
from model.utils import clone_base_model, clone_processed_model
from utils import set_seed, get_base_model_path, get_processed_model_path
from ogb.nodeproppred import Evaluator

def get_results_save_path(args):
    """Generate the path for saving results JSON file."""
    return os.path.join('results', args.architecture, args.dataset, f'{args.edge_ratio}', str(args.random_seed), f'pp_{args.pp_method}/', f'results.json')

def get_visualization_save_path(args):
    """Generate the path for saving visualization images."""
    return os.path.join('results', args.architecture, args.dataset, f'{args.edge_ratio}', str(args.random_seed), f'pp_{args.pp_method}/', 'result.jpg')

def shuffle_and_generate_removal_indices(data, seed):
    """Generates a shuffled list of edge indices for removal."""
    set_seed(seed)
    num_edges = data.edge_index.size(1)
    return torch.randperm(num_edges).tolist()

def remove_edges(data, ratio, is_full):
    """Removes a fraction of undirected edges from the data, treating edges as undirected."""
    if is_full:
        data.edge_index = data.original_edge_index.clone()
    else:
        data.edge_index = data.backup_edge_index.clone()

    edge_index = data.edge_index
    num_edges = edge_index.size(1)
    edges = edge_index.t().tolist()  # Convert edge_index to list of edges

    # Build mapping from undirected edges to indices of their directed edges
    undirected_edge_to_indices = {}
    for idx, (u, v) in enumerate(edges):
        undirected_edge = tuple(sorted([u, v]))  # Represent undirected edge as sorted tuple
        if undirected_edge not in undirected_edge_to_indices:
            undirected_edge_to_indices[undirected_edge] = []
        undirected_edge_to_indices[undirected_edge].append(idx)

    # Get all unique undirected edges
    all_undirected_edges = list(undirected_edge_to_indices.keys())
    
    # Calculate number of undirected edges to remove
    num_undirected_edges_to_remove = int(ratio * len(all_undirected_edges))

    # Randomly select undirected edges to remove
    undirected_edges_to_remove = random.sample(all_undirected_edges, num_undirected_edges_to_remove)

    # Collect indices of both directions for the selected undirected edges
    removal_indices_set = set()
    for edge in undirected_edges_to_remove:
        removal_indices_set.update(undirected_edge_to_indices[edge])

    # Create a mask to keep edges not in removal_indices_set
    mask = torch.tensor([i not in removal_indices_set for i in range(num_edges)],
                        dtype=torch.bool, device=data.edge_index.device)

    # Apply mask to remove the selected edges
    data.edge_index = data.edge_index[:, mask]
    #print(data.edge_index.size(), data.original_edge_index.size(), data.backup_edge_index.size() )

    return data

def restore_edges(data, ratio, is_full):
    """Restores a fraction of undirected edges from the original_edge_index to the current edge_index."""
    current_edge_index = data.edge_index.clone()
    original_edge_index = data.original_edge_index.clone()

    current_edges = current_edge_index.t().tolist()  # Convert current edge_index to list of edges
    original_edges = original_edge_index.t().tolist()  # Convert original edge_index to list of edges

    # Build mapping for undirected edges in both current and original edge index
    current_undirected_edges = set(tuple(sorted([u, v])) for u, v in current_edges)
    original_undirected_edges = set(tuple(sorted([u, v])) for u, v in original_edges)

    # Find the set of missing edges (edges present in the original but not in the current edge index)
    missing_undirected_edges = list(original_undirected_edges - current_undirected_edges)

    # Calculate number of undirected edges to restore based on the ratio
    num_undirected_edges_to_restore = int(ratio * len(missing_undirected_edges))

    # Randomly select undirected edges to restore
    undirected_edges_to_restore = random.sample(missing_undirected_edges, num_undirected_edges_to_restore)

    # Collect both directions of the edges to be restored
    edges_to_restore = []
    for u, v in undirected_edges_to_restore:
        edges_to_restore.append([u, v])  # Add (u, v)
        edges_to_restore.append([v, u])  # Add (v, u) to maintain undirected behavior

    # Convert edges_to_restore to tensor and append to current edge index
    edges_to_restore_tensor = torch.tensor(edges_to_restore, dtype=torch.long, device=current_edge_index.device).t()

    # Concatenate the restored edges with the current edge_index
    data.edge_index = torch.cat([current_edge_index, edges_to_restore_tensor], dim=1)
    return data

def test_edge_modification(model, data, args, device, indices_or_sequence, action, ratios):
    """Tests the model performance after edge modifications (removal, restoration)."""
    accuracies = []
    for ratio in tqdm(ratios, desc=f"Edge {action.capitalize()}s"):
        with torch.no_grad():
            if action == "remove":
                if args.edge_ratio != 1.0:
                    data = remove_edges(data, ratio, False)
                else:
                    data = remove_edges(data, ratio, True)
            elif action == "restore":
                data = restore_edges(data, ratio, device)
            if args.dataset in ["ogbn-arxiv"]:
                evaluator = Evaluator(name=args.dataset)
                train_acc, val_acc, test_acc = test_batch_ogbn(model, data, evaluator, device)
                print(train_acc, val_acc, test_acc)
            elif args.dataset in ["flickr", "reddit"]:
                _, _, test_acc = test_batch(model, data, device)
            else:
                _, _, test_acc = test(model, data)
            #del data_mod.edge_index
            torch.cuda.empty_cache()
            data.edge_index = data.backup_edge_index.clone().detach()
        accuracies.append(test_acc)
    return accuracies

def run_experiment(data, params, device, args, n, i, seed):
    # Initialize results dictionary for a single iteration
    new_results = {
        **{f"processed_{i}_{action}": [] for i in range(args.process_num) for action in ["remove", "restore"]},
        **{f"base_{action}": [] for action in ["remove", "restore"]}
    }
    
    ratios = [i / n for i in range(n+1)]

    # Generate removal sequences for this iteration
    removal_indices = shuffle_and_generate_removal_indices(data, seed=seed)

    # Base model experiment
    model_path = get_base_model_path(args, i)
    model = clone_base_model(data, model_path, params, device, args)

    print("Running base model experiment...")
    data.backup_edge_index = data.edge_index.clone()
    
    new_results["base_remove"].append(test_edge_modification(model, data, args, device, removal_indices, "remove", ratios))

    if args.edge_ratio < 1.0:  # Only perform restoration if edges were removed
        new_results["base_restore"].append(test_edge_modification(model, data, args, device, (data.original_edge_index, data.edge_index), "restore", ratios))
    else:
        new_results["base_restore"].append([new_results["base_remove"][0][0]] * len(ratios))  # Placeholder if no restoration

    print("Base model experiment completed.")
    del model

    # Processed models experiments
    for j in range(args.process_num):
        model_path = get_processed_model_path(args, i, j)
        model = clone_processed_model(data, model_path, params, device, args)

        print(f"Running processed_{j} model experiment...")

        new_results[f"processed_{j}_remove"].append(test_edge_modification(model, data, args, device, removal_indices, "remove", ratios))

        if args.edge_ratio < 1.0:
            new_results[f"processed_{j}_restore"].append(test_edge_modification(model, data, args, device, (data.original_edge_index, data.edge_index), "restore", ratios))
        else:
            new_results[f"processed_{j}_restore"].append([new_results[f"processed_{j}_remove"][0][0]] * len(ratios))  # Placeholder if no restoration

        print(f"Processed_{j} model experiment completed.")
        del model

    # Save or append results
    results_save_path = get_results_save_path(args)
    if i!=0:
        # File exists, load and append
        with open(results_save_path, 'r') as f:
            existing_results = json.load(f)
        # Append new results to the existing ones
        for key in new_results:
            if key in existing_results:
                existing_results[key].extend(new_results[key])
            else:
                existing_results[key] = new_results[key]
        results = existing_results
    else:
        # File doesn't exist, just use new results
        results = new_results
    
    # Save the updated results
    os.makedirs(os.path.dirname(results_save_path), exist_ok=True)
    with open(results_save_path, 'w') as f:
        json.dump(results, f, indent=4)

    print(f"Experiment completed. Results saved to {results_save_path}.")


